From ced70f7b303dfef4cb62a9f2ac331e72bace6bb8 Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Wed, 9 Sep 2015 14:54:06 +0200 Subject: [PATCH 0001/1168] [SPARK-10514] [Mesos] waiting for min no of total cores acquired by Spark by implementing the sufficientResourcesRegistered method --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 452c32d5411cd..36a22904bb644 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -222,6 +222,10 @@ private[spark] class CoarseMesosSchedulerBackend( markRegistered() } + override def sufficientResourcesRegistered(): Boolean = { + totalCoresAcquired >= maxCores * minRegisteredRatio + } + override def disconnected(d: SchedulerDriver) {} override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} From bc06cbda4ea34c981bf56060396c8db08655bdcf Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Thu, 10 Sep 2015 13:37:32 +0200 Subject: [PATCH 0002/1168] [SPARK-10514] [Mesos] Updating the documents for spark.scheduler.minRegisteredResourcesRatio parameter --- docs/configuration.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index e287591f3fda1..ec08668dcad0d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1109,10 +1109,11 @@ Apart from these, the following properties are also available, and may be useful spark.scheduler.minRegisteredResourcesRatio - 0.8 for YARN mode; 0.0 otherwise + 0.8 for YARN mode; 0.0 for standalone mode and Mesos Coarse grained mode The minimum ratio of registered resources (registered resources / total expected resources) - (resources are executors in yarn mode, CPU cores in standalone mode) + (resources are executors in yarn mode, CPU cores in standalone mode and Mesos Coarsed Grained + mode ['spark.cores.max' value is total expected resources for Mesos Coarse grained mode] ) to wait for before scheduling begins. Specified as a double between 0.0 and 1.0. Regardless of whether the minimum ratio of resources has been reached, the maximum amount of time it will wait before scheduling begins is controlled by config From 89b726b84d06bf54722303b894259c9de69bed65 Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Tue, 15 Sep 2015 18:11:10 +0200 Subject: [PATCH 0003/1168] [SPARK-10618] [Mesos] Refactoring coarsed-grained scheduling condition and adding test for same --- .../mesos/CoarseMesosSchedulerBackend.scala | 20 ++++--- .../CoarseMesosSchedulerBackendSuite.scala | 60 +++++++++++++++++-- 2 files changed, 67 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 65cb5016cfcc9..c6c04f59f84f9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -244,13 +244,7 @@ private[spark] class CoarseMesosSchedulerBackend( val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt val id = offer.getId.getValue - if (taskIdToSlaveId.size < executorLimit && - totalCoresAcquired < maxCores && - meetsConstraints && - mem >= calculateTotalMemory(sc) && - cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId)) { + if (isOfferValidForScheduling(meetsConstraints, slaveId, mem, cpus, sc)) { // Launch an executor on the slave val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) totalCoresAcquired += cpusToUse @@ -292,6 +286,18 @@ private[spark] class CoarseMesosSchedulerBackend( } + def isOfferValidForScheduling(meetsConstraints: Boolean, + slaveId: String, mem: Double, + cpus: Int, sc: SparkContext): Boolean = { + taskIdToSlaveId.size < executorLimit && + totalCoresAcquired < maxCores && + meetsConstraints && + mem >= calculateTotalMemory(sc) && + cpus >= 1 && + failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && + !slaveIdsWithExecutors.contains(slaveId) + } + override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue.toInt val state = status.getState diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index 525ee0d3bdc5a..b68fcb7e4d13a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -58,7 +58,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite private def createSchedulerBackend( taskScheduler: TaskSchedulerImpl, - driver: SchedulerDriver): CoarseMesosSchedulerBackend = { + driver: SchedulerDriver, sc: SparkContext): CoarseMesosSchedulerBackend = { val securityManager = mock[SecurityManager] val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master", securityManager) { override protected def createSchedulerDriver( @@ -84,11 +84,10 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite .setMaster("local[*]") .setAppName("test-mesos-dynamic-alloc") .setSparkHome("/path") - - sc = new SparkContext(sparkConf) - } + } test("mesos supports killing and limiting executors") { + sc = new SparkContext(sparkConf) val driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) val taskScheduler = mock[TaskSchedulerImpl] @@ -97,7 +96,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite sparkConf.set("spark.driver.host", "driverHost") sparkConf.set("spark.driver.port", "1234") - val backend = createSchedulerBackend(taskScheduler, driver) + val backend = createSchedulerBackend(taskScheduler, driver, sc) val minMem = backend.calculateTotalMemory(sc) val minCpu = 4 @@ -140,12 +139,13 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite } test("mesos supports killing and relaunching tasks with executors") { + sc = new SparkContext(sparkConf) val driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) val taskScheduler = mock[TaskSchedulerImpl] when(taskScheduler.sc).thenReturn(sc) - val backend = createSchedulerBackend(taskScheduler, driver) + val backend = createSchedulerBackend(taskScheduler, driver, sc) val minMem = backend.calculateTotalMemory(sc) + 1024 val minCpu = 4 @@ -184,4 +184,52 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite verify(driver, times(1)).reviveOffers() } + + test("testing various isOfferValidForScheduling cases") { + val sparkConf = (new SparkConf) + .setMaster("local[*]") + .setAppName("test-mesos-dynamic-alloc") + .setSparkHome("/path") + .set("spark.cores.max", "10") + + val sc = new SparkContext(sparkConf) + + val driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + + val schedulerBackend = createSchedulerBackend(taskScheduler, driver, sc) + + // Return true when there is a valid offer + assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 5, sc)) + + schedulerBackend.slaveIdsWithExecutors += "Slave2" + schedulerBackend.failuresBySlaveId("Slave3") = 2 + schedulerBackend.totalCoresAcquired = 5 + + // Return false When offer do not meet constraints + assert(schedulerBackend.isOfferValidForScheduling(false, "Slave1", 10000, 5, sc) === false) + + // Return false When memory in offer is less than required memory + assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 1, 5, sc) === false) + + // Return false When cpu in offer is less than required cpu + assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 0, sc) === false) + + // Return false When offer is from slave already running an executor + assert(schedulerBackend.isOfferValidForScheduling(true, "Slave2", 10000, 5, sc) === false) + + // Return false When task is failed more than MAX_SLAVE_FAILURES times on the given slave + assert(schedulerBackend.isOfferValidForScheduling(true, "Slave3", 10000, 5, sc) === false) + + schedulerBackend.totalCoresAcquired = 10 + + // Return false When max core is already acquired + assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 5, sc) === false) + + } + + } From ce81b3a28a8e7a2ef46b23e6560b41d51289030c Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Tue, 15 Sep 2015 18:11:10 +0200 Subject: [PATCH 0004/1168] [SPARK-10618] [Mesos] Refactoring coarsed-grained scheduling condition and adding test for same --- .../mesos/CoarseMesosSchedulerBackend.scala | 18 ++++-- .../CoarseMesosSchedulerBackendSuite.scala | 60 +++++++++++++++++-- 2 files changed, 66 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 7d08eae0b4871..c0c8f1f6e3ec9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -254,12 +254,7 @@ private[spark] class CoarseMesosSchedulerBackend( val cpus = getResource(offer.getResourcesList, "cpus").toInt val id = offer.getId.getValue if (meetsConstraints) { - if (taskIdToSlaveId.size < executorLimit && - totalCoresAcquired < maxCores && - mem >= calculateTotalMemory(sc) && - cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId)) { + if (isOfferValidForScheduling(meetsConstraints, slaveId, mem, cpus, sc)) { // Launch an executor on the slave val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) totalCoresAcquired += cpusToUse @@ -308,6 +303,17 @@ private[spark] class CoarseMesosSchedulerBackend( } } + def isOfferValidForScheduling(meetsConstraints: Boolean, + slaveId: String, mem: Double, + cpus: Int, sc: SparkContext): Boolean = { + taskIdToSlaveId.size < executorLimit && + totalCoresAcquired < maxCores && + meetsConstraints && + mem >= calculateTotalMemory(sc) && + cpus >= 1 && + failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && + !slaveIdsWithExecutors.contains(slaveId) + } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue.toInt diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index 525ee0d3bdc5a..b68fcb7e4d13a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -58,7 +58,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite private def createSchedulerBackend( taskScheduler: TaskSchedulerImpl, - driver: SchedulerDriver): CoarseMesosSchedulerBackend = { + driver: SchedulerDriver, sc: SparkContext): CoarseMesosSchedulerBackend = { val securityManager = mock[SecurityManager] val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master", securityManager) { override protected def createSchedulerDriver( @@ -84,11 +84,10 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite .setMaster("local[*]") .setAppName("test-mesos-dynamic-alloc") .setSparkHome("/path") - - sc = new SparkContext(sparkConf) - } + } test("mesos supports killing and limiting executors") { + sc = new SparkContext(sparkConf) val driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) val taskScheduler = mock[TaskSchedulerImpl] @@ -97,7 +96,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite sparkConf.set("spark.driver.host", "driverHost") sparkConf.set("spark.driver.port", "1234") - val backend = createSchedulerBackend(taskScheduler, driver) + val backend = createSchedulerBackend(taskScheduler, driver, sc) val minMem = backend.calculateTotalMemory(sc) val minCpu = 4 @@ -140,12 +139,13 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite } test("mesos supports killing and relaunching tasks with executors") { + sc = new SparkContext(sparkConf) val driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) val taskScheduler = mock[TaskSchedulerImpl] when(taskScheduler.sc).thenReturn(sc) - val backend = createSchedulerBackend(taskScheduler, driver) + val backend = createSchedulerBackend(taskScheduler, driver, sc) val minMem = backend.calculateTotalMemory(sc) + 1024 val minCpu = 4 @@ -184,4 +184,52 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite verify(driver, times(1)).reviveOffers() } + + test("testing various isOfferValidForScheduling cases") { + val sparkConf = (new SparkConf) + .setMaster("local[*]") + .setAppName("test-mesos-dynamic-alloc") + .setSparkHome("/path") + .set("spark.cores.max", "10") + + val sc = new SparkContext(sparkConf) + + val driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + + val schedulerBackend = createSchedulerBackend(taskScheduler, driver, sc) + + // Return true when there is a valid offer + assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 5, sc)) + + schedulerBackend.slaveIdsWithExecutors += "Slave2" + schedulerBackend.failuresBySlaveId("Slave3") = 2 + schedulerBackend.totalCoresAcquired = 5 + + // Return false When offer do not meet constraints + assert(schedulerBackend.isOfferValidForScheduling(false, "Slave1", 10000, 5, sc) === false) + + // Return false When memory in offer is less than required memory + assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 1, 5, sc) === false) + + // Return false When cpu in offer is less than required cpu + assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 0, sc) === false) + + // Return false When offer is from slave already running an executor + assert(schedulerBackend.isOfferValidForScheduling(true, "Slave2", 10000, 5, sc) === false) + + // Return false When task is failed more than MAX_SLAVE_FAILURES times on the given slave + assert(schedulerBackend.isOfferValidForScheduling(true, "Slave3", 10000, 5, sc) === false) + + schedulerBackend.totalCoresAcquired = 10 + + // Return false When max core is already acquired + assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 5, sc) === false) + + } + + } From aa964b3d2b14a3efcc8061366ff8a62b05bcc7ae Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Wed, 16 Sep 2015 11:10:55 +0200 Subject: [PATCH 0005/1168] [SPARK-10618] [Mesos] Killing space and removing duplication, also having one test for each case --- .../CoarseMesosSchedulerBackendSuite.scala | 72 ++++++++++--------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index b68fcb7e4d13a..3b8713785ee05 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -77,6 +77,14 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite backend } + private def createSchedulerBackendForGivenSparkConf(sc : SparkContext) = { + val driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + createSchedulerBackend(taskScheduler, driver, sc) + } + var sparkConf: SparkConf = _ before { @@ -84,10 +92,12 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite .setMaster("local[*]") .setAppName("test-mesos-dynamic-alloc") .setSparkHome("/path") + .set("spark.cores.max", "10") + + sc = new SparkContext(sparkConf) } test("mesos supports killing and limiting executors") { - sc = new SparkContext(sparkConf) val driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) val taskScheduler = mock[TaskSchedulerImpl] @@ -139,7 +149,6 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite } test("mesos supports killing and relaunching tasks with executors") { - sc = new SparkContext(sparkConf) val driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) val taskScheduler = mock[TaskSchedulerImpl] @@ -153,7 +162,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite val offer1 = createOffer("o1", "s1", minMem, minCpu) mesosOffers.add(offer1) - val offer2 = createOffer("o2", "s1", minMem, 1); + val offer2 = createOffer("o2", "s1", minMem, 1) backend.resourceOffers(driver, mesosOffers) @@ -185,51 +194,50 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite verify(driver, times(1)).reviveOffers() } - test("testing various isOfferValidForScheduling cases") { - val sparkConf = (new SparkConf) - .setMaster("local[*]") - .setAppName("test-mesos-dynamic-alloc") - .setSparkHome("/path") - .set("spark.cores.max", "10") - - val sc = new SparkContext(sparkConf) - - val driver = mock[SchedulerDriver] - when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) - val taskScheduler = mock[TaskSchedulerImpl] - when(taskScheduler.sc).thenReturn(sc) - - - val schedulerBackend = createSchedulerBackend(taskScheduler, driver, sc) + test("isOfferValidForScheduling return true when there is a valid offer") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - // Return true when there is a valid offer assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 5, sc)) + } - schedulerBackend.slaveIdsWithExecutors += "Slave2" - schedulerBackend.failuresBySlaveId("Slave3") = 2 - schedulerBackend.totalCoresAcquired = 5 + test("isOfferValidForScheduling return false when offer do not meet constraints") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - // Return false When offer do not meet constraints assert(schedulerBackend.isOfferValidForScheduling(false, "Slave1", 10000, 5, sc) === false) + } + + test("isOfferValidForScheduling return false when memory in offer is less than required memory") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - // Return false When memory in offer is less than required memory assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 1, 5, sc) === false) + } + + test("isOfferValidForScheduling return false when cpu in offer is less than required cpu") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - // Return false When cpu in offer is less than required cpu assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 0, sc) === false) + } + + test("isOfferValidForScheduling return false when offer is from slave already running" + + " an executor") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + schedulerBackend.slaveIdsWithExecutors += "Slave2" - // Return false When offer is from slave already running an executor assert(schedulerBackend.isOfferValidForScheduling(true, "Slave2", 10000, 5, sc) === false) + } + + test("isOfferValidForScheduling return false when task is failed more than " + + "MAX_SLAVE_FAILURES times on the given slave") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + schedulerBackend.failuresBySlaveId("Slave3") = 2 - // Return false When task is failed more than MAX_SLAVE_FAILURES times on the given slave assert(schedulerBackend.isOfferValidForScheduling(true, "Slave3", 10000, 5, sc) === false) + } + test("isOfferValidForScheduling return false when max core is already acquired") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.totalCoresAcquired = 10 - // Return false When max core is already acquired assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 5, sc) === false) - } - - } From d733bf27c113c409639cf83a11db937c03b2e74d Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Thu, 17 Sep 2015 18:28:59 +0200 Subject: [PATCH 0006/1168] [SPARK-10514] [CORE] [Mesos] Refactoring fine-grained scheduling condition and adding test for it. --- .../cluster/mesos/MesosSchedulerBackend.scala | 25 +++++--- .../mesos/MesosSchedulerBackendSuite.scala | 62 ++++++++++++++++++- 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 281965a5981bb..2617a126889a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -246,14 +246,16 @@ private[spark] class MesosSchedulerBackend( val slaveId = o.getSlaveId.getValue val offerAttributes = toAttributeMap(o.getAttributesList) - // check offers for - // 1. Memory requirements - // 2. CPU requirements - need at least 1 for executor, 1 for task - val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) - val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) + // check if all constraints are satisfield + // 1. Attribute constraints + // 2. Memory requirements + // 3. CPU requirements - need at least 1 for executor, 1 for task + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val meetsRequirements = - (meetsMemoryRequirements && meetsCPURequirements) || - (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) + isOfferValidForScheduling(cpus, mem, slaveId, sc) + + // add some debug messaging val debugstr = if (meetsRequirements) "Accepting" else "Declining" logDebug(s"$debugstr offer: ${o.getId.getValue} with attributes: " + s"$offerAttributes mem: $mem cpu: $cpus") @@ -328,6 +330,15 @@ private[spark] class MesosSchedulerBackend( } } + def isOfferValidForScheduling(cpusOffered: Double, memory : Double, + slaveId: String, sc : SparkContext): Boolean = { + val meetsMemoryRequirements = memory >= calculateTotalMemory(sc) + val meetsCPURequirements = cpusOffered >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) + + (meetsMemoryRequirements && meetsCPURequirements) || + (slaveIdToExecutorInfo.contains(slaveId) && cpusOffered >= scheduler.CPUS_PER_TASK) + } + /** Turn a Spark TaskDescription into a Mesos task and also resources unused by the task */ def createMesosTask( task: TaskDescription, diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index c4dc560031207..11e8a330f9c60 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -28,7 +28,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.mesos.Protos.Value.Scalar import org.apache.mesos.Protos._ -import org.apache.mesos.SchedulerDriver +import org.apache.mesos.{Protos, SchedulerDriver} import org.mockito.Matchers._ import org.mockito.Mockito._ import org.mockito.{ArgumentCaptor, Matchers} @@ -344,4 +344,64 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi r.getName.equals("cpus") && r.getScalar.getValue.equals(1.0) && r.getRole.equals("prod") }) } + + private def createSchedulerBackendForGivenSparkConf(sc : SparkContext) : MesosSchedulerBackend = { + val conf = new SparkConf + + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + when(sc.getSparkHome()).thenReturn(Option("/spark-home")) + + when(sc.conf).thenReturn(conf) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.executorMemory).thenReturn(100) + when(sc.listenerBus).thenReturn(listenerBus) + + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) + + new MesosSchedulerBackend(taskScheduler, sc, "master") + } + + test("isOfferValidForScheduling return true when there offer meet cpu and memory requirement") { + val sc = mock[SparkContext] + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + + assert(schedulerBackend.isOfferValidForScheduling( 5, 10000, "Slave1", sc)) + } + + test("isOfferValidForScheduling return false when memory in offer is less than required memory") { + val sc = mock[SparkContext] + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + + assert(schedulerBackend.isOfferValidForScheduling(5, 10, "Slave1", sc) === false) + } + + test("isOfferValidForScheduling return false when cpu in offer is less than required cpu") { + val sc = mock[SparkContext] + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + + assert(schedulerBackend.isOfferValidForScheduling(0, 10000, "Slave1", sc) === false) + } + + test("isOfferValidForScheduling return true when offer is from slave already running and" + + " cpu is less than minimum cpu per task an executor") { + val sc = mock[SparkContext] + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + schedulerBackend.slaveIdToExecutorInfo("Slave2") = null + + assert(schedulerBackend.isOfferValidForScheduling(2, 10000, "Slave2", sc) === true) + } + + test("isOfferValidForScheduling return false when offer is from slave already running but" + + " cpu is less than minimum cpu per task an executor") { + val sc = mock[SparkContext] + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + schedulerBackend.slaveIdToExecutorInfo("Slave2") = null + + assert(schedulerBackend.isOfferValidForScheduling(1, 10000, "Slave2", sc) === false) + } + } From 288c698fe9ad959a5ee55d9a3498857c422c5d77 Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Wed, 16 Dec 2015 00:23:30 +0100 Subject: [PATCH 0007/1168] [SPARK-10618] [Mesos] Adressing comments on PR. --- .../mesos/CoarseMesosSchedulerBackend.scala | 26 +++++++++++------ .../cluster/mesos/MesosSchedulerBackend.scala | 12 ++++---- .../CoarseMesosSchedulerBackendSuite.scala | 28 +++++++++---------- .../mesos/MesosSchedulerBackendSuite.scala | 20 ++++++------- 4 files changed, 47 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index c0c8f1f6e3ec9..3cf296b7cb021 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -254,7 +254,7 @@ private[spark] class CoarseMesosSchedulerBackend( val cpus = getResource(offer.getResourcesList, "cpus").toInt val id = offer.getId.getValue if (meetsConstraints) { - if (isOfferValidForScheduling(meetsConstraints, slaveId, mem, cpus, sc)) { + if (isOfferSatisfiesRequirements(meetsConstraints, slaveId, mem, cpus, sc)) { // Launch an executor on the slave val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) totalCoresAcquired += cpusToUse @@ -303,16 +303,24 @@ private[spark] class CoarseMesosSchedulerBackend( } } - def isOfferValidForScheduling(meetsConstraints: Boolean, + // ToDo: Abstract out each condition and log them. + def isOfferSatisfiesRequirements(meetsConstraints: Boolean, slaveId: String, mem: Double, - cpus: Int, sc: SparkContext): Boolean = { - taskIdToSlaveId.size < executorLimit && - totalCoresAcquired < maxCores && + cpusOffered: Int, sc: SparkContext): Boolean = { + val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) + val meetsCPURequirements = cpusOffered >= 1 + val needMoreCores = totalCoresAcquired < maxCores + val healthySlave = failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES + val taskOnEachSlaveLessThanExecutorLimit = taskIdToSlaveId.size < executorLimit + val executorNotRunningOnSlave = !slaveIdsWithExecutors.contains(slaveId) + + executorNotRunningOnSlave && + taskOnEachSlaveLessThanExecutorLimit && + needMoreCores && meetsConstraints && - mem >= calculateTotalMemory(sc) && - cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId) + meetsMemoryRequirements && + meetsCPURequirements && + healthySlave } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 2617a126889a9..9b398749eb48e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -246,14 +246,11 @@ private[spark] class MesosSchedulerBackend( val slaveId = o.getSlaveId.getValue val offerAttributes = toAttributeMap(o.getAttributesList) - // check if all constraints are satisfield - // 1. Attribute constraints - // 2. Memory requirements - // 3. CPU requirements - need at least 1 for executor, 1 for task + // check if Attribute constraints is satisfied val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) val meetsRequirements = - isOfferValidForScheduling(cpus, mem, slaveId, sc) + isOfferSatisfiesRequirements(cpus, mem, slaveId, sc) // add some debug messaging val debugstr = if (meetsRequirements) "Accepting" else "Declining" @@ -330,7 +327,10 @@ private[spark] class MesosSchedulerBackend( } } - def isOfferValidForScheduling(cpusOffered: Double, memory : Double, + // check if all constraints are satisfied + // 1. Memory requirements + // 2. CPU requirements - need at least 1 for executor, 1 for task + def isOfferSatisfiesRequirements(cpusOffered: Double, memory : Double, slaveId: String, sc : SparkContext): Boolean = { val meetsMemoryRequirements = memory >= calculateTotalMemory(sc) val meetsCPURequirements = cpusOffered >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index 3b8713785ee05..c9bcfce8d846c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -194,50 +194,50 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite verify(driver, times(1)).reviveOffers() } - test("isOfferValidForScheduling return true when there is a valid offer") { + test("isOfferSatisfiesRequirements return true when there is a valid offer") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 5, sc)) + assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 10000, 5, sc)) } - test("isOfferValidForScheduling return false when offer do not meet constraints") { + test("isOfferSatisfiesRequirements return false when offer do not meet constraints") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling(false, "Slave1", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(false, "Slave1", 10000, 5, sc) === false) } - test("isOfferValidForScheduling return false when memory in offer is less than required memory") { + test("isOfferSatisfiesRequirements return false when memory in offer is less than required memory") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 1, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 1, 5, sc) === false) } - test("isOfferValidForScheduling return false when cpu in offer is less than required cpu") { + test("isOfferSatisfiesRequirements return false when cpu in offer is less than required cpu") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 0, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 10000, 0, sc) === false) } - test("isOfferValidForScheduling return false when offer is from slave already running" + + test("isOfferSatisfiesRequirements return false when offer is from slave already running" + " an executor") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.slaveIdsWithExecutors += "Slave2" - assert(schedulerBackend.isOfferValidForScheduling(true, "Slave2", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave2", 10000, 5, sc) === false) } - test("isOfferValidForScheduling return false when task is failed more than " + + test("isOfferSatisfiesRequirements return false when task is failed more than " + "MAX_SLAVE_FAILURES times on the given slave") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.failuresBySlaveId("Slave3") = 2 - assert(schedulerBackend.isOfferValidForScheduling(true, "Slave3", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave3", 10000, 5, sc) === false) } - test("isOfferValidForScheduling return false when max core is already acquired") { + test("isOfferSatisfiesRequirements return false when max core is already acquired") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.totalCoresAcquired = 10 - assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 10000, 5, sc) === false) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 11e8a330f9c60..cbfd4bcae7bf2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -365,43 +365,43 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi new MesosSchedulerBackend(taskScheduler, sc, "master") } - test("isOfferValidForScheduling return true when there offer meet cpu and memory requirement") { + test("isOfferSatisfiesRequirements return true when there offer meet cpu and memory requirement") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling( 5, 10000, "Slave1", sc)) + assert(schedulerBackend.isOfferSatisfiesRequirements( 5, 10000, "Slave1", sc)) } - test("isOfferValidForScheduling return false when memory in offer is less than required memory") { + test("isOfferSatisfiesRequirements return false when memory in offer is less than required memory") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling(5, 10, "Slave1", sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(5, 10, "Slave1", sc) === false) } - test("isOfferValidForScheduling return false when cpu in offer is less than required cpu") { + test("isOfferSatisfiesRequirements return false when cpu in offer is less than required cpu") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling(0, 10000, "Slave1", sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(0, 10000, "Slave1", sc) === false) } - test("isOfferValidForScheduling return true when offer is from slave already running and" + + test("isOfferSatisfiesRequirements return true when offer is from slave already running and" + " cpu is less than minimum cpu per task an executor") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.slaveIdToExecutorInfo("Slave2") = null - assert(schedulerBackend.isOfferValidForScheduling(2, 10000, "Slave2", sc) === true) + assert(schedulerBackend.isOfferSatisfiesRequirements(2, 10000, "Slave2", sc) === true) } - test("isOfferValidForScheduling return false when offer is from slave already running but" + + test("isOfferSatisfiesRequirements return false when offer is from slave already running but" + " cpu is less than minimum cpu per task an executor") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.slaveIdToExecutorInfo("Slave2") = null - assert(schedulerBackend.isOfferValidForScheduling(1, 10000, "Slave2", sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(1, 10000, "Slave2", sc) === false) } } From 53e0600b1fce9b13d8747a6037d5f0f0e2f8024c Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Wed, 16 Dec 2015 11:06:17 +0100 Subject: [PATCH 0008/1168] [SPARK-10618] [Mesos] removing meetconstraints check as it is already done --- .../mesos/CoarseMesosSchedulerBackend.scala | 8 +++----- .../CoarseMesosSchedulerBackendSuite.scala | 17 ++++++----------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 3cf296b7cb021..287830e007b02 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -254,7 +254,7 @@ private[spark] class CoarseMesosSchedulerBackend( val cpus = getResource(offer.getResourcesList, "cpus").toInt val id = offer.getId.getValue if (meetsConstraints) { - if (isOfferSatisfiesRequirements(meetsConstraints, slaveId, mem, cpus, sc)) { + if (isOfferSatisfiesRequirements(slaveId, mem, cpus, sc)) { // Launch an executor on the slave val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) totalCoresAcquired += cpusToUse @@ -304,9 +304,8 @@ private[spark] class CoarseMesosSchedulerBackend( } // ToDo: Abstract out each condition and log them. - def isOfferSatisfiesRequirements(meetsConstraints: Boolean, - slaveId: String, mem: Double, - cpusOffered: Int, sc: SparkContext): Boolean = { + def isOfferSatisfiesRequirements(slaveId: String, mem: Double, cpusOffered: Int, + sc: SparkContext): Boolean = { val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) val meetsCPURequirements = cpusOffered >= 1 val needMoreCores = totalCoresAcquired < maxCores @@ -317,7 +316,6 @@ private[spark] class CoarseMesosSchedulerBackend( executorNotRunningOnSlave && taskOnEachSlaveLessThanExecutorLimit && needMoreCores && - meetsConstraints && meetsMemoryRequirements && meetsCPURequirements && healthySlave diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index c9bcfce8d846c..f6517c9090415 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -197,25 +197,20 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite test("isOfferSatisfiesRequirements return true when there is a valid offer") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 10000, 5, sc)) + assert(schedulerBackend.isOfferSatisfiesRequirements("Slave1", 10000, 5, sc)) } - test("isOfferSatisfiesRequirements return false when offer do not meet constraints") { - val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - - assert(schedulerBackend.isOfferSatisfiesRequirements(false, "Slave1", 10000, 5, sc) === false) - } test("isOfferSatisfiesRequirements return false when memory in offer is less than required memory") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 1, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements("Slave1", 1, 5, sc) === false) } test("isOfferSatisfiesRequirements return false when cpu in offer is less than required cpu") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 10000, 0, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements("Slave1", 10000, 0, sc) === false) } test("isOfferSatisfiesRequirements return false when offer is from slave already running" + @@ -223,7 +218,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.slaveIdsWithExecutors += "Slave2" - assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave2", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements("Slave2", 10000, 5, sc) === false) } test("isOfferSatisfiesRequirements return false when task is failed more than " + @@ -231,13 +226,13 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.failuresBySlaveId("Slave3") = 2 - assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave3", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements("Slave3", 10000, 5, sc) === false) } test("isOfferSatisfiesRequirements return false when max core is already acquired") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.totalCoresAcquired = 10 - assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements("Slave1", 10000, 5, sc) === false) } } From f1bb5b1033216ac935d98a41e5528412dec13ec1 Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Wed, 16 Dec 2015 14:37:13 +0100 Subject: [PATCH 0009/1168] [SPARK-10618] [Mesos] checking in missed Style check fix --- .../cluster/mesos/CoarseMesosSchedulerBackendSuite.scala | 3 ++- .../cluster/mesos/MesosSchedulerBackendSuite.scala | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index f6517c9090415..9a4f9b7352fd4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -201,7 +201,8 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite } - test("isOfferSatisfiesRequirements return false when memory in offer is less than required memory") { + test("isOfferSatisfiesRequirements return false when memory in offer is less" + + " than required memory") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) assert(schedulerBackend.isOfferSatisfiesRequirements("Slave1", 1, 5, sc) === false) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index cbfd4bcae7bf2..9e013a19c6b18 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -365,14 +365,16 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi new MesosSchedulerBackend(taskScheduler, sc, "master") } - test("isOfferSatisfiesRequirements return true when there offer meet cpu and memory requirement") { + test("isOfferSatisfiesRequirements return true when there offer meet cpu and" + + " memory requirement") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) assert(schedulerBackend.isOfferSatisfiesRequirements( 5, 10000, "Slave1", sc)) } - test("isOfferSatisfiesRequirements return false when memory in offer is less than required memory") { + test("isOfferSatisfiesRequirements return false when memory in offer is less " + + "than required memory") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) From 5fbf71ad846e8511444dadbf31210578b789f3cd Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 15 Sep 2015 09:58:49 -0700 Subject: [PATCH 0010/1168] [SPARK-10491] [MLLIB] move RowMatrix.dspr to BLAS jira: https://issues.apache.org/jira/browse/SPARK-10491 We implemented dspr with sparse vector support in `RowMatrix`. This method is also used in WeightedLeastSquares and other places. It would be useful to move it to `linalg.BLAS`. Let me know if new UT needed. Author: Yuhao Yang Closes #8663 from hhbyyh/movedspr. --- .../spark/ml/optim/WeightedLeastSquares.scala | 4 +- .../org/apache/spark/mllib/linalg/BLAS.scala | 44 +++++++++++++++++++ .../mllib/linalg/distributed/RowMatrix.scala | 40 +---------------- .../apache/spark/mllib/linalg/BLASSuite.scala | 25 +++++++++++ 4 files changed, 72 insertions(+), 41 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index a99e2ac4c6913..0ff8931b0bab4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -88,7 +88,7 @@ private[ml] class WeightedLeastSquares( if (fitIntercept) { // shift centers // A^T A - aBar aBar^T - RowMatrix.dspr(-1.0, aBar, aaValues) + BLAS.spr(-1.0, aBar, aaValues) // A^T b - bBar aBar BLAS.axpy(-bBar, aBar, abBar) } @@ -203,7 +203,7 @@ private[ml] object WeightedLeastSquares { bbSum += w * b * b BLAS.axpy(w, a, aSum) BLAS.axpy(w * b, a, abSum) - RowMatrix.dspr(w, a, aaSum.values) + BLAS.spr(w, a, aaSum) this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 9ee81eda8a8c0..df9f4ae145b88 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -236,6 +236,50 @@ private[spark] object BLAS extends Serializable with Logging { _nativeBLAS } + /** + * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. + * + * @param U the upper triangular part of the matrix in a [[DenseVector]](column major) + */ + def spr(alpha: Double, v: Vector, U: DenseVector): Unit = { + spr(alpha, v, U.values) + } + + /** + * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. + * + * @param U the upper triangular part of the matrix packed in an array (column major) + */ + def spr(alpha: Double, v: Vector, U: Array[Double]): Unit = { + val n = v.size + v match { + case DenseVector(values) => + NativeBLAS.dspr("U", n, alpha, values, 1, U) + case SparseVector(size, indices, values) => + val nnz = indices.length + var colStartIdx = 0 + var prevCol = 0 + var col = 0 + var j = 0 + var i = 0 + var av = 0.0 + while (j < nnz) { + col = indices(j) + // Skip empty columns. + colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2 + col = indices(j) + av = alpha * values(j) + i = 0 + while (i <= j) { + U(colStartIdx + indices(i)) += av * values(i) + i += 1 + } + j += 1 + prevCol = col + } + } + } + /** * A := alpha * x * x^T^ + A * @param alpha a real scalar that will be multiplied to x * x^T^. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 83779ac88989b..e55ef26858adb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -24,7 +24,6 @@ import scala.collection.mutable.ListBuffer import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, svd => brzSvd, MatrixSingularException, inv} import breeze.numerics.{sqrt => brzSqrt} -import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.SparkContext._ @@ -123,7 +122,7 @@ class RowMatrix @Since("1.0.0") ( // Compute the upper triangular part of the gram matrix. val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( seqOp = (U, v) => { - RowMatrix.dspr(1.0, v, U.data) + BLAS.spr(1.0, v, U.data) U }, combOp = (U1, U2) => U1 += U2) @@ -673,43 +672,6 @@ class RowMatrix @Since("1.0.0") ( @Experimental object RowMatrix { - /** - * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR. - * - * @param U the upper triangular part of the matrix packed in an array (column major) - */ - // TODO: SPARK-10491 - move this method to linalg.BLAS - private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = { - // TODO: Find a better home (breeze?) for this method. - val n = v.size - v match { - case DenseVector(values) => - blas.dspr("U", n, alpha, values, 1, U) - case SparseVector(size, indices, values) => - val nnz = indices.length - var colStartIdx = 0 - var prevCol = 0 - var col = 0 - var j = 0 - var i = 0 - var av = 0.0 - while (j < nnz) { - col = indices(j) - // Skip empty columns. - colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2 - col = indices(j) - av = alpha * values(j) - i = 0 - while (i <= j) { - U(colStartIdx + indices(i)) += av * values(i) - i += 1 - } - j += 1 - prevCol = col - } - } - } - /** * Fills a full square matrix from its upper triangular part. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 8db5c8424abe9..96e5ffef7a131 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -126,6 +126,31 @@ class BLASSuite extends SparkFunSuite { } } + test("spr") { + // test dense vector + val alpha = 0.1 + val x = new DenseVector(Array(1.0, 2, 2.1, 4)) + val U = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4)) + val expected = new DenseVector(Array(1.1, 2.2, 2.4, 3.21, 3.42, 3.441, 4.4, 4.8, 4.84, 5.6)) + + spr(alpha, x, U) + assert(U ~== expected absTol 1e-9) + + val matrix33 = new DenseVector(Array(1.0, 2, 3, 4, 5)) + withClue("Size of vector must match the rank of matrix") { + intercept[Exception] { + spr(alpha, x, matrix33) + } + } + + // test sparse vector + val sv = new SparseVector(4, Array(0, 3), Array(1.0, 2)) + val U2 = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4)) + spr(0.1, sv, U2) + val expectedSparse = new DenseVector(Array(1.1, 2.0, 2.0, 3.0, 3.0, 3.0, 4.2, 4.0, 4.0, 4.4)) + assert(U2 ~== expectedSparse absTol 1e-15) + } + test("syr") { val dA = new DenseMatrix(4, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8)) From 170474cba94bfc12d2b0a50992cc0a2615b9adde Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 15 Sep 2015 10:45:02 -0700 Subject: [PATCH 0011/1168] [SPARK-10300] [BUILD] [TESTS] Add support for test tags in run-tests.py. This change does two things: - tag a few tests and adds the mechanism in the build to be able to disable those tags, both in maven and sbt, for both junit and scalatest suites. - add some logic to run-tests.py to disable some tags depending on what files have changed; that's used to disable expensive tests when a module hasn't explicitly been changed, to speed up testing for changes that don't directly affect those modules. Author: Marcelo Vanzin Closes #8437 from vanzin/test-tags. --- core/pom.xml | 10 ------- dev/run-tests.py | 19 ++++++++++++-- dev/sparktestsupport/modules.py | 24 ++++++++++++++++- external/flume/pom.xml | 10 ------- external/kafka/pom.xml | 10 ------- external/mqtt/pom.xml | 10 ------- external/twitter/pom.xml | 10 ------- external/zeromq/pom.xml | 10 ------- extras/java8-tests/pom.xml | 10 ------- extras/kinesis-asl/pom.xml | 5 ---- launcher/pom.xml | 5 ---- mllib/pom.xml | 10 ------- network/common/pom.xml | 10 ------- network/shuffle/pom.xml | 10 ------- pom.xml | 17 ++++++++++-- project/SparkBuild.scala | 13 ++++++++-- sql/core/pom.xml | 5 ---- .../execution/HiveCompatibilitySuite.scala | 2 ++ sql/hive/pom.xml | 5 ---- .../spark/sql/hive/ExtendedHiveTest.java | 26 +++++++++++++++++++ .../spark/sql/hive/client/VersionsSuite.scala | 2 ++ streaming/pom.xml | 10 ------- unsafe/pom.xml | 10 ------- .../spark/deploy/yarn/ExtendedYarnTest.java | 26 +++++++++++++++++++ .../spark/deploy/yarn/YarnClusterSuite.scala | 1 + .../yarn/YarnShuffleIntegrationSuite.scala | 1 + 26 files changed, 124 insertions(+), 147 deletions(-) create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java create mode 100644 yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java diff --git a/core/pom.xml b/core/pom.xml index e31d90f608892..8a20181096223 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -331,16 +331,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.apache.curator curator-test diff --git a/dev/run-tests.py b/dev/run-tests.py index d8b22e1665e7b..1a816585187d9 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -118,6 +118,14 @@ def determine_modules_to_test(changed_modules): return modules_to_test.union(set(changed_modules)) +def determine_tags_to_exclude(changed_modules): + tags = [] + for m in modules.all_modules: + if m not in changed_modules: + tags += m.test_tags + return tags + + # ------------------------------------------------------------------------------------------------- # Functions for working with subprocesses and shell tools # ------------------------------------------------------------------------------------------------- @@ -369,6 +377,7 @@ def detect_binary_inop_with_mima(): def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] + profiles_and_goals = test_profiles + mvn_test_goals print("[info] Running Spark tests using Maven with these arguments: ", @@ -392,7 +401,7 @@ def run_scala_tests_sbt(test_modules, test_profiles): exec_sbt(profiles_and_goals) -def run_scala_tests(build_tool, hadoop_version, test_modules): +def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): """Function to properly execute all tests passed in as a set from the `determine_test_suites` function""" set_title_and_block("Running Spark unit tests", "BLOCK_SPARK_UNIT_TESTS") @@ -401,6 +410,10 @@ def run_scala_tests(build_tool, hadoop_version, test_modules): test_profiles = get_hadoop_profiles(hadoop_version) + \ list(set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules))) + + if excluded_tags: + test_profiles += ['-Dtest.exclude.tags=' + ",".join(excluded_tags)] + if build_tool == "maven": run_scala_tests_maven(test_profiles) else: @@ -500,8 +513,10 @@ def main(): target_branch = os.environ["ghprbTargetBranch"] changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) changed_modules = determine_modules_for_files(changed_files) + excluded_tags = determine_tags_to_exclude(changed_modules) if not changed_modules: changed_modules = [modules.root] + excluded_tags = [] print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) @@ -541,7 +556,7 @@ def main(): detect_binary_inop_with_mima() # run the test suites - run_scala_tests(build_tool, hadoop_version, test_modules) + run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 346452f3174e4..65397f1f3e0bc 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -31,7 +31,7 @@ class Module(object): def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), - should_run_r_tests=False): + test_tags=(), should_run_r_tests=False): """ Define a new module. @@ -50,6 +50,8 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= :param blacklisted_python_implementations: A set of Python implementations that are not supported by this module's Python components. The values in this set should match strings returned by Python's `platform.python_implementation()`. + :param test_tags A set of tags that will be excluded when running unit tests if the module + is not explicitly changed. :param should_run_r_tests: If true, changes in this module will trigger all R tests. """ self.name = name @@ -60,6 +62,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations + self.test_tags = test_tags self.should_run_r_tests = should_run_r_tests self.dependent_modules = set() @@ -85,6 +88,9 @@ def contains_file(self, filename): "catalyst/test", "sql/test", "hive/test", + ], + test_tags=[ + "org.apache.spark.sql.hive.ExtendedHiveTest" ] ) @@ -398,6 +404,22 @@ def contains_file(self, filename): ) +yarn = Module( + name="yarn", + dependencies=[], + source_file_regexes=[ + "yarn/", + "network/yarn/", + ], + sbt_test_goals=[ + "yarn/test", + "network-yarn/test", + ], + test_tags=[ + "org.apache.spark.deploy.yarn.ExtendedYarnTest" + ] +) + # The root module is a dummy module which is used to run all of the tests. # No other modules should directly depend on this module. root = Module( diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 132062f94fb45..3154e36c21ef5 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -66,16 +66,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 05abd9e2e6810..7d0d46dadc727 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -86,16 +86,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 05e6338a08b0a..913c47d33f488 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -58,16 +58,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.apache.activemq activemq-core diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 244ad58ae9593..9137bf25ee8ae 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -58,16 +58,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 171df8682c848..6fec4f0e8a0f9 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -57,16 +57,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 81794a8536318..dba3dda8a9562 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -58,16 +58,6 @@ test-jar test - - junit - junit - test - - - com.novocode - junit-interface - test - diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 6dd8ff69c2943..760f183a2ef37 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -74,11 +74,6 @@ scalacheck_${scala.binary.version} test - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/launcher/pom.xml b/launcher/pom.xml index ed38e66aa2467..80696280a1d18 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -42,11 +42,6 @@ log4j test - - junit - junit - test - org.mockito mockito-core diff --git a/mllib/pom.xml b/mllib/pom.xml index 22c0c6008ba37..5dedacb38874e 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -94,16 +94,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.mockito mockito-core diff --git a/network/common/pom.xml b/network/common/pom.xml index 1cc054a8936c5..9c12cca0df609 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -64,16 +64,6 @@ - - junit - junit - test - - - com.novocode - junit-interface - test - log4j log4j diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 7a66c968041ce..e4f4c57b683c8 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -78,16 +78,6 @@ test-jar test - - junit - junit - test - - - com.novocode - junit-interface - test - log4j log4j diff --git a/pom.xml b/pom.xml index 6535994641145..2927d3e107563 100644 --- a/pom.xml +++ b/pom.xml @@ -181,6 +181,7 @@ 0.9.2 ${java.home} + @@ -1952,6 +1964,7 @@ __not_used__ + ${test.exclude.tags} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 901cfa538d23e..d80d300f1c3b2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -567,11 +567,20 @@ object TestSettings { javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, javaOptions += "-Xmx3g", + // Exclude tags defined in a system property + testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, + sys.props.get("test.exclude.tags").map { tags => + tags.split(",").flatMap { tag => Seq("-l", tag) }.toSeq + }.getOrElse(Nil): _*), + testOptions in Test += Tests.Argument(TestFrameworks.JUnit, + sys.props.get("test.exclude.tags").map { tags => + Seq("--exclude-categories=" + tags) + }.getOrElse(Nil): _*), // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), - testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), + testOptions in Test += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), // Enable Junit testing. - libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", + libraryDependencies += "com.novocode" % "junit-interface" % "0.11" % "test", // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, // Make sure the test temp directory exists. diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 465aa3a3888c2..fa6732db183d8 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -73,11 +73,6 @@ jackson-databind ${fasterxml.jackson.version} - - junit - junit - test - org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index ab309e0a1d36b..ffc4c32794ca4 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -24,11 +24,13 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.hive.test.TestHive /** * Runs the test cases that are included in the hive distribution. */ +@ExtendedHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index ac67fe5f47be9..82cfeb2bb95d3 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -160,11 +160,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - org.apache.spark spark-sql_${scala.binary.version} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java new file mode 100644 index 0000000000000..e2183183fb559 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedHiveTest { } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index f0bb77092c0cf..888d1b7b45532 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils @@ -32,6 +33,7 @@ import org.apache.spark.util.Utils * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ +@ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { // Do not use a temp path here to speed up subsequent executions of the unit test during diff --git a/streaming/pom.xml b/streaming/pom.xml index 5cc9001b0e9ab..1e6ee009ca6d5 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -84,21 +84,11 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - org.seleniumhq.selenium selenium-java test - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 066abe92e51c0..4e8b9a84bb67f 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -55,16 +55,6 @@ - - junit - junit - test - - - com.novocode - junit-interface - test - org.mockito mockito-core diff --git a/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java b/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java new file mode 100644 index 0000000000000..7a8f2fe979c1f --- /dev/null +++ b/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedYarnTest { } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index b5a42fd6afd98..105c3090d489d 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ +@ExtendedYarnTest class YarnClusterSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = new YarnConfiguration() diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 8d9c9b3004eda..4700e2428df08 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} /** * Integration test for the external shuffle service with a yarn mini-cluster */ +@ExtendedYarnTest class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = { From 1fa82fca3ce35b7e48fc07a2c1f8480d2bcbcec5 Mon Sep 17 00:00:00 2001 From: noelsmith Date: Tue, 15 Sep 2015 12:23:20 -0700 Subject: [PATCH 0012/1168] [PYSPARK] [MLLIB] [DOCS] Replaced addversion with versionadded in mllib.random Missed this when reviewing `pyspark.mllib.random` for SPARK-10275. Author: noelsmith Closes #8773 from noel-smith/mllib-random-versionadded-fix. --- python/pyspark/mllib/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 9c733b1332bc0..6a3c643b66417 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -41,7 +41,7 @@ class RandomRDDs(object): Generator methods for creating RDDs comprised of i.i.d samples from some distribution. - .. addedversion:: 1.1.0 + .. versionadded:: 1.1.0 """ @staticmethod From f3014442291b7eb8fc8a56daca7554294f414a0c Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Tue, 15 Sep 2015 20:42:33 +0100 Subject: [PATCH 0013/1168] [DOCS] Small fixes to Spark on Yarn doc * a follow-up to 16b6d18613e150c7038c613992d80a7828413e66 as `--num-executors` flag is not suppported. * links + formatting Author: Jacek Laskowski Closes #8762 from jaceklaskowski/docs-spark-on-yarn. --- docs/running-on-yarn.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 5159ef9e3394e..d1244323edfff 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -18,16 +18,16 @@ Spark application's configuration (driver, executors, and the AM when running in There are two deploy modes that can be used to launch Spark applications on YARN. In `yarn-cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `yarn-client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. -Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. +Unlike [Spark standalone](spark-standalone.html) and [Mesos](running-on-mesos.html) modes, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. + To launch a Spark application in `yarn-cluster` mode: - `$ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options]` + $ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] For example: $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ --master yarn-cluster \ - --num-executors 3 \ --driver-memory 4g \ --executor-memory 2g \ --executor-cores 1 \ @@ -37,7 +37,7 @@ For example: The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. -To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. To run spark-shell: +To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. The following shows how you can run `spark-shell` in `yarn-client` mode: $ ./bin/spark-shell --master yarn-client @@ -54,8 +54,8 @@ In `yarn-cluster` mode, the driver runs on a different machine than the client, # Preparations -Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. -Binary distributions can be downloaded from the Spark project website. +Running Spark on YARN requires a binary distribution of Spark which is built with YARN support. +Binary distributions can be downloaded from the [downloads page](http://spark.apache.org/downloads.html) of the project website. To build Spark yourself, refer to [Building Spark](building-spark.html). # Configuration From 9bd7952481ef52c2367d58545232a0e52dd0865c Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 15 Sep 2015 13:03:38 -0700 Subject: [PATCH 0014/1168] Revert "[SPARK-10300] [BUILD] [TESTS] Add support for test tags in run-tests.py." This reverts commit 8abef21dac1a6538c4e4e0140323b83d804d602b. --- core/pom.xml | 10 +++++++ dev/run-tests.py | 19 ++------------ dev/sparktestsupport/modules.py | 24 +---------------- external/flume/pom.xml | 10 +++++++ external/kafka/pom.xml | 10 +++++++ external/mqtt/pom.xml | 10 +++++++ external/twitter/pom.xml | 10 +++++++ external/zeromq/pom.xml | 10 +++++++ extras/java8-tests/pom.xml | 10 +++++++ extras/kinesis-asl/pom.xml | 5 ++++ launcher/pom.xml | 5 ++++ mllib/pom.xml | 10 +++++++ network/common/pom.xml | 10 +++++++ network/shuffle/pom.xml | 10 +++++++ pom.xml | 17 ++---------- project/SparkBuild.scala | 13 ++-------- sql/core/pom.xml | 5 ++++ .../execution/HiveCompatibilitySuite.scala | 2 -- sql/hive/pom.xml | 5 ++++ .../spark/sql/hive/ExtendedHiveTest.java | 26 ------------------- .../spark/sql/hive/client/VersionsSuite.scala | 2 -- streaming/pom.xml | 10 +++++++ unsafe/pom.xml | 10 +++++++ .../spark/deploy/yarn/ExtendedYarnTest.java | 26 ------------------- .../spark/deploy/yarn/YarnClusterSuite.scala | 1 - .../yarn/YarnShuffleIntegrationSuite.scala | 1 - 26 files changed, 147 insertions(+), 124 deletions(-) delete mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java delete mode 100644 yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java diff --git a/core/pom.xml b/core/pom.xml index 8a20181096223..e31d90f608892 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -331,6 +331,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + org.apache.curator curator-test diff --git a/dev/run-tests.py b/dev/run-tests.py index 1a816585187d9..d8b22e1665e7b 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -118,14 +118,6 @@ def determine_modules_to_test(changed_modules): return modules_to_test.union(set(changed_modules)) -def determine_tags_to_exclude(changed_modules): - tags = [] - for m in modules.all_modules: - if m not in changed_modules: - tags += m.test_tags - return tags - - # ------------------------------------------------------------------------------------------------- # Functions for working with subprocesses and shell tools # ------------------------------------------------------------------------------------------------- @@ -377,7 +369,6 @@ def detect_binary_inop_with_mima(): def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] - profiles_and_goals = test_profiles + mvn_test_goals print("[info] Running Spark tests using Maven with these arguments: ", @@ -401,7 +392,7 @@ def run_scala_tests_sbt(test_modules, test_profiles): exec_sbt(profiles_and_goals) -def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): +def run_scala_tests(build_tool, hadoop_version, test_modules): """Function to properly execute all tests passed in as a set from the `determine_test_suites` function""" set_title_and_block("Running Spark unit tests", "BLOCK_SPARK_UNIT_TESTS") @@ -410,10 +401,6 @@ def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): test_profiles = get_hadoop_profiles(hadoop_version) + \ list(set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules))) - - if excluded_tags: - test_profiles += ['-Dtest.exclude.tags=' + ",".join(excluded_tags)] - if build_tool == "maven": run_scala_tests_maven(test_profiles) else: @@ -513,10 +500,8 @@ def main(): target_branch = os.environ["ghprbTargetBranch"] changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) changed_modules = determine_modules_for_files(changed_files) - excluded_tags = determine_tags_to_exclude(changed_modules) if not changed_modules: changed_modules = [modules.root] - excluded_tags = [] print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) @@ -556,7 +541,7 @@ def main(): detect_binary_inop_with_mima() # run the test suites - run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) + run_scala_tests(build_tool, hadoop_version, test_modules) modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 65397f1f3e0bc..346452f3174e4 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -31,7 +31,7 @@ class Module(object): def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), - test_tags=(), should_run_r_tests=False): + should_run_r_tests=False): """ Define a new module. @@ -50,8 +50,6 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= :param blacklisted_python_implementations: A set of Python implementations that are not supported by this module's Python components. The values in this set should match strings returned by Python's `platform.python_implementation()`. - :param test_tags A set of tags that will be excluded when running unit tests if the module - is not explicitly changed. :param should_run_r_tests: If true, changes in this module will trigger all R tests. """ self.name = name @@ -62,7 +60,6 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations - self.test_tags = test_tags self.should_run_r_tests = should_run_r_tests self.dependent_modules = set() @@ -88,9 +85,6 @@ def contains_file(self, filename): "catalyst/test", "sql/test", "hive/test", - ], - test_tags=[ - "org.apache.spark.sql.hive.ExtendedHiveTest" ] ) @@ -404,22 +398,6 @@ def contains_file(self, filename): ) -yarn = Module( - name="yarn", - dependencies=[], - source_file_regexes=[ - "yarn/", - "network/yarn/", - ], - sbt_test_goals=[ - "yarn/test", - "network-yarn/test", - ], - test_tags=[ - "org.apache.spark.deploy.yarn.ExtendedYarnTest" - ] -) - # The root module is a dummy module which is used to run all of the tests. # No other modules should directly depend on this module. root = Module( diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 3154e36c21ef5..132062f94fb45 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -66,6 +66,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 7d0d46dadc727..05abd9e2e6810 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -86,6 +86,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 913c47d33f488..05e6338a08b0a 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -58,6 +58,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + org.apache.activemq activemq-core diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 9137bf25ee8ae..244ad58ae9593 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -58,6 +58,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 6fec4f0e8a0f9..171df8682c848 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -57,6 +57,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index dba3dda8a9562..81794a8536318 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -58,6 +58,16 @@ test-jar test + + junit + junit + test + + + com.novocode + junit-interface + test + diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 760f183a2ef37..6dd8ff69c2943 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -74,6 +74,11 @@ scalacheck_${scala.binary.version} test + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/launcher/pom.xml b/launcher/pom.xml index 80696280a1d18..ed38e66aa2467 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -42,6 +42,11 @@ log4j test + + junit + junit + test + org.mockito mockito-core diff --git a/mllib/pom.xml b/mllib/pom.xml index 5dedacb38874e..22c0c6008ba37 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -94,6 +94,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + org.mockito mockito-core diff --git a/network/common/pom.xml b/network/common/pom.xml index 9c12cca0df609..1cc054a8936c5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -64,6 +64,16 @@ + + junit + junit + test + + + com.novocode + junit-interface + test + log4j log4j diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index e4f4c57b683c8..7a66c968041ce 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -78,6 +78,16 @@ test-jar test + + junit + junit + test + + + com.novocode + junit-interface + test + log4j log4j diff --git a/pom.xml b/pom.xml index 2927d3e107563..6535994641145 100644 --- a/pom.xml +++ b/pom.xml @@ -181,7 +181,6 @@ 0.9.2 ${java.home} - @@ -1964,7 +1952,6 @@ __not_used__ - ${test.exclude.tags} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d80d300f1c3b2..901cfa538d23e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -567,20 +567,11 @@ object TestSettings { javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, javaOptions += "-Xmx3g", - // Exclude tags defined in a system property - testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, - sys.props.get("test.exclude.tags").map { tags => - tags.split(",").flatMap { tag => Seq("-l", tag) }.toSeq - }.getOrElse(Nil): _*), - testOptions in Test += Tests.Argument(TestFrameworks.JUnit, - sys.props.get("test.exclude.tags").map { tags => - Seq("--exclude-categories=" + tags) - }.getOrElse(Nil): _*), // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), - testOptions in Test += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), + testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), // Enable Junit testing. - libraryDependencies += "com.novocode" % "junit-interface" % "0.11" % "test", + libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, // Make sure the test temp directory exists. diff --git a/sql/core/pom.xml b/sql/core/pom.xml index fa6732db183d8..465aa3a3888c2 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -73,6 +73,11 @@ jackson-databind ${fasterxml.jackson.version} + + junit + junit + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index ffc4c32794ca4..ab309e0a1d36b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -24,13 +24,11 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.hive.test.TestHive /** * Runs the test cases that are included in the hive distribution. */ -@ExtendedHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 82cfeb2bb95d3..ac67fe5f47be9 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -160,6 +160,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + org.apache.spark spark-sql_${scala.binary.version} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java deleted file mode 100644 index e2183183fb559..0000000000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive; - -import java.lang.annotation.*; -import org.scalatest.TagAnnotation; - -@TagAnnotation -@Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.METHOD, ElementType.TYPE}) -public @interface ExtendedHiveTest { } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 888d1b7b45532..f0bb77092c0cf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils @@ -33,7 +32,6 @@ import org.apache.spark.util.Utils * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ -@ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { // Do not use a temp path here to speed up subsequent executions of the unit test during diff --git a/streaming/pom.xml b/streaming/pom.xml index 1e6ee009ca6d5..5cc9001b0e9ab 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -84,11 +84,21 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + org.seleniumhq.selenium selenium-java test + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 4e8b9a84bb67f..066abe92e51c0 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -55,6 +55,16 @@ + + junit + junit + test + + + com.novocode + junit-interface + test + org.mockito mockito-core diff --git a/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java b/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java deleted file mode 100644 index 7a8f2fe979c1f..0000000000000 --- a/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn; - -import java.lang.annotation.*; -import org.scalatest.TagAnnotation; - -@TagAnnotation -@Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.METHOD, ElementType.TYPE}) -public @interface ExtendedYarnTest { } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 105c3090d489d..b5a42fd6afd98 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -39,7 +39,6 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -@ExtendedYarnTest class YarnClusterSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = new YarnConfiguration() diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 4700e2428df08..8d9c9b3004eda 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -32,7 +32,6 @@ import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} /** * Integration test for the external shuffle service with a yarn mini-cluster */ -@ExtendedYarnTest class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = { From fad31e9951626f62559fdd507a7053785bf3a880 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 Sep 2015 13:33:32 -0700 Subject: [PATCH 0015/1168] [SPARK-10437] [SQL] Support aggregation expressions in Order By JIRA: https://issues.apache.org/jira/browse/SPARK-10437 If an expression in `SortOrder` is a resolved one, such as `count(1)`, the corresponding rule in `Analyzer` to make it work in order by will not be applied. Author: Liang-Chi Hsieh Closes #8599 from viirya/orderby-agg. --- .../sql/catalyst/analysis/Analyzer.scala | 14 +++++++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 20 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 591747b45c376..02f34cbf58ad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -561,7 +561,7 @@ class Analyzer( } case sort @ Sort(sortOrder, global, aggregate: Aggregate) - if aggregate.resolved && !sort.resolved => + if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { @@ -598,9 +598,15 @@ class Analyzer( } } - Project(aggregate.output, - Sort(evaluatedOrderings, global, - aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + // Since we don't rely on sort.resolved as the stop condition for this rule, + // we need to check this and prevent applying this rule multiple times + if (sortOrder == evaluatedOrderings) { + sort + } else { + Project(aggregate.output, + Sort(evaluatedOrderings, global, + aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + } } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 962b100b532c9..f9981356f364f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1562,6 +1562,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { |ORDER BY sum(b) + 1 """.stripMargin), Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT count(*) + |FROM orderByData + |GROUP BY a + |ORDER BY count(*) + """.stripMargin), + Row(2) :: Row(2) :: Row(2) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a + |FROM orderByData + |GROUP BY a + |ORDER BY a, count(*), sum(b) + """.stripMargin), + Row("1") :: Row("2") :: Row("3") :: Row("4") :: Nil) } test("SPARK-7952: fix the equality check between boolean and numeric types") { From 3bb5a1393b6ea576a0c4f977899d13756fb688aa Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 15 Sep 2015 13:36:52 -0700 Subject: [PATCH 0016/1168] [SPARK-10475] [SQL] improve column prunning for Project on Sort Sometimes we can't push down the whole `Project` though `Sort`, but we still have a chance to push down part of it. Author: Wenchen Fan Closes #8644 from cloud-fan/column-prune. --- .../sql/catalyst/optimizer/Optimizer.scala | 19 +++++++++++++++---- .../optimizer/ColumnPruningSuite.scala | 11 +++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0f4caec7451a2..648a65e7c0eb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -228,10 +228,21 @@ object ColumnPruning extends Rule[LogicalPlan] { case Project(projectList, Limit(exp, child)) => Limit(exp, Project(projectList, child)) - // Push down project if possible when the child is sort - case p @ Project(projectList, s @ Sort(_, _, grandChild)) - if s.references.subsetOf(p.outputSet) => - s.copy(child = Project(projectList, grandChild)) + // Push down project if possible when the child is sort. + case p @ Project(projectList, s @ Sort(_, _, grandChild)) => + if (s.references.subsetOf(p.outputSet)) { + s.copy(child = Project(projectList, grandChild)) + } else { + val neededReferences = s.references ++ p.references + if (neededReferences == grandChild.outputSet) { + // No column we can prune, return the original plan. + p + } else { + // Do not use neededReferences.toSeq directly, should respect grandChild's output order. + val newProjectList = grandChild.output.filter(neededReferences.contains) + p.copy(child = s.copy(child = Project(newProjectList, grandChild))) + } + } // Eliminate no-op Projects case Project(projectList, child) if child.output == projectList => child diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index dbebcb86809de..4a1e7ceaf394b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -80,5 +80,16 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Column pruning for Project on Sort") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + + val query = input.orderBy('b.asc).select('a).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = input.select('a, 'b).orderBy('b.asc).select('a).analyze + + comparePlans(optimized, correctAnswer) + } + // todo: add more tests for column pruning } From daf4658b2db8febbb4d017036b821c43d5dfcecb Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 15 Sep 2015 15:46:47 -0700 Subject: [PATCH 0017/1168] [SPARK-7685] [ML] Apply weights to different samples in Logistic Regression In fraud detection dataset, almost all the samples are negative while only couple of them are positive. This type of high imbalanced data will bias the models toward negative resulting poor performance. In python-scikit, they provide a correction allowing users to Over-/undersample the samples of each class according to the given weights. In auto mode, selects weights inversely proportional to class frequencies in the training set. This can be done in a more efficient way by multiplying the weights into loss and gradient instead of doing actual over/undersampling in the training dataset which is very expensive. http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html On the other hand, some of the training data maybe more important like the training samples from tenure users while the training samples from new users maybe less important. We should be able to provide another "weight: Double" information in the LabeledPoint to weight them differently in the learning algorithm. Author: DB Tsai Author: DB Tsai Closes #7884 from dbtsai/SPARK-7685. --- .../classification/LogisticRegression.scala | 199 +++++++++++------- .../ml/param/shared/SharedParamsCodeGen.scala | 6 +- .../spark/ml/param/shared/sharedParams.scala | 12 +- .../stat/MultivariateOnlineSummarizer.scala | 75 ++++--- .../LogisticRegressionSuite.scala | 102 ++++++++- .../MultivariateOnlineSummarizerSuite.scala | 27 +++ project/MimaExcludes.scala | 10 +- 7 files changed, 303 insertions(+), 128 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a460262b87e43..bd96e8d000ff2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -29,12 +29,12 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.storage.StorageLevel /** @@ -42,7 +42,7 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasStandardization with HasThreshold { + with HasStandardization with HasWeightCol with HasThreshold { /** * Set threshold in binary classification, in range [0, 1]. @@ -146,6 +146,17 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } } +/** + * Class that represents an instance of weighted data point with label and features. + * + * TODO: Refactor this class to proper place. + * + * @param label Label for this data point. + * @param weight The weight of this instance. + * @param features The vector of features for this data point. + */ +private[classification] case class Instance(label: Double, weight: Double, features: Vector) + /** * :: Experimental :: * Logistic regression. @@ -218,31 +229,42 @@ class LogisticRegression(override val uid: String) override def getThreshold: Double = super.getThreshold + /** + * Whether to over-/under-sample training instances according to the given weights in weightCol. + * If empty, all instances are treated equally (weight 1.0). + * Default is empty, so all instances have weight one. + * @group setParam + */ + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) override def getThresholds: Array[Double] = super.getThresholds override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. - val instances = extractLabeledPoints(dataset).map { - case LabeledPoint(label: Double, features: Vector) => (label, features) + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val (summarizer, labelSummarizer) = instances.treeAggregate( - (new MultivariateOnlineSummarizer, new MultiClassSummarizer))( - seqOp = (c, v) => (c, v) match { - case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer), - (label: Double, features: Vector)) => - (summarizer.add(features), labelSummarizer.add(label)) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((summarizer1: MultivariateOnlineSummarizer, - classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer, - classSummarizer2: MultiClassSummarizer)) => - (summarizer1.merge(summarizer2), classSummarizer1.merge(classSummarizer2)) - }) + val (summarizer, labelSummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer), + c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.treeAggregate( + new MultivariateOnlineSummarizer, new MultiClassSummarizer)(seqOp, combOp) + } val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -295,7 +317,7 @@ class LogisticRegression(override val uid: String) new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) } - val initialWeightsWithIntercept = + val initialCoefficientsWithIntercept = Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) if ($(fitIntercept)) { @@ -312,14 +334,14 @@ class LogisticRegression(override val uid: String) b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - initialWeightsWithIntercept.toArray(numFeatures) - = math.log(histogram(1).toDouble / histogram(0).toDouble) + initialCoefficientsWithIntercept.toArray(numFeatures) + = math.log(histogram(1) / histogram(0)) } val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialWeightsWithIntercept.toBreeze.toDenseVector) + initialCoefficientsWithIntercept.toBreeze.toDenseVector) - val (weights, intercept, objectiveHistory) = { + val (coefficients, intercept, objectiveHistory) = { /* Note that in Logistic Regression, the objective history (loss + regularization) is log-likelihood which is invariance under feature standardization. As a result, @@ -339,28 +361,29 @@ class LogisticRegression(override val uid: String) } /* - The weights are trained in the scaled space; we're converting them back to + The coefficients are trained in the scaled space; we're converting them back to the original space. Note that the intercept in scaled space and original space is the same; as a result, no scaling is needed. */ - val rawWeights = state.x.toArray.clone() + val rawCoefficients = state.x.toArray.clone() var i = 0 while (i < numFeatures) { - rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } i += 1 } if ($(fitIntercept)) { - (Vectors.dense(rawWeights.dropRight(1)).compressed, rawWeights.last, arrayBuilder.result()) + (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last, + arrayBuilder.result()) } else { - (Vectors.dense(rawWeights).compressed, 0.0, arrayBuilder.result()) + (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result()) } } if (handlePersistence) instances.unpersist() - val model = copyValues(new LogisticRegressionModel(uid, weights, intercept)) + val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept)) val logRegSummary = new BinaryLogisticRegressionTrainingSummary( model.transform(dataset), $(probabilityCol), @@ -501,22 +524,29 @@ class LogisticRegressionModel private[ml] ( * corresponding joint dataset. */ private[classification] class MultiClassSummarizer extends Serializable { - private val distinctMap = new mutable.HashMap[Int, Long] + // The first element of value in distinctMap is the actually number of instances, + // and the second element of value is sum of the weights. + private val distinctMap = new mutable.HashMap[Int, (Long, Double)] private var totalInvalidCnt: Long = 0L /** * Add a new label into this MultilabelSummarizer, and update the distinct map. * @param label The label for this data point. + * @param weight The weight of this instances. * @return This MultilabelSummarizer */ - def add(label: Double): this.type = { + def add(label: Double, weight: Double = 1.0): this.type = { + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + + if (weight == 0.0) return this + if (label - label.toInt != 0.0 || label < 0) { totalInvalidCnt += 1 this } else { - val counts: Long = distinctMap.getOrElse(label.toInt, 0L) - distinctMap.put(label.toInt, counts + 1) + val (counts: Long, weightSum: Double) = distinctMap.getOrElse(label.toInt, (0L, 0.0)) + distinctMap.put(label.toInt, (counts + 1L, weightSum + weight)) this } } @@ -537,8 +567,8 @@ private[classification] class MultiClassSummarizer extends Serializable { } smallMap.distinctMap.foreach { case (key, value) => - val counts = largeMap.distinctMap.getOrElse(key, 0L) - largeMap.distinctMap.put(key, counts + value) + val (counts: Long, weightSum: Double) = largeMap.distinctMap.getOrElse(key, (0L, 0.0)) + largeMap.distinctMap.put(key, (counts + value._1, weightSum + value._2)) } largeMap.totalInvalidCnt += smallMap.totalInvalidCnt largeMap @@ -550,13 +580,13 @@ private[classification] class MultiClassSummarizer extends Serializable { /** @return The number of distinct labels in the input dataset. */ def numClasses: Int = distinctMap.keySet.max + 1 - /** @return The counts of each label in the input dataset. */ - def histogram: Array[Long] = { - val result = Array.ofDim[Long](numClasses) + /** @return The weightSum of each label in the input dataset. */ + def histogram: Array[Double] = { + val result = Array.ofDim[Double](numClasses) var i = 0 val len = result.length while (i < len) { - result(i) = distinctMap.getOrElse(i, 0L) + result(i) = distinctMap.getOrElse(i, (0L, 0.0))._2 i += 1 } result @@ -565,6 +595,8 @@ private[classification] class MultiClassSummarizer extends Serializable { /** * Abstraction for multinomial Logistic Regression Training results. + * Currently, the training summary ignores the training weights except + * for the objective trace. */ sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary { @@ -584,10 +616,10 @@ sealed trait LogisticRegressionSummary extends Serializable { /** Dataframe outputted by the model's `transform` method. */ def predictions: DataFrame - /** Field in "predictions" which gives the calibrated probability of each sample as a vector. */ + /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */ def probabilityCol: String - /** Field in "predictions" which gives the the true label of each sample. */ + /** Field in "predictions" which gives the the true label of each instance. */ def labelCol: String } @@ -597,8 +629,8 @@ sealed trait LogisticRegressionSummary extends Serializable { * Logistic regression training results. * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each sample as a vector. - * @param labelCol field in "predictions" which gives the true label of each sample. + * each instance as a vector. + * @param labelCol field in "predictions" which gives the true label of each instance. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Experimental @@ -617,8 +649,8 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * Binary Logistic regression results for a given model. * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each sample. - * @param labelCol field in "predictions" which gives the true label of each sample. + * each instance. + * @param labelCol field in "predictions" which gives the true label of each instance. */ @Experimental class BinaryLogisticRegressionSummary private[classification] ( @@ -687,14 +719,14 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * LogisticAggregator computes the gradient and loss for binary logistic loss function, as used - * in binary classification for samples in sparse or dense vector in a online fashion. + * in binary classification for instances in sparse or dense vector in a online fashion. * * Note that multinomial logistic loss is not supported yet! * * Two LogisticAggregator can be merged together to have a summary of loss and gradient of * the corresponding joint dataset. * - * @param weights The weights/coefficients corresponding to the features. + * @param coefficients The coefficients corresponding to the features. * @param numClasses the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * @param fitIntercept Whether to fit an intercept term. @@ -702,25 +734,25 @@ class BinaryLogisticRegressionSummary private[classification] ( * @param featuresMean The mean values of the features. */ private class LogisticAggregator( - weights: Vector, + coefficients: Vector, numClasses: Int, fitIntercept: Boolean, featuresStd: Array[Double], featuresMean: Array[Double]) extends Serializable { - private var totalCnt: Long = 0L + private var weightSum = 0.0 private var lossSum = 0.0 - private val weightsArray = weights match { + private val coefficientsArray = coefficients match { case dv: DenseVector => dv.values case _ => throw new IllegalArgumentException( - s"weights only supports dense vector but got type ${weights.getClass}.") + s"coefficients only supports dense vector but got type ${coefficients.getClass}.") } - private val dim = if (fitIntercept) weightsArray.length - 1 else weightsArray.length + private val dim = if (fitIntercept) coefficientsArray.length - 1 else coefficientsArray.length - private val gradientSumArray = Array.ofDim[Double](weightsArray.length) + private val gradientSumArray = Array.ofDim[Double](coefficientsArray.length) /** * Add a new training data to this LogisticAggregator, and update the loss and gradient @@ -729,13 +761,17 @@ private class LogisticAggregator( * @param label The label for this data point. * @param data The features for one data point in dense/sparse vector format to be added * into this aggregator. + * @param weight The weight for over-/undersamples each of training instance. Default is one. * @return This LogisticAggregator object. */ - def add(label: Double, data: Vector): this.type = { - require(dim == data.size, s"Dimensions mismatch when adding new sample." + + def add(label: Double, data: Vector, weight: Double = 1.0): this.type = { + require(dim == data.size, s"Dimensions mismatch when adding new instance." + s" Expecting $dim but got ${data.size}.") + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") - val localWeightsArray = weightsArray + if (weight == 0.0) return this + + val localCoefficientsArray = coefficientsArray val localGradientSumArray = gradientSumArray numClasses match { @@ -745,13 +781,13 @@ private class LogisticAggregator( var sum = 0.0 data.foreachActive { (index, value) => if (featuresStd(index) != 0.0 && value != 0.0) { - sum += localWeightsArray(index) * (value / featuresStd(index)) + sum += localCoefficientsArray(index) * (value / featuresStd(index)) } } - sum + { if (fitIntercept) localWeightsArray(dim) else 0.0 } + sum + { if (fitIntercept) localCoefficientsArray(dim) else 0.0 } } - val multiplier = (1.0 / (1.0 + math.exp(margin))) - label + val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label) data.foreachActive { (index, value) => if (featuresStd(index) != 0.0 && value != 0.0) { @@ -765,15 +801,15 @@ private class LogisticAggregator( if (label > 0) { // The following is equivalent to log(1 + exp(margin)) but more numerically stable. - lossSum += MLUtils.log1pExp(margin) + lossSum += weight * MLUtils.log1pExp(margin) } else { - lossSum += MLUtils.log1pExp(margin) - margin + lossSum += weight * (MLUtils.log1pExp(margin) - margin) } case _ => new NotImplementedError("LogisticRegression with ElasticNet in ML package only supports " + "binary classification for now.") } - totalCnt += 1 + weightSum += weight this } @@ -789,8 +825,8 @@ private class LogisticAggregator( require(dim == other.dim, s"Dimensions mismatch when merging with another " + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") - if (other.totalCnt != 0) { - totalCnt += other.totalCnt + if (other.weightSum != 0.0) { + weightSum += other.weightSum lossSum += other.lossSum var i = 0 @@ -805,13 +841,17 @@ private class LogisticAggregator( this } - def count: Long = totalCnt - - def loss: Double = lossSum / totalCnt + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") + lossSum / weightSum + } def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / totalCnt, result) + scal(1.0 / weightSum, result) result } } @@ -823,7 +863,7 @@ private class LogisticAggregator( * It's used in Breeze's convex optimization routines. */ private class LogisticCostFun( - data: RDD[(Double, Vector)], + data: RDD[Instance], numClasses: Int, fitIntercept: Boolean, standardization: Boolean, @@ -831,22 +871,23 @@ private class LogisticCostFun( featuresMean: Array[Double], regParamL2: Double) extends DiffFunction[BDV[Double]] { - override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { val numFeatures = featuresStd.length - val w = Vectors.fromBreeze(weights) + val w = Vectors.fromBreeze(coefficients) - val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept, - featuresStd, featuresMean))( - seqOp = (c, v) => (c, v) match { - case (aggregator, (label, features)) => aggregator.add(label, features) - }, - combOp = (c1, c2) => (c1, c2) match { - case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - }) + val logisticAggregator = { + val seqOp = (c: LogisticAggregator, instance: Instance) => + c.add(instance.label, instance.features, instance.weight) + val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2) + + data.treeAggregate( + new LogisticAggregator(w, numClasses, fitIntercept, featuresStd, featuresMean) + )(seqOp, combOp) + } val totalGradientArray = logisticAggregator.gradient.toArray - // regVal is the sum of weight squares excluding intercept for L2 regularization. + // regVal is the sum of coefficients squares excluding intercept for L2 regularization. val regVal = if (regParamL2 == 0.0) { 0.0 } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index e9e99ed1db40e..8049d51fee5ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -42,7 +42,7 @@ private[shared] object SharedParamsCodeGen { Some("\"rawPrediction\"")), ParamDesc[String]("probabilityCol", "Column name for predicted class conditional" + " probabilities. Note: Not all models output well-calibrated probability estimates!" + - " These probabilities should be treated as confidences, not precise probabilities.", + " These probabilities should be treated as confidences, not precise probabilities", Some("\"probability\"")), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), @@ -65,10 +65,10 @@ private[shared] object SharedParamsCodeGen { "options may be added later.", isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + - " before fitting the model.", Some("true")), + " before fitting the model", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." + - " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", + " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 30092170863ad..aff47fc326c4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -127,10 +127,10 @@ private[ml] trait HasRawPredictionCol extends Params { private[ml] trait HasProbabilityCol extends Params { /** - * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.. + * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. * @group param */ - final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities") setDefault(probabilityCol, "probability") @@ -270,10 +270,10 @@ private[ml] trait HasHandleInvalid extends Params { private[ml] trait HasStandardization extends Params { /** - * Param for whether to standardize the training features before fitting the model.. + * Param for whether to standardize the training features before fitting the model. * @group param */ - final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model.") + final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model") setDefault(standardization, true) @@ -304,10 +304,10 @@ private[ml] trait HasSeed extends Params { private[ml] trait HasElasticNetParam extends Params { /** - * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.. + * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. * @group param */ - final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1)) + final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", ParamValidators.inRange(0, 1)) /** @group getParam */ final def getElasticNetParam: Double = $(elasticNetParam) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 51b713e263e0c..201333c3690df 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -23,16 +23,19 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * :: DeveloperApi :: * MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean, - * variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector + * variance, minimum, maximum, counts, and nonzero counts for instances in sparse or dense vector * format in a online fashion. * * Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of * the corresponding joint dataset. * - * A numerically stable algorithm is implemented to compute sample mean and variance: + * A numerically stable algorithm is implemented to compute the mean and variance of instances: * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]] * Zero elements (including explicit zero values) are skipped when calling add(), * to have time complexity O(nnz) instead of O(n) for each column. + * + * For weighted instances, the unbiased estimation of variance is defined by the reliability + * weights: [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]. */ @Since("1.1.0") @DeveloperApi @@ -44,6 +47,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var currM2: Array[Double] = _ private var currL1: Array[Double] = _ private var totalCnt: Long = 0 + private var weightSum: Double = 0.0 + private var weightSquareSum: Double = 0.0 private var nnz: Array[Double] = _ private var currMax: Array[Double] = _ private var currMin: Array[Double] = _ @@ -55,10 +60,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * @return This MultivariateOnlineSummarizer object. */ @Since("1.1.0") - def add(sample: Vector): this.type = { + def add(sample: Vector): this.type = add(sample, 1.0) + + private[spark] def add(instance: Vector, weight: Double): this.type = { + require(weight >= 0.0, s"sample weight, ${weight} has to be >= 0.0") + if (weight == 0.0) return this + if (n == 0) { - require(sample.size > 0, s"Vector should have dimension larger than zero.") - n = sample.size + require(instance.size > 0, s"Vector should have dimension larger than zero.") + n = instance.size currMean = Array.ofDim[Double](n) currM2n = Array.ofDim[Double](n) @@ -69,8 +79,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMin = Array.fill[Double](n)(Double.MaxValue) } - require(n == sample.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $n but got ${sample.size}.") + require(n == instance.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $n but got ${instance.size}.") val localCurrMean = currMean val localCurrM2n = currM2n @@ -79,7 +89,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val localNnz = nnz val localCurrMax = currMax val localCurrMin = currMin - sample.foreachActive { (index, value) => + instance.foreachActive { (index, value) => if (value != 0.0) { if (localCurrMax(index) < value) { localCurrMax(index) = value @@ -90,15 +100,17 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val prevMean = localCurrMean(index) val diff = value - prevMean - localCurrMean(index) = prevMean + diff / (localNnz(index) + 1.0) - localCurrM2n(index) += (value - localCurrMean(index)) * diff - localCurrM2(index) += value * value - localCurrL1(index) += math.abs(value) + localCurrMean(index) = prevMean + weight * diff / (localNnz(index) + weight) + localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff + localCurrM2(index) += weight * value * value + localCurrL1(index) += weight * math.abs(value) - localNnz(index) += 1.0 + localNnz(index) += weight } } + weightSum += weight + weightSquareSum += weight * weight totalCnt += 1 this } @@ -112,10 +124,12 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") def merge(other: MultivariateOnlineSummarizer): this.type = { - if (this.totalCnt != 0 && other.totalCnt != 0) { + if (this.weightSum != 0.0 && other.weightSum != 0.0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt + weightSum += other.weightSum + weightSquareSum += other.weightSquareSum var i = 0 while (i < n) { val thisNnz = nnz(i) @@ -138,13 +152,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S nnz(i) = totalNnz i += 1 } - } else if (totalCnt == 0 && other.totalCnt != 0) { + } else if (weightSum == 0.0 && other.weightSum != 0.0) { this.n = other.n this.currMean = other.currMean.clone() this.currM2n = other.currM2n.clone() this.currM2 = other.currM2.clone() this.currL1 = other.currL1.clone() this.totalCnt = other.totalCnt + this.weightSum = other.weightSum + this.weightSquareSum = other.weightSquareSum this.nnz = other.nnz.clone() this.currMax = other.currMax.clone() this.currMin = other.currMin.clone() @@ -158,28 +174,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def mean: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { - realMean(i) = currMean(i) * (nnz(i) / totalCnt) + realMean(i) = currMean(i) * (nnz(i) / weightSum) i += 1 } Vectors.dense(realMean) } /** - * Sample variance of each dimension. + * Unbiased estimate of sample variance of each dimension. * */ @Since("1.1.0") override def variance: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realVariance = Array.ofDim[Double](n) - val denominator = totalCnt - 1.0 + val denominator = weightSum - (weightSquareSum / weightSum) // Sample variance is computed, if the denominator is less than 0, the variance is just 0. if (denominator > 0.0) { @@ -187,9 +203,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 val len = currM2n.length while (i < len) { - realVariance(i) = - currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt - realVariance(i) /= denominator + realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * + (weightSum - nnz(i)) / weightSum) / denominator i += 1 } } @@ -209,7 +224,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def numNonzeros: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") Vectors.dense(nnz) } @@ -220,11 +235,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def max: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") var i = 0 while (i < n) { - if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 + if ((nnz(i) < weightSum) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } Vectors.dense(currMax) @@ -236,11 +251,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def min: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") var i = 0 while (i < n) { - if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 + if ((nnz(i) < weightSum) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } Vectors.dense(currMin) @@ -252,7 +267,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.2.0") override def normL2: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realMagnitude = Array.ofDim[Double](n) @@ -271,7 +286,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.2.0") override def normL1: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") Vectors.dense(currL1) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index cce39f382f738..f5219f9f574be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.ml.classification +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -59,8 +62,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) - sqlContext.createDataFrame( - generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)) + sqlContext.createDataFrame(sc.parallelize(testData, 4)) } } @@ -77,6 +79,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lr.getPredictionCol === "prediction") assert(lr.getRawPredictionCol === "rawPrediction") assert(lr.getProbabilityCol === "probability") + assert(lr.getWeightCol === "") assert(lr.getFitIntercept) assert(lr.getStandardization) val model = lr.fit(dataset) @@ -216,43 +219,65 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("MultiClassSummarizer") { val summarizer1 = (new MultiClassSummarizer) .add(0.0).add(3.0).add(4.0).add(3.0).add(6.0) - assert(summarizer1.histogram.zip(Array[Long](1, 0, 0, 2, 1, 0, 1)).forall(x => x._1 === x._2)) + assert(summarizer1.histogram === Array[Double](1, 0, 0, 2, 1, 0, 1)) assert(summarizer1.countInvalid === 0) assert(summarizer1.numClasses === 7) val summarizer2 = (new MultiClassSummarizer) .add(1.0).add(5.0).add(3.0).add(0.0).add(4.0).add(1.0) - assert(summarizer2.histogram.zip(Array[Long](1, 2, 0, 1, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizer2.histogram === Array[Double](1, 2, 0, 1, 1, 1)) assert(summarizer2.countInvalid === 0) assert(summarizer2.numClasses === 6) val summarizer3 = (new MultiClassSummarizer) .add(0.0).add(1.3).add(5.2).add(2.5).add(2.0).add(4.0).add(4.0).add(4.0).add(1.0) - assert(summarizer3.histogram.zip(Array[Long](1, 1, 1, 0, 3)).forall(x => x._1 === x._2)) + assert(summarizer3.histogram === Array[Double](1, 1, 1, 0, 3)) assert(summarizer3.countInvalid === 3) assert(summarizer3.numClasses === 5) val summarizer4 = (new MultiClassSummarizer) .add(3.1).add(4.3).add(2.0).add(1.0).add(3.0) - assert(summarizer4.histogram.zip(Array[Long](0, 1, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizer4.histogram === Array[Double](0, 1, 1, 1)) assert(summarizer4.countInvalid === 2) assert(summarizer4.numClasses === 4) // small map merges large one val summarizerA = summarizer1.merge(summarizer2) assert(summarizerA.hashCode() === summarizer2.hashCode()) - assert(summarizerA.histogram.zip(Array[Long](2, 2, 0, 3, 2, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizerA.histogram === Array[Double](2, 2, 0, 3, 2, 1, 1)) assert(summarizerA.countInvalid === 0) assert(summarizerA.numClasses === 7) // large map merges small one val summarizerB = summarizer3.merge(summarizer4) assert(summarizerB.hashCode() === summarizer3.hashCode()) - assert(summarizerB.histogram.zip(Array[Long](1, 2, 2, 1, 3)).forall(x => x._1 === x._2)) + assert(summarizerB.histogram === Array[Double](1, 2, 2, 1, 3)) assert(summarizerB.countInvalid === 5) assert(summarizerB.numClasses === 5) } + test("MultiClassSummarizer with weighted samples") { + val summarizer1 = (new MultiClassSummarizer) + .add(label = 0.0, weight = 0.2).add(3.0, 0.8).add(4.0, 3.2).add(3.0, 1.3).add(6.0, 3.1) + assert(Vectors.dense(summarizer1.histogram) ~== + Vectors.dense(Array(0.2, 0, 0, 2.1, 3.2, 0, 3.1)) absTol 1E-10) + assert(summarizer1.countInvalid === 0) + assert(summarizer1.numClasses === 7) + + val summarizer2 = (new MultiClassSummarizer) + .add(1.0, 1.1).add(5.0, 2.3).add(3.0).add(0.0).add(4.0).add(1.0).add(2, 0.0) + assert(Vectors.dense(summarizer2.histogram) ~== + Vectors.dense(Array[Double](1.0, 2.1, 0.0, 1, 1, 2.3)) absTol 1E-10) + assert(summarizer2.countInvalid === 0) + assert(summarizer2.numClasses === 6) + + val summarizer = summarizer1.merge(summarizer2) + assert(Vectors.dense(summarizer.histogram) ~== + Vectors.dense(Array(1.2, 2.1, 0.0, 3.1, 4.2, 2.3, 3.1)) absTol 1E-10) + assert(summarizer.countInvalid === 0) + assert(summarizer.numClasses === 7) + } + test("binary logistic regression with intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true) val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false) @@ -713,7 +738,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble) + val interceptTheory = math.log(histogram(1) / histogram(0)) val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptTheory relTol 1E-5) @@ -781,4 +806,63 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .forall(x => x(0) >= x(1))) } + + test("binary logistic regression with weighted samples") { + val (dataset, weightedDataset) = { + val nPoints = 1000 + val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) + + // Let's over-sample the positive samples twice. + val data1 = testData.flatMap { case labeledPoint: LabeledPoint => + if (labeledPoint.label == 1.0) { + Iterator(labeledPoint, labeledPoint) + } else { + Iterator(labeledPoint) + } + } + + val rnd = new Random(8392) + val data2 = testData.flatMap { case LabeledPoint(label: Double, features: Vector) => + if (rnd.nextGaussian() > 0.0) { + if (label == 1.0) { + Iterator( + Instance(label, 1.2, features), + Instance(label, 0.8, features), + Instance(0.0, 0.0, features)) + } else { + Iterator( + Instance(label, 0.3, features), + Instance(1.0, 0.0, features), + Instance(label, 0.1, features), + Instance(label, 0.6, features)) + } + } else { + if (label == 1.0) { + Iterator(Instance(label, 2.0, features)) + } else { + Iterator(Instance(label, 1.0, features)) + } + } + } + + (sqlContext.createDataFrame(sc.parallelize(data1, 4)), + sqlContext.createDataFrame(sc.parallelize(data2, 4))) + } + + val trainer1a = (new LogisticRegression).setFitIntercept(true) + .setRegParam(0.0).setStandardization(true) + val trainer1b = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") + .setRegParam(0.0).setStandardization(true) + val model1a0 = trainer1a.fit(dataset) + val model1a1 = trainer1a.fit(weightedDataset) + val model1b = trainer1b.fit(weightedDataset) + assert(model1a0.weights !~= model1a1.weights absTol 1E-3) + assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) + assert(model1a0.weights ~== model1b.weights absTol 1E-3) + assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 07efde4f5e6dc..b6d41db69be0a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -218,4 +218,31 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { s0.merge(s1) assert(s0.mean(0) ~== 1.0 absTol 1e-14) } + + test("merging summarizer with weighted samples") { + val summarizer = (new MultivariateOnlineSummarizer) + .add(instance = Vectors.sparse(3, Seq((0, -0.8), (1, 1.7))), weight = 0.1) + .add(Vectors.dense(0.0, -1.2, -1.7), 0.2).merge( + (new MultivariateOnlineSummarizer) + .add(Vectors.sparse(3, Seq((0, -0.7), (1, 0.01), (2, 1.3))), 0.15) + .add(Vectors.dense(-0.5, 0.3, -1.5), 0.05)) + + assert(summarizer.count === 4) + + // The following values are hand calculated using the formula: + // [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]] + // which defines the reliability weight used for computing the unbiased estimation of variance + // for weighted instances. + assert(summarizer.mean ~== Vectors.dense(Array(-0.42, -0.107, -0.44)) + absTol 1E-10, "mean mismatch") + assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857)) + absTol 1E-8, "variance mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(Array(0.3, 0.5, 0.4)) + absTol 1E-10, "numNonzeros mismatch") + assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch") + assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch") + assert(summarizer.normL2 ~== Vectors.dense(0.387298335, 0.762571308141, 0.9715966241192) + absTol 1E-8, "normL2 mismatch") + assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch") + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 87b141cd3b058..46026c1e90ea0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,15 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.execution") ) ++ MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ - MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") + MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ + Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticCostFun.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticAggregator.add"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticAggregator.count") + ) case v if v.startsWith("1.5") => Seq( MimaBuild.excludeSparkPackage("network"), From f562d122a3f4fac0a880a04d2b47851a282acd44 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 16:45:47 -0700 Subject: [PATCH 0018/1168] [SPARK-10548] [SPARK-10563] [SQL] Fix concurrent SQL executions *Note: this is for master branch only.* The fix for branch-1.5 is at #8721. The query execution ID is currently passed from a thread to its children, which is not the intended behavior. This led to `IllegalArgumentException: spark.sql.execution.id is already set` when running queries in parallel, e.g.: ``` (1 to 100).par.foreach { _ => sc.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count() } ``` The cause is `SparkContext`'s local properties are inherited by default. This patch adds a way to exclude keys we don't want to be inherited, and makes SQL go through that code path. Author: Andrew Or Closes #8710 from andrewor14/concurrent-sql-executions. --- .../scala/org/apache/spark/SparkContext.scala | 9 +- .../org/apache/spark/ThreadingSuite.scala | 65 +++++------ .../sql/execution/SQLExecutionSuite.scala | 101 ++++++++++++++++++ 3 files changed, 132 insertions(+), 43 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index dee6091ce3caf..a2f34eafa2c38 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -33,6 +33,7 @@ import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} import scala.util.control.NonFatal +import org.apache.commons.lang.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, @@ -347,8 +348,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] var checkpointDir: Option[String] = None // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) + protected[spark] val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = { + // Note: make a clone such that changes in the parent properties aren't reflected in + // the those of the children threads, which has confusing semantics (SPARK-10563). + SerializationUtils.clone(parent).asInstanceOf[Properties] + } override protected def initialValue(): Properties = new Properties() } diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index a96a4ce201c21..54c131cdae367 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -147,7 +147,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { }.start() } sem.acquire(2) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } if (ThreadingSuiteState.failed.get()) { logError("Waited 1 second without seeing runningThreads = 4 (it was " + ThreadingSuiteState.runningThreads.get() + "); failing test") @@ -178,7 +178,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === null) } @@ -207,58 +207,41 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) } - test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { - val jobStarted = new Semaphore(0) - val jobEnded = new Semaphore(0) - @volatile var jobResult: JobResult = null - var throwable: Option[Throwable] = None - + test("mutation in parent local property does not affect child (SPARK-10563)") { sc = new SparkContext("local", "test") - sc.setJobGroup("originalJobGroupId", "description") - sc.addSparkListener(new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - jobStarted.release() - } - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - jobResult = jobEnd.jobResult - jobEnded.release() - } - }) - - // Create a new thread which will inherit the current thread's properties - val thread = new Thread() { + val originalTestValue: String = "original-value" + var threadTestValue: String = null + sc.setLocalProperty("test", originalTestValue) + var throwable: Option[Throwable] = None + val thread = new Thread { override def run(): Unit = { try { - assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") - // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task - try { - sc.parallelize(1 to 100).foreach { x => - Thread.sleep(100) - } - } catch { - case s: SparkException => // ignored so that we don't print noise in test logs - } + threadTestValue = sc.getLocalProperty("test") } catch { case t: Throwable => throwable = Some(t) } } } + sc.setLocalProperty("test", "this-should-not-be-inherited") thread.start() - // Wait for the job to start, then mutate the original properties, which should have been - // inherited by the running job but hopefully defensively copied or snapshotted: - jobStarted.tryAcquire(10, TimeUnit.SECONDS) - sc.setJobGroup("modifiedJobGroupId", "description") - // Canceling the original job group should cancel the running job. In other words, the - // modification of the properties object should not affect the properties of running jobs - sc.cancelJobGroup("originalJobGroupId") - jobEnded.tryAcquire(10, TimeUnit.SECONDS) - throwable.foreach { t => throw t } - assert(jobResult.isInstanceOf[JobFailed]) + thread.join() + throwable.foreach { t => throw improveStackTrace(t) } + assert(threadTestValue === originalTestValue) } + + /** + * Improve the stack trace of an error thrown from within a thread. + * Otherwise it's difficult to tell which line in the test the error came from. + */ + private def improveStackTrace(t: Throwable): Throwable = { + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + t + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala new file mode 100644 index 0000000000000..63639681ef80a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.Properties + +import scala.collection.parallel.CompositeThrowable + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.SQLContext + +class SQLExecutionSuite extends SparkFunSuite { + + test("concurrent query execution (SPARK-10548)") { + // Try to reproduce the issue with the old SparkContext + val conf = new SparkConf() + .setMaster("local[*]") + .setAppName("test") + val badSparkContext = new BadSparkContext(conf) + try { + testConcurrentQueryExecution(badSparkContext) + fail("unable to reproduce SPARK-10548") + } catch { + case e: IllegalArgumentException => + assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) + } finally { + badSparkContext.stop() + } + + // Verify that the issue is fixed with the latest SparkContext + val goodSparkContext = new SparkContext(conf) + try { + testConcurrentQueryExecution(goodSparkContext) + } finally { + goodSparkContext.stop() + } + } + + /** + * Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently. + */ + private def testConcurrentQueryExecution(sc: SparkContext): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Initialize local properties. This is necessary for the test to pass. + sc.getLocalProperties + + // Set up a thread that runs executes a simple SQL query. + // Before starting the thread, mutate the execution ID in the parent. + // The child thread should not see the effect of this change. + var throwable: Option[Throwable] = None + val child = new Thread { + override def run(): Unit = { + try { + sc.parallelize(1 to 100).map { i => (i, i) }.toDF("a", "b").collect() + } catch { + case t: Throwable => + throwable = Some(t) + } + + } + } + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, "anything") + child.start() + child.join() + + // The throwable is thrown from the child thread so it doesn't have a helpful stack trace + throwable.foreach { t => + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + throw t + } + } + +} + +/** + * A bad [[SparkContext]] that does not clone the inheritable thread local properties + * when passing them to children threads. + */ +private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { + protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() + } +} From c102fc55f1965b626923cebee3808bd31fe7e2fa Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 15 Sep 2015 16:53:27 -0700 Subject: [PATCH 0019/1168] [SPARK-10612] [SQL] Add prepare to LocalNode. The idea is that we should separate the function call that does memory reservation (i.e. prepare) from the function call that consumes the input (e.g. open()), so all operators can be a chance to reserve memory before they are all consumed. Author: Reynold Xin Closes #8761 from rxin/SPARK-10612. --- .../org/apache/spark/sql/execution/local/LocalNode.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 9840080e16953..569cff565c092 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -45,6 +45,14 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging def output: Seq[Attribute] + /** + * Called before open(). Prepare can be used to reserve memory needed. It must NOT consume + * any input data. + * + * Implementations of this must also call the `prepare()` function of its children. + */ + def prepare(): Unit = children.foreach(_.prepare()) + /** * Initializes the iterator state. Must be called before calling `next()`. * From 419185d7d2d43d39c14b34f0010886a23cbfe9aa Mon Sep 17 00:00:00 2001 From: vinodkc Date: Tue, 15 Sep 2015 17:01:10 -0700 Subject: [PATCH 0020/1168] [SPARK-10575] [SPARK CORE] Wrapped RDD.takeSample with Scope Remove return statements in RDD.takeSample and wrap it withScope Author: vinodkc Author: vinodkc Author: Vinod K C Closes #8730 from vinodkc/fix_takesample_return. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 68 +++++++++---------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 7dd2bc5d7cd72..a56e542242d5f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -469,50 +469,44 @@ abstract class RDD[T: ClassTag]( * @param seed seed for the random number generator * @return sample of specified size in an array */ - // TODO: rewrite this without return statements so we can wrap it in a scope def takeSample( withReplacement: Boolean, num: Int, - seed: Long = Utils.random.nextLong): Array[T] = { + seed: Long = Utils.random.nextLong): Array[T] = withScope { val numStDev = 10.0 - if (num < 0) { - throw new IllegalArgumentException("Negative number of elements requested") - } else if (num == 0) { - return new Array[T](0) - } - - val initialCount = this.count() - if (initialCount == 0) { - return new Array[T](0) - } - - val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt - if (num > maxSampleSize) { - throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + - s"$numStDev * math.sqrt(Int.MaxValue)") - } - - val rand = new Random(seed) - if (!withReplacement && num >= initialCount) { - return Utils.randomizeInPlace(this.collect(), rand) - } - - val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, - withReplacement) - - var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + require(num >= 0, "Negative number of elements requested") + require(num <= (Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt), + "Cannot support a sample size > Int.MaxValue - " + + s"$numStDev * math.sqrt(Int.MaxValue)") - // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for the initial size - var numIters = 0 - while (samples.length < num) { - logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") - samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - numIters += 1 + if (num == 0) { + new Array[T](0) + } else { + val initialCount = this.count() + if (initialCount == 0) { + new Array[T](0) + } else { + val rand = new Random(seed) + if (!withReplacement && num >= initialCount) { + Utils.randomizeInPlace(this.collect(), rand) + } else { + val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, + withReplacement) + var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + + // If the first sample didn't turn out large enough, keep trying to take samples; + // this shouldn't happen often because we use a big multiplier for the initial size + var numIters = 0 + while (samples.length < num) { + logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") + samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + numIters += 1 + } + Utils.randomizeInPlace(samples, rand).take(num) + } + } } - - Utils.randomizeInPlace(samples, rand).take(num) } /** From 98bf753adb1e173586a198f571e9ac86f5719a43 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 15 Sep 2015 17:11:21 -0700 Subject: [PATCH 0021/1168] [SPARK-10381] Fix mixup of taskAttemptNumber & attemptId in OutputCommitCoordinator When speculative execution is enabled, consider a scenario where the authorized committer of a particular output partition fails during the OutputCommitter.commitTask() call. In this case, the OutputCommitCoordinator is supposed to release that committer's exclusive lock on committing once that task fails. However, due to a unit mismatch (we used task attempt number in one place and task attempt id in another) the lock will not be released, causing Spark to go into an infinite retry loop. This bug was masked by the fact that the OutputCommitCoordinator does not have enough end-to-end tests (the current tests use many mocks). Other factors contributing to this bug are the fact that we have many similarly-named identifiers that have different semantics but the same data types (e.g. attemptNumber and taskAttemptId, with inconsistent variable naming which makes them difficult to distinguish). This patch adds a regression test and fixes this bug by always using task attempt numbers throughout this code. Author: Josh Rosen Closes #8544 from JoshRosen/SPARK-10381. --- .../org/apache/spark/SparkHadoopWriter.scala | 3 +- .../org/apache/spark/TaskEndReason.scala | 7 +- .../executor/CommitDeniedException.scala | 4 +- .../spark/mapred/SparkHadoopMapRedUtil.scala | 20 ++---- .../apache/spark/scheduler/DAGScheduler.scala | 7 +- .../scheduler/OutputCommitCoordinator.scala | 48 +++++++------ .../org/apache/spark/scheduler/TaskInfo.scala | 7 +- .../status/api/v1/AllStagesResource.scala | 2 +- .../org/apache/spark/ui/jobs/StagePage.scala | 4 +- .../org/apache/spark/util/JsonProtocol.scala | 2 +- ...putCommitCoordinatorIntegrationSuite.scala | 68 +++++++++++++++++++ .../OutputCommitCoordinatorSuite.scala | 24 ++++--- .../apache/spark/util/JsonProtocolSuite.scala | 2 +- project/MimaExcludes.scala | 36 +++++++++- .../datasources/WriterContainer.scala | 3 +- .../sql/execution/ui/SQLListenerSuite.scala | 4 +- .../spark/sql/hive/hiveWriterContainers.scala | 2 +- 17 files changed, 174 insertions(+), 69 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index ae5926dd534a6..ac6eaab20d8d2 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -104,8 +104,7 @@ class SparkHadoopWriter(jobConf: JobConf) } def commit() { - SparkHadoopMapRedUtil.commitTask( - getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID) + SparkHadoopMapRedUtil.commitTask(getOutputCommitter(), getTaskContext(), jobID, splitID) } def commitJob() { diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 2ae878b3e6087..7137246bc34f2 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -193,9 +193,12 @@ case object TaskKilled extends TaskFailedReason { * Task requested the driver to commit, but was denied. */ @DeveloperApi -case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason { +case class TaskCommitDenied( + jobID: Int, + partitionID: Int, + attemptNumber: Int) extends TaskFailedReason { override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + - s" for job: $jobID, partition: $partitionID, attempt: $attemptID" + s" for job: $jobID, partition: $partitionID, attemptNumber: $attemptNumber" /** * If a task failed because its attempt to commit was denied, do not count this failure * towards failing the stage. This is intended to prevent spurious stage failures in cases diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala index f47d7ef511da1..7d84889a2def0 100644 --- a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -26,8 +26,8 @@ private[spark] class CommitDeniedException( msg: String, jobID: Int, splitID: Int, - attemptID: Int) + attemptNumber: Int) extends Exception(msg) { - def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID) + def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptNumber) } diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index f405b732e4725..f7298e8d5c62c 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -91,8 +91,7 @@ object SparkHadoopMapRedUtil extends Logging { committer: MapReduceOutputCommitter, mrTaskContext: MapReduceTaskAttemptContext, jobId: Int, - splitId: Int, - attemptId: Int): Unit = { + splitId: Int): Unit = { val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) @@ -122,7 +121,8 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId) + val taskAttemptNumber = TaskContext.get().attemptNumber() + val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, taskAttemptNumber) if (canCommit) { performCommit() @@ -132,7 +132,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, jobId, splitId, attemptId) + throw new CommitDeniedException(message, jobId, splitId, taskAttemptNumber) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination @@ -143,16 +143,4 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") } } - - def commitTask( - committer: MapReduceOutputCommitter, - mrTaskContext: MapReduceTaskAttemptContext, - sparkTaskContext: TaskContext): Unit = { - commitTask( - committer, - mrTaskContext, - sparkTaskContext.stageId(), - sparkTaskContext.partitionId(), - sparkTaskContext.attemptNumber()) - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b4f90e8347894..3c9a66e504403 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1128,8 +1128,11 @@ class DAGScheduler( val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) - outputCommitCoordinator.taskCompleted(stageId, task.partitionId, - event.taskInfo.attempt, event.reason) + outputCommitCoordinator.taskCompleted( + stageId, + task.partitionId, + event.taskInfo.attemptNumber, // this is a task attempt number + event.reason) // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 5d926377ce86b..add0dedc03f44 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -25,7 +25,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint private sealed trait OutputCommitCoordinationMessage extends Serializable private case object StopCoordinator extends OutputCommitCoordinationMessage -private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long) +private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int) /** * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" @@ -44,8 +44,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) var coordinatorRef: Option[RpcEndpointRef] = None private type StageId = Int - private type PartitionId = Long - private type TaskAttemptId = Long + private type PartitionId = Int + private type TaskAttemptNumber = Int /** * Map from active stages's id => partition id => task attempt with exclusive lock on committing @@ -57,7 +57,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() - private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + private type CommittersByStageMap = + mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptNumber]] /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -75,14 +76,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * @param stage the stage number * @param partition the partition number - * @param attempt a unique identifier for this task attempt + * @param attemptNumber how many times this task has been attempted + * (see [[TaskContext.attemptNumber()]]) * @return true if this task is authorized to commit, false otherwise */ def canCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = { - val msg = AskPermissionToCommitOutput(stage, partition, attempt) + attemptNumber: TaskAttemptNumber): Boolean = { + val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => endpointRef.askWithRetry[Boolean](msg) @@ -95,7 +97,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Called by DAGScheduler private[scheduler] def stageStart(stage: StageId): Unit = synchronized { - authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]() + authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptNumber]() } // Called by DAGScheduler @@ -107,7 +109,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def taskCompleted( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId, + attemptNumber: TaskAttemptNumber, reason: TaskEndReason): Unit = synchronized { val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") @@ -117,12 +119,12 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) case Success => // The task output has been committed successfully case denied: TaskCommitDenied => - logInfo( - s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") + logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + + s"attempt: $attemptNumber") case otherReason => - if (authorizedCommitters.get(partition).exists(_ == attempt)) { - logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + - s" clearing lock") + if (authorizedCommitters.get(partition).exists(_ == attemptNumber)) { + logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + + s"partition=$partition) failed; clearing lock") authorizedCommitters.remove(partition) } } @@ -140,21 +142,23 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def handleAskPermissionToCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = synchronized { + attemptNumber: TaskAttemptNumber): Boolean = synchronized { authorizedCommittersByStage.get(stage) match { case Some(authorizedCommitters) => authorizedCommitters.get(partition) match { case Some(existingCommitter) => - logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " + - s"existingCommitter = $existingCommitter") + logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition; existingCommitter = $existingCommitter") false case None => - logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition") - authorizedCommitters(partition) = attempt + logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition") + authorizedCommitters(partition) = attemptNumber true } case None => - logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit") + logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" + + s"partition $partition to commit") false } } @@ -174,9 +178,9 @@ private[spark] object OutputCommitCoordinator { } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, taskAttempt) => + case AskPermissionToCommitOutput(stage, partition, attemptNumber) => context.reply( - outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) + outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 132a9ced77700..f113c2b1b8433 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -29,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi class TaskInfo( val taskId: Long, val index: Int, - val attempt: Int, + val attemptNumber: Int, val launchTime: Long, val executorId: String, val host: String, @@ -95,7 +95,10 @@ class TaskInfo( } } - def id: String = s"$index.$attempt" + @deprecated("Use attemptNumber", "1.6.0") + def attempt: Int = attemptNumber + + def id: String = s"$index.$attemptNumber" def duration: Long = { if (!finished) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 390c136df79b3..24a0b5220695c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -127,7 +127,7 @@ private[v1] object AllStagesResource { new TaskData( taskId = uiData.taskInfo.taskId, index = uiData.taskInfo.index, - attempt = uiData.taskInfo.attempt, + attempt = uiData.taskInfo.attemptNumber, launchTime = new Date(uiData.taskInfo.launchTime), executorId = uiData.taskInfo.executorId, host = uiData.taskInfo.host, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2b71f55b7bb4f..712782d27b3cf 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -621,7 +621,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { serializationTimeProportionPos + serializationTimeProportion val index = taskInfo.index - val attempt = taskInfo.attempt + val attempt = taskInfo.attemptNumber val svgTag = if (totalExecutionTime == 0) { @@ -967,7 +967,7 @@ private[ui] class TaskDataSource( new TaskTableRowData( info.index, info.taskId, - info.attempt, + info.attemptNumber, info.speculative, info.status, info.taskLocality.toString, diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 24f78744ad74c..99614a786bd93 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -266,7 +266,7 @@ private[spark] object JsonProtocol { def taskInfoToJson(taskInfo: TaskInfo): JValue = { ("Task ID" -> taskInfo.taskId) ~ ("Index" -> taskInfo.index) ~ - ("Attempt" -> taskInfo.attempt) ~ + ("Attempt" -> taskInfo.attemptNumber) ~ ("Launch Time" -> taskInfo.launchTime) ~ ("Executor ID" -> taskInfo.executorId) ~ ("Host" -> taskInfo.host) ~ diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala new file mode 100644 index 0000000000000..1ae5b030f0832 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.{Span, Seconds} + +import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext} +import org.apache.spark.util.Utils + +/** + * Integration tests for the OutputCommitCoordinator. + * + * See also: [[OutputCommitCoordinatorSuite]] for unit tests that use mocks. + */ +class OutputCommitCoordinatorIntegrationSuite + extends SparkFunSuite + with LocalSparkContext + with Timeouts { + + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + .set("master", "local[2,4]") + .set("spark.speculation", "true") + .set("spark.hadoop.mapred.output.committer.class", + classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName) + sc = new SparkContext("local[2, 4]", "test", conf) + } + + test("exception thrown in OutputCommitter.commitTask()") { + // Regression test for SPARK-10381 + failAfter(Span(60, Seconds)) { + val tempDir = Utils.createTempDir() + try { + sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out") + } finally { + Utils.deleteRecursively(tempDir) + } + } + } +} + +private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter { + override def commitTask(context: TaskAttemptContext): Unit = { + val ctx = TaskContext.get() + if (ctx.attemptNumber < 1) { + throw new java.io.FileNotFoundException("Intentional exception") + } + super.commitTask(context) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index e5ecd4b7c2610..6d08d7c5b7d2a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -63,6 +63,9 @@ import scala.language.postfixOps * was not in SparkHadoopWriter, the tests would still pass because only one of the * increments would be captured even though the commit in both tasks was executed * erroneously. + * + * See also: [[OutputCommitCoordinatorIntegrationSuite]] for integration tests that do + * not use mocks. */ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { @@ -164,27 +167,28 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { val stage: Int = 1 - val partition: Long = 2 - val authorizedCommitter: Long = 3 - val nonAuthorizedCommitter: Long = 100 + val partition: Int = 2 + val authorizedCommitter: Int = 3 + val nonAuthorizedCommitter: Int = 100 outputCommitCoordinator.stageStart(stage) - assert(outputCommitCoordinator.canCommit(stage, partition, attempt = authorizedCommitter)) - assert(!outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter)) + + assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) // The non-authorized committer fails outputCommitCoordinator.taskCompleted( - stage, partition, attempt = nonAuthorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled) // New tasks should still not be able to commit because the authorized committer has not failed assert( - !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 1)) + !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock outputCommitCoordinator.taskCompleted( - stage, partition, attempt = authorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled) // A new task should now be allowed to become the authorized committer assert( - outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 2)) + outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) // There can only be one authorized committer assert( - !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 3)) + !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3)) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 47e548ef0d442..143c1b901df11 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -499,7 +499,7 @@ class JsonProtocolSuite extends SparkFunSuite { private def assertEquals(info1: TaskInfo, info2: TaskInfo) { assert(info1.taskId === info2.taskId) assert(info1.index === info2.index) - assert(info1.attempt === info2.attempt) + assert(info1.attemptNumber === info2.attemptNumber) assert(info1.launchTime === info2.launchTime) assert(info1.executorId === info2.executorId) assert(info1.host === info2.host) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 46026c1e90ea0..1c96b0958586f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,7 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.execution") ) ++ MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ - MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ + MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ Seq( ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.classification.LogisticCostFun.this"), @@ -53,6 +53,23 @@ object MimaExcludes { "org.apache.spark.ml.classification.LogisticAggregator.add"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.classification.LogisticAggregator.count") + ) ++ Seq( + // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. + // This class is marked as `private` but MiMa still seems to be confused by the change. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") ) case v if v.startsWith("1.5") => Seq( @@ -213,6 +230,23 @@ object MimaExcludes { // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.mllib.linalg.VectorUDT.serialize") + ) ++ Seq( + // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. + // This class is marked as `private` but MiMa still seems to be confused by the change. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") ) case v if v.startsWith("1.4") => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index f8ef674ed29c1..cfd64c1d9eb34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -198,8 +198,7 @@ private[sql] abstract class BaseWriterContainer( } def commitTask(): Unit = { - SparkHadoopMapRedUtil.commitTask( - outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId) + SparkHadoopMapRedUtil.commitTask(outputCommitter, taskAttemptContext, jobId.getId, taskId.getId) } def abortTask(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 2bbb41ca777b7..7a46c69a056b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -54,9 +54,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { details = "" ) - private def createTaskInfo(taskId: Int, attempt: Int): TaskInfo = new TaskInfo( + private def createTaskInfo(taskId: Int, attemptNumber: Int): TaskInfo = new TaskInfo( taskId = taskId, - attempt = attempt, + attemptNumber = attemptNumber, // The following fields are not used in tests index = 0, launchTime = 0, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 4ca8042d22367..c8d6b718045a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -121,7 +121,7 @@ private[hive] class SparkHiveWriterContainer( } protected def commit() { - SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID, attemptID) + SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID) } private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { From 476a259a222d0d3fca72fe26d1e4e6ada3b9bbfc Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 17:24:32 -0700 Subject: [PATCH 0022/1168] [SPARK-10613] [SPARK-10624] [SQL] Reduce LocalNode tests dependency on SQLContext Instead of relying on `DataFrames` to verify our answers, we can just use simple arrays. This significantly simplifies the test logic for `LocalNode`s and reduces a lot of code duplicated from `SparkPlanTest`. This also fixes an additional issue [SPARK-10624](https://issues.apache.org/jira/browse/SPARK-10624) where the output of `TakeOrderedAndProjectNode` is not actually ordered. Author: Andrew Or Closes #8764 from andrewor14/sql-local-tests-cleanup. --- .../spark/sql/execution/local/LocalNode.scala | 8 +- .../sql/execution/local/SampleNode.scala | 16 +- .../local/TakeOrderedAndProjectNode.scala | 2 +- .../spark/sql/execution/SparkPlanTest.scala | 2 +- .../spark/sql/execution/local/DummyNode.scala | 68 ++++ .../sql/execution/local/ExpandNodeSuite.scala | 54 ++- .../sql/execution/local/FilterNodeSuite.scala | 34 +- .../execution/local/HashJoinNodeSuite.scala | 141 ++++---- .../execution/local/IntersectNodeSuite.scala | 24 +- .../sql/execution/local/LimitNodeSuite.scala | 28 +- .../sql/execution/local/LocalNodeSuite.scala | 73 +--- .../sql/execution/local/LocalNodeTest.scala | 165 ++------- .../local/NestedLoopJoinNodeSuite.scala | 316 ++++++------------ .../execution/local/ProjectNodeSuite.scala | 39 ++- .../sql/execution/local/SampleNodeSuite.scala | 35 +- .../TakeOrderedAndProjectNodeSuite.scala | 50 ++- .../sql/execution/local/UnionNodeSuite.scala | 49 +-- 17 files changed, 468 insertions(+), 636 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 569cff565c092..f96b62a67a254 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{SQLConf, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.types.StructType /** @@ -33,18 +33,14 @@ import org.apache.spark.sql.types.StructType * Before consuming the iterator, open function must be called. * After consuming the iterator, close function must be called. */ -abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging { +abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging { protected val codegenEnabled: Boolean = conf.codegenEnabled protected val unsafeEnabled: Boolean = conf.unsafeEnabled - lazy val schema: StructType = StructType.fromAttributes(output) - private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing") - def output: Seq[Attribute] - /** * Called before open(). Prepare can be used to reserve memory needed. It must NOT consume * any input data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala index abf3df1c0c2af..793700803f216 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.local -import java.util.Random - import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + /** * Sample the dataset. * @@ -51,18 +50,15 @@ case class SampleNode( override def open(): Unit = { child.open() - val (sampler, _seed) = if (withReplacement) { - val random = new Random(seed) + val sampler = + if (withReplacement) { // Disable gap sampling since the gap sampling method buffers two rows internally, // requiring us to copy the row, which is more expensive than the random number generator. - (new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), - // Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result - // of DataFrame - random.nextLong()) + new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false) } else { - (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed) + new BernoulliCellSampler[InternalRow](lowerBound, upperBound) } - sampler.setSeed(_seed) + sampler.setSeed(seed) iterator = sampler.sample(child.asIterator) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala index 53f1dcc65d8cf..ae672fbca8d83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -50,7 +50,7 @@ case class TakeOrderedAndProjectNode( } // Close it eagerly since we don't need it. child.close() - iterator = queue.iterator + iterator = queue.toArray.sorted(ord).iterator } override def next(): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index de45ae4635fb7..3d218f01c9ead 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -238,7 +238,7 @@ object SparkPlanTest { outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap - plan.transformExpressions { + plan transformExpressions { case UnresolvedAttribute(Seq(u)) => inputMap.getOrElse(u, sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala new file mode 100644 index 0000000000000..efc3227dd60d8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala @@ -0,0 +1,68 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +/** + * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]]. + */ +private[local] case class DummyNode( + output: Seq[Attribute], + relation: LocalRelation, + conf: SQLConf) + extends LocalNode(conf) { + + import DummyNode._ + + private var index: Int = CLOSED + private val input: Seq[InternalRow] = relation.data + + def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) { + this(output, LocalRelation.fromProduct(output, data), conf) + } + + def isOpen: Boolean = index != CLOSED + + override def children: Seq[LocalNode] = Seq.empty + + override def open(): Unit = { + index = -1 + } + + override def next(): Boolean = { + index += 1 + index < input.size + } + + override def fetch(): InternalRow = { + assert(index >= 0 && index < input.size) + input(index) + } + + override def close(): Unit = { + index = CLOSED + } +} + +private object DummyNode { + val CLOSED: Int = Int.MinValue +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala index cfa7f3f6dcb97..bbd94d8da2d11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala @@ -17,35 +17,33 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.catalyst.dsl.expressions._ + + class ExpandNodeSuite extends LocalNodeTest { - import testImplicits._ - - test("expand") { - val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value") - checkAnswer( - input, - node => - ExpandNode(conf, Seq( - Seq( - input.col("key") + input.col("value"), input.col("key") - input.col("value") - ).map(_.expr), - Seq( - input.col("key") * input.col("value"), input.col("key") / input.col("value") - ).map(_.expr) - ), node.output, node), - Seq( - (2, 0), - (1, 1), - (4, 0), - (4, 1), - (6, 0), - (9, 1), - (8, 0), - (16, 1), - (10, 0), - (25, 1) - ).toDF().collect() - ) + private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v)) + val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode) + val resolvedNode = resolveExpressions(expandNode) + val expectedOutput = { + val firstHalf = inputData.map { case (k, v) => (k + v, k - v) } + val secondHalf = inputData.map { case (k, v) => (k * v, k / v) } + firstHalf ++ secondHalf + } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput.toSet === expectedOutput.toSet) + } + + test("empty") { + testExpand() } + + test("basic") { + testExpand((1 to 100).map { i => (i, i * 1000) }.toArray) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala index a12670e347c25..4eadce646d379 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -17,25 +17,29 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.dsl.expressions._ -class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { - test("basic") { - val condition = (testData.col("key") % 2) === 0 - checkAnswer( - testData, - node => FilterNode(conf, condition.expr, node), - testData.filter(condition).collect() - ) +class FilterNodeSuite extends LocalNodeTest { + + private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val cond = 'k % 2 === 0 + val inputNode = new DummyNode(kvIntAttributes, inputData) + val filterNode = new FilterNode(conf, cond, inputNode) + val resolvedNode = resolveExpressions(filterNode) + val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - val condition = (emptyTestData.col("key") % 2) === 0 - checkAnswer( - emptyTestData, - node => FilterNode(conf, condition.expr, node), - emptyTestData.filter(condition).collect() - ) + testFilter() + } + + test("basic") { + testFilter((1 to 100).map { i => (i, i) }.toArray) } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index 78d891351f4a9..5c1bdb088eeed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -18,99 +18,80 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.execution.joins +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + class HashJoinNodeSuite extends LocalNodeTest { - import testImplicits._ + // Test all combinations of the two dimensions: with/out unsafe and build sides + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + testJoin(unsafeAndCodegen, buildSide) + } + } - def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = { - test(s"$suiteName: inner join with one match per row") { - withSQLConf(confPairs: _*) { - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(upperCaseData.col("N").expr), - Seq(lowerCaseData.col("n").expr), - joins.BuildLeft, - node1, - node2) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N").collect() - ) + /** + * Test inner hash join with varying degrees of matches. + */ + private def testJoin( + unsafeAndCodegen: Boolean, + buildSide: BuildSide): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) + + // Actual test body + def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { + val rightInputMap = rightInput.toMap + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions(new HashJoinNode( + conf, Seq('id1), Seq('id2), buildSide, node1, node2)) + } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = leftInput + .filter { case (k, _) => rightInputMap.contains(k) } + .map { case (k, v) => (k, v, k, rightInputMap(k)) } + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) } + assert(actualOutput === expectedOutput) } - test(s"$suiteName: inner join with multiple matches") { - withSQLConf(confPairs: _*) { - val x = testData2.where($"a" === 1).as("x") - val y = testData2.where($"a" === 1).as("y") - checkAnswer2( - x, - y, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(x.col("a").expr), - Seq(y.col("a").expr), - joins.BuildLeft, - node1, - node2) - ), - x.join(y).where($"x.a" === $"y.a").collect() - ) - } + test(s"$testNamePrefix: empty") { + runTest(Array.empty, Array.empty) + runTest(someData, Array.empty) + runTest(Array.empty, someData) } - test(s"$suiteName: inner join, no matches") { - withSQLConf(confPairs: _*) { - val x = testData2.where($"a" === 1).as("x") - val y = testData2.where($"a" === 2).as("y") - checkAnswer2( - x, - y, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(x.col("a").expr), - Seq(y.col("a").expr), - joins.BuildLeft, - node1, - node2) - ), - Nil - ) - } + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) }.toArray + runTest(someData, Array.empty) + runTest(Array.empty, someData) + runTest(someData, someIrrelevantData) + runTest(someIrrelevantData, someData) } - test(s"$suiteName: big inner join, 4 matches per row") { - withSQLConf(confPairs: _*) { - val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) - val bigDataX = bigData.as("x") - val bigDataY = bigData.as("y") + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(someData, someOtherData) + runTest(someOtherData, someData) + } - checkAnswer2( - bigDataX, - bigDataY, - wrapForUnsafe( - (node1, node2) => - HashJoinNode( - conf, - Seq(bigDataX.col("key").expr), - Seq(bigDataY.col("key").expr), - joins.BuildLeft, - node1, - node2) - ), - bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect()) - } + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }.toArray + runTest(someData, someSuperRelevantData) + runTest(someSuperRelevantData, someData) } } - joinSuite( - "general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala index 7deaa375fcfc2..c0ad2021b204a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -17,19 +17,21 @@ package org.apache.spark.sql.execution.local -class IntersectNodeSuite extends LocalNodeTest { - import testImplicits._ +class IntersectNodeSuite extends LocalNodeTest { test("basic") { - val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value") - - checkAnswer2( - input1, - input2, - (node1, node2) => IntersectNode(conf, node1, node2), - input1.intersect(input2).collect() - ) + val n = 100 + val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray + val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray + val leftNode = new DummyNode(kvIntAttributes, leftData) + val rightNode = new DummyNode(kvIntAttributes, rightData) + val intersectNode = new IntersectNode(conf, leftNode, rightNode) + val expectedOutput = leftData.intersect(rightData) + val actualOutput = intersectNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala index 3b183902007e4..fb790636a3689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -17,23 +17,25 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext -class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { +class LimitNodeSuite extends LocalNodeTest { - test("basic") { - checkAnswer( - testData, - node => LimitNode(conf, 10, node), - testData.limit(10).collect() - ) + private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val limitNode = new LimitNode(conf, limit, inputNode) + val expectedOutput = inputData.take(limit) + val actualOutput = limitNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - checkAnswer( - emptyTestData, - node => LimitNode(conf, 10, node), - emptyTestData.limit(10).collect() - ) + testLimit() } + + test("basic") { + testLimit((1 to 100).map { i => (i, i) }.toArray, 20) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala index b89fa46f8b3b4..0d1ed99eec6cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -17,28 +17,24 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.IntegerType -class LocalNodeSuite extends SparkFunSuite { - private val data = (1 to 100).toArray +class LocalNodeSuite extends LocalNodeTest { + private val data = (1 to 100).map { i => (i, i) }.toArray test("basic open, next, fetch, close") { - val node = new DummyLocalNode(data) + val node = new DummyNode(kvIntAttributes, data) assert(!node.isOpen) node.open() assert(node.isOpen) - data.foreach { i => + data.foreach { case (k, v) => assert(node.next()) // fetch should be idempotent val fetched = node.fetch() assert(node.fetch() === fetched) assert(node.fetch() === fetched) - assert(node.fetch().numFields === 1) - assert(node.fetch().getInt(0) === i) + assert(node.fetch().numFields === 2) + assert(node.fetch().getInt(0) === k) + assert(node.fetch().getInt(1) === v) } assert(!node.next()) node.close() @@ -46,16 +42,17 @@ class LocalNodeSuite extends SparkFunSuite { } test("asIterator") { - val node = new DummyLocalNode(data) + val node = new DummyNode(kvIntAttributes, data) val iter = node.asIterator node.open() - data.foreach { i => + data.foreach { case (k, v) => // hasNext should be idempotent assert(iter.hasNext) assert(iter.hasNext) val item = iter.next() - assert(item.numFields === 1) - assert(item.getInt(0) === i) + assert(item.numFields === 2) + assert(item.getInt(0) === k) + assert(item.getInt(1) === v) } intercept[NoSuchElementException] { iter.next() @@ -64,53 +61,13 @@ class LocalNodeSuite extends SparkFunSuite { } test("collect") { - val node = new DummyLocalNode(data) + val node = new DummyNode(kvIntAttributes, data) node.open() val collected = node.collect() assert(collected.size === data.size) - assert(collected.forall(_.size === 1)) - assert(collected.map(_.getInt(0)) === data) + assert(collected.forall(_.size === 2)) + assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data) node.close() } } - -/** - * A dummy [[LocalNode]] that just returns one row per integer in the input. - */ -private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) { - private var index = Int.MinValue - - def this(input: Array[Int]) { - this(new SQLConf, input) - } - - def isOpen: Boolean = { - index != Int.MinValue - } - - override def output: Seq[Attribute] = { - Seq(AttributeReference("something", IntegerType)()) - } - - override def children: Seq[LocalNode] = Seq.empty - - override def open(): Unit = { - index = -1 - } - - override def next(): Boolean = { - index += 1 - index < input.size - } - - override def fetch(): InternalRow = { - assert(index >= 0 && index < input.size) - val values = Array(input(index).asInstanceOf[Any]) - new GenericInternalRow(values) - } - - override def close(): Unit = { - index = Int.MinValue - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 86dd28064cc6a..098050bcd2236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -17,147 +17,54 @@ package org.apache.spark.sql.execution.local -import scala.util.control.NonFatal - import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SQLConf} -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.types.{IntegerType, StringType} -class LocalNodeTest extends SparkFunSuite with SharedSQLContext { - def conf: SQLConf = sqlContext.conf +class LocalNodeTest extends SparkFunSuite { - protected def wrapForUnsafe( - f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { - if (conf.unsafeEnabled) { - (left: LocalNode, right: LocalNode) => { - val _left = ConvertToUnsafeNode(conf, left) - val _right = ConvertToUnsafeNode(conf, right) - val r = f(_left, _right) - ConvertToSafeNode(conf, r) - } - } else { - f - } - } - - /** - * Runs the LocalNode and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate - * the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer( - input: DataFrame, - nodeFunction: LocalNode => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - doCheckAnswer( - input :: Nil, - nodes => nodeFunction(nodes.head), - expectedAnswer, - sortAnswers) - } - - /** - * Runs the LocalNode and makes sure the answer matches the expected result. - * @param left the left input data to be used. - * @param right the right input data to be used. - * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate - * the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer2( - left: DataFrame, - right: DataFrame, - nodeFunction: (LocalNode, LocalNode) => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - doCheckAnswer( - left :: right :: Nil, - nodes => nodeFunction(nodes(0), nodes(1)), - expectedAnswer, - sortAnswers) - } + protected val conf: SQLConf = new SQLConf + protected val kvIntAttributes = Seq( + AttributeReference("k", IntegerType)(), + AttributeReference("v", IntegerType)()) + protected val joinNameAttributes = Seq( + AttributeReference("id1", IntegerType)(), + AttributeReference("name", StringType)()) + protected val joinNicknameAttributes = Seq( + AttributeReference("id2", IntegerType)(), + AttributeReference("nickname", StringType)()) /** - * Runs the `LocalNode`s and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts a sequence of input `LocalNode`s and uses them to - * instantiate the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. + * Wrap a function processing two [[LocalNode]]s such that: + * (1) all input rows are automatically converted to unsafe rows + * (2) all output rows are automatically converted back to safe rows */ - protected def doCheckAnswer( - input: Seq[DataFrame], - nodeFunction: Seq[LocalNode] => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - LocalNodeTest.checkAnswer( - input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, sortAnswers) match { - case Some(errorMessage) => fail(errorMessage) - case None => + protected def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) } } - protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { - new SeqScanNode( - conf, - df.queryExecution.sparkPlan.output, - df.queryExecution.toRdd.map(_.copy()).collect()) - } - -} - -/** - * Helper methods for writing tests of individual local physical operators. - */ -object LocalNodeTest { - /** - * Runs the `LocalNode`s and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts the input `LocalNode`s and uses them to - * instantiate the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. + * Recursively resolve all expressions in a [[LocalNode]] using the node's attributes. */ - def checkAnswer( - input: Seq[SeqScanNode], - nodeFunction: Seq[LocalNode] => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean): Option[String] = { - - val outputNode = nodeFunction(input) - - val outputResult: Seq[Row] = try { - outputNode.collect() - } catch { - case NonFatal(e) => - val errorMessage = - s""" - | Exception thrown while executing local plan: - | $outputNode - | == Exception == - | $e - | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin - return Some(errorMessage) - } - - SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage => - s""" - | Results do not match for local plan: - | $outputNode - | $errorMessage - """.stripMargin + protected def resolveExpressions(outputNode: LocalNode): LocalNode = { + outputNode transform { + case node: LocalNode => + val inputMap = node.output.map { a => (a.name, a) }.toMap + node transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } } } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index b1ef26ba82f16..40299d9d5ee37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -18,222 +18,128 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + class NestedLoopJoinNodeSuite extends LocalNodeTest { - import testImplicits._ - - private def joinSuite( - suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): Unit = { - test(s"$suiteName: left outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N", "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - lowerCaseData.col("n") > 1).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - upperCaseData.col("N") > 1).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left").collect()) + // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) + private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + joinTypes.foreach { joinType => + testJoin(unsafeAndCodegen, buildSide, joinType) } } + } - test(s"$suiteName: right outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N", "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("n") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - upperCaseData.col("N") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right").collect()) + /** + * Test outer nested loop joins with varying degrees of matches. + */ + private def testJoin( + unsafeAndCodegen: Boolean, + buildSide: BuildSide, + joinType: JoinType): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) + + // Actual test body + def runTest( + joinType: JoinType, + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)]): Unit = { + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val cond = 'id1 === 'id2 + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions( + new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond))) } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + } + assert(actualOutput.toSet === expectedOutput.toSet) } - test(s"$suiteName: full outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N", "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("n") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - upperCaseData.col("N") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "full").collect()) - } + test(s"$testNamePrefix: empty") { + runTest(joinType, Array.empty, Array.empty) + } + + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) }.toArray + runTest(joinType, someData, Array.empty) + runTest(joinType, Array.empty, someData) + runTest(joinType, someData, someIrrelevantData) + runTest(joinType, someIrrelevantData, someData) + } + + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(joinType, someData, someOtherData) + runTest(joinType, someOtherData, someData) + } + + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) } + runTest(joinType, someData, someSuperRelevantData) + runTest(joinType, someSuperRelevantData, someData) + } + } + + /** + * Helper method to generate the expected output of a test based on the join type. + */ + private def generateExpectedOutput( + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)], + joinType: JoinType): Array[(Int, String, Int, String)] = { + joinType match { + case LeftOuter => + val rightInputMap = rightInput.toMap + leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + + case RightOuter => + val leftInputMap = leftInput.toMap + rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + + case FullOuter => + val leftInputMap = leftInput.toMap + val rightInputMap = rightInput.toMap + val leftOutput = leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + val rightOutput = rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + (leftOutput ++ rightOutput).distinct + + case other => + throw new IllegalArgumentException(s"Join type $other is not applicable") } } - joinSuite( - "general-build-left", - BuildLeft, - SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite( - "general-build-right", - BuildRight, - SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite( - "tungsten-build-left", - BuildLeft, - SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") - joinSuite( - "tungsten-build-right", - BuildRight, - SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala index 38e0a230c46d8..02ecb23d34b2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -17,28 +17,33 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} +import org.apache.spark.sql.types.{IntegerType, StringType} -class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { - test("basic") { - val output = testData.queryExecution.sparkPlan.output - val columns = Seq(output(1), output(0)) - checkAnswer( - testData, - node => ProjectNode(conf, columns, node), - testData.select("value", "key").collect() - ) +class ProjectNodeSuite extends LocalNodeTest { + private val pieAttributes = Seq( + AttributeReference("id", IntegerType)(), + AttributeReference("age", IntegerType)(), + AttributeReference("name", StringType)()) + + private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = { + val inputNode = new DummyNode(pieAttributes, inputData) + val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2)) + val projectNode = new ProjectNode(conf, columns, inputNode) + val expectedOutput = inputData.map { case (id, age, name) => (id, name) } + val actualOutput = projectNode.collect().map { case row => + (row.getInt(0), row.getString(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - val output = emptyTestData.queryExecution.sparkPlan.output - val columns = Seq(output(1), output(0)) - checkAnswer( - emptyTestData, - node => ProjectNode(conf, columns, node), - emptyTestData.select("value", "key").collect() - ) + testProject() + } + + test("basic") { + testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala index 87a7da453999c..a3e83bbd51457 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -17,21 +17,32 @@ package org.apache.spark.sql.execution.local -class SampleNodeSuite extends LocalNodeTest { +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + - import testImplicits._ +class SampleNodeSuite extends LocalNodeTest { private def testSample(withReplacement: Boolean): Unit = { - test(s"withReplacement: $withReplacement") { - val seed = 0L - val input = sqlContext.sparkContext. - parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition - toDF("key", "value") - checkAnswer( - input, - node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node), - input.sample(withReplacement, 0.3, seed).collect() - ) + val seed = 0L + val lowerb = 0.0 + val upperb = 0.3 + val maybeOut = if (withReplacement) "" else "out" + test(s"with$maybeOut replacement") { + val inputData = (1 to 1000).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode) + val sampler = + if (withReplacement) { + new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false) + } else { + new BernoulliCellSampler[(Int, Int)](lowerb, upperb) + } + sampler.setSeed(seed) + val expectedOutput = sampler.sample(inputData.iterator).toArray + val actualOutput = sampleNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala index ff28b24eeff14..42ebc7bfcaadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -17,38 +17,34 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder} +import scala.util.Random -class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SortOrder - import testImplicits._ - private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { - val sortOrder: Seq[SortOrder] = sortExprs.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - sortOrder - } +class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { - private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = { - val testCaseName = if (desc) "desc" else "asc" - test(testCaseName) { - val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - val sortColumn = if (desc) input.col("key").desc else input.col("key") - checkAnswer( - input, - node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node), - input.sort(sortColumn).limit(5).collect() - ) + private def testTakeOrderedAndProject(desc: Boolean): Unit = { + val limit = 10 + val ascOrDesc = if (desc) "desc" else "asc" + test(ascOrDesc) { + val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val firstColumn = inputNode.output(0) + val sortDirection = if (desc) Descending else Ascending + val sortOrder = SortOrder(firstColumn, sortDirection) + val takeOrderAndProjectNode = new TakeOrderedAndProjectNode( + conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode) + val expectedOutput = inputData + .map { case (k, _) => k } + .sortBy { k => k * (if (desc) -1 else 1) } + .take(limit) + val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) } + assert(actualOutput === expectedOutput) } } - testTakeOrderedAndProjectNode(desc = false) - testTakeOrderedAndProjectNode(desc = true) + testTakeOrderedAndProject(desc = false) + testTakeOrderedAndProject(desc = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala index eedd7320900f9..666b0235c061d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -17,36 +17,39 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext -class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { +class UnionNodeSuite extends LocalNodeTest { - test("basic") { - checkAnswer2( - testData, - testData, - (node1, node2) => UnionNode(conf, Seq(node1, node2)), - testData.unionAll(testData).collect() - ) + private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = { + val inputNodes = inputData.map { data => + new DummyNode(kvIntAttributes, data) + } + val unionNode = new UnionNode(conf, inputNodes) + val expectedOutput = inputData.flatten + val actualOutput = unionNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - checkAnswer2( - emptyTestData, - emptyTestData, - (node1, node2) => UnionNode(conf, Seq(node1, node2)), - emptyTestData.unionAll(emptyTestData).collect() - ) + testUnion(Seq(Array.empty)) + testUnion(Seq(Array.empty, Array.empty)) + } + + test("self") { + val data = (1 to 100).map { i => (i, i) }.toArray + testUnion(Seq(data)) + testUnion(Seq(data, data)) + testUnion(Seq(data, data, data)) } - test("complicated union") { - val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData, - emptyTestData, emptyTestData, testData, emptyTestData) - doCheckAnswer( - dfs, - nodes => UnionNode(conf, nodes), - dfs.reduce(_.unionAll(_)).collect() - ) + test("basic") { + val zero = Array.empty[(Int, Int)] + val one = (1 to 100).map { i => (i, i) }.toArray + val two = (50 to 150).map { i => (i, i) }.toArray + val three = (800 to 900).map { i => (i, i) }.toArray + testUnion(Seq(zero, one, two, three)) } } From ca53d14621f7f4e5dfe025ad9e2c2f658a69497e Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Tue, 15 Sep 2015 19:41:38 -0700 Subject: [PATCH 0023/1168] [SPARK-9078] [SQL] Allow jdbc dialects to override the query used to check the table. Current implementation uses query with a LIMIT clause to find if table already exists. This syntax works only in some database systems. This patch changes the default query to the one that is likely to work on most databases, and adds a new method to the JdbcDialect abstract class to allow dialects to override the default query. I looked at using the JDBC meta data calls, it turns out there is no common way to find the current schema, catalog..etc. There is a new method Connection.getSchema() , but that is available only starting jdk1.7 , and existing jdbc drivers may not have implemented it. Other option was to use jdbc escape syntax clause for LIMIT, not sure on how well this supported in all the databases also. After looking at all the jdbc metadata options my conclusion was most common way is to use the simple select query with 'where 1 =0' , and allow dialects to customize as needed Author: sureshthalamati Closes #8676 from sureshthalamati/table_exists_spark-9078. --- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../datasources/jdbc/JdbcUtils.scala | 9 ++++++--- .../apache/spark/sql/jdbc/JdbcDialects.scala | 20 +++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 14 +++++++++++++ 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b2a66dd417b4c..745bb4ec9cf1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { val conn = JdbcUtils.createConnection(url, props) try { - var tableExists = JdbcUtils.tableExists(conn, table) + var tableExists = JdbcUtils.tableExists(conn, url, table) if (mode == SaveMode.Ignore && tableExists) { return diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 26788b2a4fd69..f89d55b20e212 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -42,10 +42,13 @@ object JdbcUtils extends Logging { /** * Returns true if the table already exists in the JDBC database. */ - def tableExists(conn: Connection, table: String): Boolean = { + def tableExists(conn: Connection, url: String, table: String): Boolean = { + val dialect = JdbcDialects.get(url) + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all - // SQL database systems, considering "table" could also include the database name. - Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess + // SQL database systems using JDBC meta data calls, considering "table" could also include + // the database name. Query used to find table exists can be overriden by the dialects. + Try(conn.prepareStatement(dialect.getTableExistsQuery(table)).executeQuery()).isSuccess } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index c6d05c9b83b98..68ebaaca6c53d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -88,6 +88,17 @@ abstract class JdbcDialect { def quoteIdentifier(colName: String): String = { s""""$colName"""" } + + /** + * Get the SQL query that should be used to find if the given table exists. Dialects can + * override this method to return a query that works best in a particular database. + * @param table The name of the table. + * @return The SQL query to use for checking the table. + */ + def getTableExistsQuery(table: String): String = { + s"SELECT * FROM $table WHERE 1=0" + } + } /** @@ -198,6 +209,11 @@ case object PostgresDialect extends JdbcDialect { case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) case _ => None } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } + } /** @@ -222,6 +238,10 @@ case object MySQLDialect extends JdbcDialect { override def quoteIdentifier(colName: String): String = { s"`$colName`" } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index ed710689cc670..5ab9381de4d66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -450,4 +450,18 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") } + + test("table exists query by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val table = "weblogs" + val defaultQuery = s"SELECT * FROM $table WHERE 1=0" + val limitQuery = s"SELECT 1 FROM $table LIMIT 1" + assert(MySQL.getTableExistsQuery(table) == limitQuery) + assert(Postgres.getTableExistsQuery(table) == limitQuery) + assert(db2.getTableExistsQuery(table) == defaultQuery) + assert(h2.getTableExistsQuery(table) == defaultQuery) + } } From cce3997ef28a80cdcac8a26c570654dfd7f6c9ef Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 15 Sep 2015 19:43:26 -0700 Subject: [PATCH 0024/1168] [SPARK-10595] [ML] [MLLIB] [DOCS] Various ML guide cleanups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Various ML guide cleanups. * ml-guide.md: Make it easier to access the algorithm-specific guides. * LDA user guide: EM often begins with useless topics, but running longer generally improves them dramatically. E.g., 10 iterations on a Wikipedia dataset produces useless topics, but 50 iterations produces very meaningful topics. * mllib-feature-extraction.html#elementwiseproduct: “w” parameter should be “scalingVec” * Clean up Binarizer user guide a little. * Document in Pipeline that users should not put an instance into the Pipeline in more than 1 place. * spark.ml Word2Vec user guide: clean up grammar/writing * Chi Sq Feature Selector docs: Improve text in doc. CC: mengxr feynmanliang Author: Joseph K. Bradley Closes #8752 from jkbradley/mlguide-fixes-1.5. --- docs/ml-features.md | 34 +++++++++++++++++--- docs/ml-guide.md | 31 ++++++++++++------- docs/mllib-clustering.md | 4 +++ docs/mllib-feature-extraction.md | 53 +++++++++++++++++++++----------- docs/mllib-guide.md | 4 +-- 5 files changed, 91 insertions(+), 35 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index a414c21b5c280..b70da4ac63845 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -123,12 +123,21 @@ for features_label in rescaledData.select("features", "label").take(3): ## Word2Vec -`Word2Vec` is an `Estimator` which takes sequences of words that represents documents and trains a `Word2VecModel`. The model is a `Map(String, Vector)` essentially, which maps each word to an unique fix-sized vector. The `Word2VecModel` transforms each documents into a vector using the average of all words in the document, which aims to other computations of documents such as similarity calculation consequencely. Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more details on Word2Vec. +`Word2Vec` is an `Estimator` which takes sequences of words representing documents and trains a +`Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` +transforms each document into a vector using the average of all words in the document; this vector +can then be used for as features for prediction, document similarity calculations, etc. +Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more +details. -Word2Vec is implemented in [Word2Vec](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec). In the following code segment, we start with a set of documents, each of them is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. +In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm.
+ +Refer to the [Word2Vec Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec) +for more details on the API. + {% highlight scala %} import org.apache.spark.ml.feature.Word2Vec @@ -152,6 +161,10 @@ result.select("result").take(3).foreach(println)
+ +Refer to the [Word2Vec Java docs](api/java/org/apache/spark/ml/feature/Word2Vec.html) +for more details on the API. + {% highlight java %} import java.util.Arrays; @@ -192,6 +205,10 @@ for (Row r: result.select("result").take(3)) {
+ +Refer to the [Word2Vec Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Word2Vec) +for more details on the API. + {% highlight python %} from pyspark.ml.feature import Word2Vec @@ -621,12 +638,15 @@ for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): ## Binarizer -Binarization is the process of thresholding numerical features to binary features. As some probabilistic estimators make assumption that the input data is distributed according to [Bernoulli distribution](http://en.wikipedia.org/wiki/Bernoulli_distribution), a binarizer is useful for pre-processing the input data with continuous numerical features. +Binarization is the process of thresholding numerical features to binary (0/1) features. -A simple [Binarizer](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) class provides this functionality. Besides the common parameters of `inputCol` and `outputCol`, `Binarizer` has the parameter `threshold` used for binarizing continuous numerical features. The features greater than the threshold, will be binarized to 1.0. The features equal to or less than the threshold, will be binarized to 0.0. The example below shows how to binarize numerical features. +`Binarizer` takes the common parameters `inputCol` and `outputCol`, as well as the `threshold` for binarization. Feature values greater than the threshold are binarized to 1.0; values equal to or less than the threshold are binarized to 0.0.
+ +Refer to the [Binarizer API doc](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details. + {% highlight scala %} import org.apache.spark.ml.feature.Binarizer import org.apache.spark.sql.DataFrame @@ -650,6 +670,9 @@ binarizedFeatures.collect().foreach(println)
+ +Refer to the [Binarizer API doc](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details. + {% highlight java %} import java.util.Arrays; @@ -687,6 +710,9 @@ for (Row r : binarizedFeatures.collect()) {
+ +Refer to the [Binarizer API doc](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details. + {% highlight python %} from pyspark.ml.feature import Binarizer diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 78c93a95c7807..c5d7f990021f1 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -32,7 +32,21 @@ See the [algorithm guides](#algorithm-guides) section below for guides on sub-pa * This will become a table of contents (this text will be scraped). {:toc} -# Main concepts +# Algorithm guides + +We provide several algorithm guides specific to the Pipelines API. +Several of these algorithms, such as certain feature transformers, are not in the `spark.mllib` API. +Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., random forests +provide class probabilities, and linear models provide model summaries. + +* [Feature extraction, transformation, and selection](ml-features.html) +* [Decision Trees for classification and regression](ml-decision-tree.html) +* [Ensembles](ml-ensembles.html) +* [Linear methods with elastic net regularization](ml-linear-methods.html) +* [Multilayer perceptron classifier](ml-ann.html) + + +# Main concepts in Pipelines Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. @@ -166,6 +180,11 @@ compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the `DataFrame` *schema*, a description of the data types of columns in the `DataFrame`. +*Unique Pipeline stages*: A `Pipeline`'s stages should be unique instances. E.g., the same instance +`myHashingTF` should not be inserted into the `Pipeline` twice since `Pipeline` stages must have +unique IDs. However, different instances `myHashingTF1` and `myHashingTF2` (both of type `HashingTF`) +can be put into the same `Pipeline` since different instances will be created with different IDs. + ## Parameters Spark ML `Estimator`s and `Transformer`s use a uniform API for specifying parameters. @@ -184,16 +203,6 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -# Algorithm guides - -There are now several algorithms in the Pipelines API which are not in the `spark.mllib` API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. - -* [Feature extraction, transformation, and selection](ml-features.html) -* [Decision Trees for classification and regression](ml-decision-tree.html) -* [Ensembles](ml-ensembles.html) -* [Linear methods with elastic net regularization](ml-linear-methods.html) -* [Multilayer perceptron classifier](ml-ann.html) - # Code examples This section gives code examples illustrating the functionality discussed above. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 3fb35d3c50b06..c2711cf82deb4 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -507,6 +507,10 @@ must also be $> 1.0$. Providing `Vector(-1)` results in default behavior $> 1.0$. Providing `-1` results in defaulting to a value of $0.1 + 1$. * `maxIterations`: The maximum number of EM iterations. +*Note*: It is important to do enough iterations. In early iterations, EM often has useless topics, +but those topics improve dramatically after more iterations. Using at least 20 and possibly +50-100 iterations is often reasonable, depending on your dataset. + `EMLDAOptimizer` produces a `DistributedLDAModel`, which stores not only the inferred topics but also the full training corpus and topic distributions for each document in the training corpus. A diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index de86aba2ae627..7e417ed5f37a9 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -380,35 +380,43 @@ data2 = labels.zip(normalizer2.transform(features))
-## Feature selection -[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. +## ChiSqSelector -### ChiSqSelector -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which the class label depends on the most. This is akin to yielding the features with the most predictive power. +[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) tries to identify relevant +features for use in model construction. It reduces the size of the feature space, which can improve +both speed and statistical learning behavior. -#### Model Fitting +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements +Chi-Squared feature selection. It operates on labeled data with categorical features. +`ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, +and then filters (selects) the top features which the class label depends on the most. +This is akin to yielding the features with the most predictive power. -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) has the -following parameters in the constructor: +The number of features to select can be tuned using a held-out validation set. -* `numTopFeatures` number of top features that the selector will select (filter). +### Model Fitting -We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method in -`ChiSqSelector` which can take an input of `RDD[LabeledPoint]` with categorical features, learn the summary statistics, and then -return a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. +`ChiSqSelector` takes a `numTopFeatures` parameter specifying the number of top features that +the selector will select. -This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) -which can apply the Chi-Squared feature selection on a `Vector` to produce a reduced `Vector` or on +The [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method takes +an input of `RDD[LabeledPoint]` with categorical features, learns the summary statistics, and then +returns a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. +The `ChiSqSelectorModel` can be applied either to a `Vector` to produce a reduced `Vector`, or to an `RDD[Vector]` to produce a reduced `RDD[Vector]`. Note that the user can also construct a `ChiSqSelectorModel` by hand by providing an array of selected feature indices (which must be sorted in ascending order). -#### Example +### Example The following example shows the basic use of ChiSqSelector. The data set used has a feature matrix consisting of greyscale values that vary from 0 to 255 for each feature.
-
+
+ +Refer to the [`ChiSqSelector` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) +for details on the API. + {% highlight scala %} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors @@ -434,7 +442,11 @@ val filteredData = discretizedData.map { lp => {% endhighlight %}
-
+
+ +Refer to the [`ChiSqSelector` Java docs](api/java/org/apache/spark/mllib/feature/ChiSqSelector.html) +for details on the API. + {% highlight java %} import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; @@ -486,7 +498,12 @@ sc.stop(); ## ElementwiseProduct -ElementwiseProduct multiplies each input vector by a provided "weight" vector, using element-wise multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `w`, to yield a result vector. +`ElementwiseProduct` multiplies each input vector by a provided "weight" vector, using element-wise +multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This +represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) +between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector. +Qu8T948*1# +Denoting the `scalingVec` as "`w`," this transformation may be written as: `\[ \begin{pmatrix} v_1 \\ @@ -506,7 +523,7 @@ v_N [`ElementwiseProduct`](api/scala/index.html#org.apache.spark.mllib.feature.ElementwiseProduct) has the following parameter in the constructor: -* `w`: the transforming vector. +* `scalingVec`: the transforming vector. `ElementwiseProduct` implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) which can apply the weighting on a `Vector` to produce a transformed `Vector` or on an `RDD[Vector]` to produce a transformed `RDD[Vector]`. diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 257f7cc7603fa..91e50ccfecec4 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -13,9 +13,9 @@ primitives and higher-level pipeline APIs. It divides into two packages: -* [`spark.mllib`](mllib-guide.html#mllib-types-algorithms-and-utilities) contains the original API +* [`spark.mllib`](mllib-guide.html#data-types-algorithms-and-utilities) contains the original API built on top of [RDDs](programming-guide.html#resilient-distributed-datasets-rdds). -* [`spark.ml`](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) provides higher-level API +* [`spark.ml`](ml-guide.html) provides higher-level API built on top of [DataFrames](sql-programming-guide.html#dataframes) for constructing ML pipelines. Using `spark.ml` is recommended because with DataFrames the API is more versatile and flexible. From 122cff0991a143a85bfd7a00ee987e14688ff4b7 Mon Sep 17 00:00:00 2001 From: Vinod K C Date: Tue, 15 Sep 2015 23:25:51 -0700 Subject: [PATCH 0025/1168] [SPARK-10516] [ MLLIB] Added values property in DenseVector Author: Vinod K C Closes #8682 from vinodkc/fix_SPARK-10516. --- python/pyspark/mllib/linalg/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 380f86e9b44f8..4829acb16ed8a 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -399,6 +399,10 @@ def squared_distance(self, other): def toArray(self): return self.array + @property + def values(self): + return self.array + def __getitem__(self, item): return self.array[item] From a1132b00ce1c4ae36a3f6de1878b163b59341f84 Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Wed, 16 Sep 2015 10:47:30 +0100 Subject: [PATCH 0026/1168] [SPARK-10511] [BUILD] Reset git repository before packaging source distro The calculation of Spark version is downloading Scala and Zinc in the build directory which is inflating the size of the source distribution. Reseting the repo before packaging the source distribution fix this issue. Author: Luciano Resende Closes #8774 from lresende/spark-10511. --- dev/create-release/release-build.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index d0b3a54dde1dc..9dac43ce54425 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -99,6 +99,7 @@ fi DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" USER_HOST="$ASF_USERNAME@people.apache.org" +git clean -d -f -x rm .gitignore rm -rf .git cd .. From b04c20ecafe6600782557817fdc4dbcab86970c5 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 16 Sep 2015 04:34:14 -0700 Subject: [PATCH 0027/1168] [SPARK-10276] [MLLIB] [PYSPARK] Add @since annotation to pyspark.mllib.recommendation Author: Yu ISHIKAWA Closes #8677 from yu-iskw/SPARK-10276. --- python/pyspark/mllib/recommendation.py | 36 +++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 506ca2151cce7..95047b5b7b4b7 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -18,7 +18,7 @@ import array from collections import namedtuple -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc from pyspark.mllib.util import JavaLoader, JavaSaveable @@ -36,6 +36,8 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])): (1, 2, 5.0) >>> (r[0], r[1], r[2]) (1, 2, 5.0) + + .. versionadded:: 1.2.0 """ def __reduce__(self): @@ -111,13 +113,17 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 0.9.0 """ + @since("0.9.0") def predict(self, user, product): """ Predicts rating for the given user and product. """ return self._java_model.predict(int(user), int(product)) + @since("0.9.0") def predictAll(self, user_product): """ Returns a list of predicted ratings for input user and product pairs. @@ -128,6 +134,7 @@ def predictAll(self, user_product): user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1]))) return self.call("predict", user_product) + @since("1.2.0") def userFeatures(self): """ Returns a paired RDD, where the first element is the user and the @@ -135,6 +142,7 @@ def userFeatures(self): """ return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v)) + @since("1.2.0") def productFeatures(self): """ Returns a paired RDD, where the first element is the product and the @@ -142,6 +150,7 @@ def productFeatures(self): """ return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v)) + @since("1.4.0") def recommendUsers(self, product, num): """ Recommends the top "num" number of users for a given product and returns a list @@ -149,6 +158,7 @@ def recommendUsers(self, product, num): """ return list(self.call("recommendUsers", product, num)) + @since("1.4.0") def recommendProducts(self, user, num): """ Recommends the top "num" number of products for a given user and returns a list @@ -157,17 +167,25 @@ def recommendProducts(self, user, num): return list(self.call("recommendProducts", user, num)) @property + @since("1.4.0") def rank(self): + """Rank for the features in this model""" return self.call("rank") @classmethod + @since("1.3.1") def load(cls, sc, path): + """Load a model from the given path""" model = cls._load_java(sc, path) wrapper = sc._jvm.MatrixFactorizationModelWrapper(model) return MatrixFactorizationModel(wrapper) class ALS(object): + """Alternating Least Squares matrix factorization + + .. versionadded:: 0.9.0 + """ @classmethod def _prepare(cls, ratings): @@ -188,15 +206,31 @@ def _prepare(cls, ratings): return ratings @classmethod + @since("0.9.0") def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, seed=None): + """ + Train a matrix factorization model given an RDD of ratings given by users to some products, + in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the + product of two lower-rank matrices of a given rank (number of features). To solve for these + features, we run a given number of iterations of ALS. This is done using a level of + parallelism given by `blocks`. + """ model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, nonnegative, seed) return MatrixFactorizationModel(model) @classmethod + @since("0.9.0") def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01, nonnegative=False, seed=None): + """ + Train a matrix factorization model given an RDD of 'implicit preferences' given by users + to some products, in the form of (userID, productID, preference) pairs. We approximate the + ratings matrix as the product of two lower-rank matrices of a given rank (number of + features). To solve for these features, we run a given number of iterations of ALS. + This is done using a level of parallelism given by `blocks`. + """ model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, alpha, nonnegative, seed) return MatrixFactorizationModel(model) From 4b4ea296e195a2422ad54965028e8d2c5601f3bb Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 16 Sep 2015 19:19:23 +0100 Subject: [PATCH 0028/1168] [SPARK-10589] [WEBUI] Add defense against external site framing Set `X-Frame-Options: SAMEORIGIN` to protect against frame-related vulnerability Author: Sean Owen Closes #8745 from srowen/SPARK-10589. --- .../spark/deploy/worker/ui/WorkerWebUI.scala | 7 ++++--- .../org/apache/spark/metrics/MetricsSystem.scala | 2 +- .../spark/metrics/sink/MetricsServlet.scala | 6 +++--- .../scala/org/apache/spark/ui/JettyUtils.scala | 16 ++++++++++++++-- .../main/scala/org/apache/spark/ui/WebUI.scala | 4 ++-- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 709a27233598c..1a0598e50dcf1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -20,9 +20,8 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.Logging import org.apache.spark.deploy.worker.Worker -import org.apache.spark.deploy.worker.ui.WorkerWebUI._ import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.RpcUtils @@ -49,7 +48,9 @@ class WorkerWebUI( attachPage(new WorkerPage(this)) attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) attachHandler(createServletHandler("/log", - (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr)) + (request: HttpServletRequest) => logPage.renderLog(request), + worker.securityMgr, + worker.conf)) } } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 4517f465ebd3b..48afe3ae3511f 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -88,7 +88,7 @@ private[spark] class MetricsSystem private ( */ def getServletHandlers: Array[ServletContextHandler] = { require(running, "Can only call getServletHandlers on a running MetricsSystem") - metricsServlet.map(_.getHandlers).getOrElse(Array()) + metricsServlet.map(_.getHandlers(conf)).getOrElse(Array()) } metricsConfig.initialize() diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 0c2e212a33074..4193e1d21d3c1 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -27,7 +27,7 @@ import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.SecurityManager +import org.apache.spark.{SparkConf, SecurityManager} import org.apache.spark.ui.JettyUtils._ private[spark] class MetricsServlet( @@ -49,10 +49,10 @@ private[spark] class MetricsServlet( val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers: Array[ServletContextHandler] = { + def getHandlers(conf: SparkConf): Array[ServletContextHandler] = { Array[ServletContextHandler]( createServletHandler(servletPath, - new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) + new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr, conf) ) } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 779c0ba083596..b796a44fe01ac 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -59,7 +59,17 @@ private[spark] object JettyUtils extends Logging { def createServlet[T <% AnyRef]( servletParams: ServletParams[T], - securityMgr: SecurityManager): HttpServlet = { + securityMgr: SecurityManager, + conf: SparkConf): HttpServlet = { + + // SPARK-10589 avoid frame-related click-jacking vulnerability, using X-Frame-Options + // (see http://tools.ietf.org/html/rfc7034). By default allow framing only from the + // same origin, but allow framing for a specific named URI. + // Example: spark.ui.allowFramingFrom = https://example.com/ + val allowFramingFrom = conf.getOption("spark.ui.allowFramingFrom") + val xFrameOptionsValue = + allowFramingFrom.map(uri => s"ALLOW-FROM $uri").getOrElse("SAMEORIGIN") + new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { try { @@ -68,6 +78,7 @@ private[spark] object JettyUtils extends Logging { response.setStatus(HttpServletResponse.SC_OK) val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.setHeader("X-Frame-Options", xFrameOptionsValue) // scalastyle:off println response.getWriter.println(servletParams.extractFn(result)) // scalastyle:on println @@ -97,8 +108,9 @@ private[spark] object JettyUtils extends Logging { path: String, servletParams: ServletParams[T], securityMgr: SecurityManager, + conf: SparkConf, basePath: String = ""): ServletContextHandler = { - createServletHandler(path, createServlet(servletParams, securityMgr), basePath) + createServletHandler(path, createServlet(servletParams, securityMgr, conf), basePath) } /** Create a context handler that responds to a request with the given path prefix */ diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 61449847add3d..81a121fd441bd 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -76,9 +76,9 @@ private[spark] abstract class WebUI( def attachPage(page: WebUIPage) { val pagePath = "/" + page.prefix val renderHandler = createServletHandler(pagePath, - (request: HttpServletRequest) => page.render(request), securityManager, basePath) + (request: HttpServletRequest) => page.render(request), securityManager, conf, basePath) val renderJsonHandler = createServletHandler(pagePath.stripSuffix("/") + "/json", - (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath) + (request: HttpServletRequest) => page.renderJson(request), securityManager, conf, basePath) attachHandler(renderHandler) attachHandler(renderJsonHandler) pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]()) From e4bafe29aff00057cfd049b2ac6a069b2456af4a Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 16 Sep 2015 13:20:39 -0700 Subject: [PATCH 0029/1168] [SPARK-10050] [SPARKR] Support collecting data of MapType in DataFrame. 1. Support collecting data of MapType from DataFrame. 2. Support data of MapType in createDataFrame. Author: Sun Rui Closes #8711 from sun-rui/SPARK-10050. --- R/pkg/R/SQLContext.R | 5 +- R/pkg/R/deserialize.R | 14 +++++ R/pkg/R/schema.R | 34 ++++++++--- R/pkg/inst/tests/test_sparkSQL.R | 56 +++++++++++++++---- .../scala/org/apache/spark/api/r/SerDe.scala | 31 ++++++++++ .../org/apache/spark/sql/api/r/SQLUtils.scala | 6 ++ 6 files changed, 123 insertions(+), 23 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 4ac057d0f2d83..1c58fd96d750a 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -41,10 +41,7 @@ infer_type <- function(x) { if (type == "map") { stopifnot(length(x) > 0) key <- ls(x)[[1]] - list(type = "map", - keyType = "string", - valueType = infer_type(get(key, x)), - valueContainsNull = TRUE) + paste0("map") } else if (type == "array") { stopifnot(length(x) > 0) names <- names(x) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index d1858ec227b56..ce88d0b071b72 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -50,6 +50,7 @@ readTypedObject <- function(con, type) { "t" = readTime(con), "a" = readArray(con), "l" = readList(con), + "e" = readEnv(con), "n" = NULL, "j" = getJobj(readString(con)), stop(paste("Unsupported type for deserialization", type))) @@ -121,6 +122,19 @@ readList <- function(con) { } } +readEnv <- function(con) { + env <- new.env() + len <- readInt(con) + if (len > 0) { + for (i in 1:len) { + key <- readString(con) + value <- readObject(con) + env[[key]] <- value + } + } + env +} + readRaw <- function(con) { dataLen <- readInt(con) readBin(con, raw(), as.integer(dataLen), endian = "big") diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 62d4f73878d29..8df1563f8ebc0 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -131,13 +131,33 @@ checkType <- function(type) { if (type %in% primtiveTypes) { return() } else { - m <- regexec("^array<(.*)>$", type) - matchedStrings <- regmatches(type, m) - if (length(matchedStrings[[1]]) >= 2) { - elemType <- matchedStrings[[1]][2] - checkType(elemType) - return() - } + # Check complex types + firstChar <- substr(type, 1, 1) + switch (firstChar, + a = { + # Array type + m <- regexec("^array<(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + elemType <- matchedStrings[[1]][2] + checkType(elemType) + return() + } + }, + m = { + # Map type + m <- regexec("^map<(.*),(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 3) { + keyType <- matchedStrings[[1]][2] + if (keyType != "string" && keyType != "character") { + stop("Key type in a map must be string or character") + } + valueType <- matchedStrings[[1]][3] + checkType(valueType) + return() + } + }) } stop(paste("Unsupported type for Dataframe:", type)) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 98d4402d368e1..e159a69584274 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -57,7 +57,7 @@ mockLinesComplexType <- complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) -test_that("infer types", { +test_that("infer types and check types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") expect_equal(infer_type("abc"), "string") @@ -72,9 +72,9 @@ test_that("infer types", { checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() assign("a", 1L, envir = e) - expect_equal(infer_type(e), - list(type = "map", keyType = "string", valueType = "integer", - valueContainsNull = TRUE)) + expect_equal(infer_type(e), "map") + + expect_error(checkType("map"), "Key type in a map must be string or character") }) test_that("structType and structField", { @@ -242,7 +242,7 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) -test_that("create DataFrame with nested array and struct", { +test_that("create DataFrame with nested array and map", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) @@ -253,21 +253,35 @@ test_that("create DataFrame with nested array and struct", { # ldf <- collect(df) # expect_equal(ldf[1,], l[[1]]) + # ArrayType and MapType + e <- new.env() + assign("n", 3L, envir = e) - # ArrayType only for now - l <- list(as.list(1:10), list("a", "b")) - df <- createDataFrame(sqlContext, list(l), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"))) + l <- list(as.list(1:10), list("a", "b"), e) + df <- createDataFrame(sqlContext, list(l), c("a", "b", "c")) + expect_equal(dtypes(df), list(c("a", "array"), + c("b", "array"), + c("c", "map"))) expect_equal(count(df), 1) ldf <- collect(df) - expect_equal(names(ldf), c("a", "b")) + expect_equal(names(ldf), c("a", "b", "c")) expect_equal(ldf[1, 1][[1]], l[[1]]) expect_equal(ldf[1, 2][[1]], l[[2]]) + e <- ldf$c[[1]] + expect_equal(class(e), "environment") + expect_equal(ls(e), "n") + expect_equal(e$n, 3L) }) +# For test map type in DataFrame +mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", + "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", + "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") +mapTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesMapType, mapTypeJsonPath) + test_that("Collect DataFrame with complex types", { - # only ArrayType now - # TODO: tests for StructType and MapType after they are supported + # ArrayType df <- jsonFile(sqlContext, complexTypeJsonPath) ldf <- collect(df) @@ -277,6 +291,24 @@ test_that("Collect DataFrame with complex types", { expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + + # MapType + schema <- structType(structField("name", "string"), + structField("info", "map")) + df <- read.df(sqlContext, mapTypeJsonPath, "json", schema) + expect_equal(dtypes(df), list(c("name", "string"), + c("info", "map"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("name", "info")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "environment") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) + + # TODO: tests for StructType after it is supported }) test_that("jsonFile() on a local file returns a DataFrame", { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 3c92bb7a1c73c..0c78613e406e1 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -209,11 +209,23 @@ private[spark] object SerDe { case "array" => dos.writeByte('a') // Array of objects case "list" => dos.writeByte('l') + case "map" => dos.writeByte('e') case "jobj" => dos.writeByte('j') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") } } + private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { + if (key == null) { + throw new IllegalArgumentException("Key in map can't be null.") + } else if (!key.isInstanceOf[String]) { + throw new IllegalArgumentException(s"Invalid map key type: ${key.getClass.getName}") + } + + writeString(dos, key.asInstanceOf[String]) + writeObject(dos, value) + } + def writeObject(dos: DataOutputStream, obj: Object): Unit = { if (obj == null) { writeType(dos, "void") @@ -306,6 +318,25 @@ private[spark] object SerDe { writeInt(dos, v.length) v.foreach(elem => writeObject(dos, elem)) + // Handle map + case v: java.util.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + val iter = v.entrySet.iterator + while(iter.hasNext) { + val entry = iter.next + val key = entry.getKey + val value = entry.getValue + + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + case v: scala.collection.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + v.foreach { case (key, value) => + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + case _ => writeType(dos, "jobj") writeJObj(dos, value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d4b834adb6e39..f45d119c8cfdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -64,6 +64,12 @@ private[r] object SQLUtils { case r"\Aarray<(.*)${elemType}>\Z" => { org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) } + case r"\Amap<(.*)${keyType},(.*)${valueType}>\Z" => { + if (keyType != "string" && keyType != "character") { + throw new IllegalArgumentException("Key type of a map must be string or character") + } + org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) + } case _ => throw new IllegalArgumentException(s"Invaid type $dataType") } } From 5999f1ad8f413cfdc76f5cf2a94f074b746a2aa6 Mon Sep 17 00:00:00 2001 From: Kevin Cox Date: Wed, 16 Sep 2015 15:30:17 -0700 Subject: [PATCH 0030/1168] [SPARK-9794] [SQL] Fix datetime parsing in SparkSQL. This fixes https://issues.apache.org/jira/browse/SPARK-9794 by using a real ISO8601 parser. (courtesy of the xml component of the standard java library) cc: angelini Author: Kevin Cox Closes #8396 from kevincox/kevincox-sql-time-parsing. --- .../sql/catalyst/util/DateTimeUtils.scala | 27 ++++++---------- .../catalyst/util/DateTimeUtilsSuite.scala | 32 +++++++++++++++++++ 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 687ca000d12bb..400c4327be1c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import java.util.{TimeZone, Calendar} +import javax.xml.bind.DatatypeConverter; import org.apache.spark.unsafe.types.UTF8String @@ -109,30 +110,22 @@ object DateTimeUtils { } def stringToTime(s: String): java.util.Date = { - if (!s.contains('T')) { + var indexOfGMT = s.indexOf("GMT"); + if (indexOfGMT != -1) { + // ISO8601 with a weird time zone specifier (2000-01-01T00:00GMT+01:00) + val s0 = s.substring(0, indexOfGMT) + val s1 = s.substring(indexOfGMT + 3) + // Mapped to 2000-01-01T00:00+01:00 + stringToTime(s0 + s1) + } else if (!s.contains('T')) { // JDBC escape string if (s.contains(' ')) { Timestamp.valueOf(s) } else { Date.valueOf(s) } - } else if (s.endsWith("Z")) { - // this is zero timezone of ISO8601 - stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") - } else if (s.indexOf("GMT") == -1) { - // timezone with ISO8601 - val inset = "+00.00".length - val s0 = s.substring(0, s.length - inset) - val s1 = s.substring(s.length - inset, s.length) - if (s0.substring(s0.lastIndexOf(':')).contains('.')) { - stringToTime(s0 + "GMT" + s1) - } else { - stringToTime(s0 + ".0GMT" + s1) - } } else { - // ISO8601 with GMT insert - val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) - ISO8601GMT.parse(s) + DatatypeConverter.parseDateTime(s).getTime() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 6b9a11f0ff743..46335941b62d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -136,6 +136,38 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) } + test("string to time") { + // Tests with UTC. + var c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(Calendar.MILLISECOND, 0) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00GMT-00:00") === c.getTime()) + + c.set(2000, 11, 30, 10, 0, 0) + assert(stringToTime("2000-12-30T10:00:00Z") === c.getTime()) + + // Tests with set time zone. + c.setTimeZone(TimeZone.getTimeZone("GMT-04:00")) + c.set(Calendar.MILLISECOND, 0) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00-04:00") === c.getTime()) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00GMT-04:00") === c.getTime()) + + // Tests with local time zone. + c.setTimeZone(TimeZone.getDefault()) + c.set(Calendar.MILLISECOND, 0) + + c.set(2000, 11, 30, 0, 0, 0) + assert(stringToTime("2000-12-30") === new Date(c.getTimeInMillis())) + + c.set(2000, 11, 30, 10, 0, 0) + assert(stringToTime("2000-12-30 10:00:00") === new Timestamp(c.getTimeInMillis())) + } + test("string to timestamp") { var c = Calendar.getInstance() c.set(1969, 11, 31, 16, 0, 0) From 1c5ce5917804c77e88eda2eaf6122ba09afe8b14 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 16 Sep 2015 15:32:01 -0700 Subject: [PATCH 0031/1168] Tiny style fix for d39f15ea2b8bed5342d2f8e3c1936f915c470783. --- .../org/apache/spark/sql/catalyst/util/DateTimeUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 400c4327be1c7..781ed1688a327 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import java.util.{TimeZone, Calendar} -import javax.xml.bind.DatatypeConverter; +import javax.xml.bind.DatatypeConverter import org.apache.spark.unsafe.types.UTF8String From edaf6c4cd37d9ef67e2aaf86b3244c4436762db4 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 17 Sep 2015 00:48:57 -0700 Subject: [PATCH 0032/1168] [MINOR] [CORE] Fixes minor variable name typo Author: Cheng Lian Closes #8784 from liancheng/typo-fix. --- .../apache/spark/serializer/GenericAvroSerializerSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala index bc9f3708ed69d..87f25e7245e1f 100644 --- a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -76,9 +76,9 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { test("caches previously seen schemas") { val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val compressedSchema = genericSer.compress(schema) - val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) + val decompressedSchema = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) assert(compressedSchema.eq(genericSer.compress(schema))) - assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) + assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) } } From 0c6a6873211fe3e6457a1704bc0de4ae6f1e0485 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:43:59 -0700 Subject: [PATCH 0033/1168] [SPARK-10284] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.tuning Author: Yu ISHIKAWA Closes #8694 from yu-iskw/SPARK-10284. --- python/pyspark/ml/tuning.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index cae778869e9c5..ab5621f45c72c 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -18,6 +18,7 @@ import itertools import numpy as np +from pyspark import since from pyspark.ml.param import Params, Param from pyspark.ml import Estimator, Model from pyspark.ml.util import keyword_only @@ -47,11 +48,14 @@ class ParamGridBuilder(object): True >>> all([m in expected for m in output]) True + + .. versionadded:: 1.4.0 """ def __init__(self): self._param_grid = {} + @since("1.4.0") def addGrid(self, param, values): """ Sets the given parameters in this grid to fixed values. @@ -60,6 +64,7 @@ def addGrid(self, param, values): return self + @since("1.4.0") def baseOn(self, *args): """ Sets the given parameters in this grid to fixed values. @@ -73,6 +78,7 @@ def baseOn(self, *args): return self + @since("1.4.0") def build(self): """ Builds and returns all combinations of parameters specified @@ -104,6 +110,8 @@ class CrossValidator(Estimator): >>> cvModel = cv.fit(dataset) >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -142,6 +150,7 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF self._set(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): @@ -150,6 +159,7 @@ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, num kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setEstimator(self, value): """ Sets the value of :py:attr:`estimator`. @@ -157,12 +167,14 @@ def setEstimator(self, value): self._paramMap[self.estimator] = value return self + @since("1.4.0") def getEstimator(self): """ Gets the value of estimator or its default value. """ return self.getOrDefault(self.estimator) + @since("1.4.0") def setEstimatorParamMaps(self, value): """ Sets the value of :py:attr:`estimatorParamMaps`. @@ -170,12 +182,14 @@ def setEstimatorParamMaps(self, value): self._paramMap[self.estimatorParamMaps] = value return self + @since("1.4.0") def getEstimatorParamMaps(self): """ Gets the value of estimatorParamMaps or its default value. """ return self.getOrDefault(self.estimatorParamMaps) + @since("1.4.0") def setEvaluator(self, value): """ Sets the value of :py:attr:`evaluator`. @@ -183,12 +197,14 @@ def setEvaluator(self, value): self._paramMap[self.evaluator] = value return self + @since("1.4.0") def getEvaluator(self): """ Gets the value of evaluator or its default value. """ return self.getOrDefault(self.evaluator) + @since("1.4.0") def setNumFolds(self, value): """ Sets the value of :py:attr:`numFolds`. @@ -196,6 +212,7 @@ def setNumFolds(self, value): self._paramMap[self.numFolds] = value return self + @since("1.4.0") def getNumFolds(self): """ Gets the value of numFolds or its default value. @@ -231,7 +248,15 @@ def _fit(self, dataset): bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) + @since("1.4.0") def copy(self, extra=None): + """ + Creates a copy of this instance with a randomly generated uid + and some extra params. This copies creates a deep copy of + the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ if extra is None: extra = dict() newCV = Params.copy(self, extra) @@ -246,6 +271,8 @@ def copy(self, extra=None): class CrossValidatorModel(Model): """ Model from k-fold cross validation. + + .. versionadded:: 1.4.0 """ def __init__(self, bestModel): @@ -256,6 +283,7 @@ def __init__(self, bestModel): def _transform(self, dataset): return self.bestModel.transform(dataset) + @since("1.4.0") def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid From ebb754c19204bf67ff84782e92ef882eee9e71af Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:45:20 -0700 Subject: [PATCH 0034/1168] [SPARK-10283] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.regression Author: Yu ISHIKAWA Closes #8693 from yu-iskw/SPARK-10283. --- python/pyspark/ml/regression.py | 65 +++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index a9503608b7f25..21d454f9003bb 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -62,6 +63,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.4.0 """ @keyword_only @@ -81,6 +84,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, standardization=True): @@ -96,13 +100,31 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearRegressionModel(java_model) + @since("1.4.0") + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self._paramMap[self.elasticNetParam] = value + return self + + @since("1.4.0") + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + class LinearRegressionModel(JavaModel): """ Model fitted by LinearRegression. + + .. versionadded:: 1.4.0 """ @property + @since("1.4.0") def weights(self): """ Model weights. @@ -110,6 +132,7 @@ def weights(self): return self._call_java("weights") @property + @since("1.4.0") def intercept(self): """ Model intercept. @@ -162,6 +185,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -193,6 +218,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -209,6 +235,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return DecisionTreeRegressionModel(java_model) + @since("1.4.0") def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. @@ -216,6 +243,7 @@ def setImpurity(self, value): self._paramMap[self.impurity] = value return self + @since("1.4.0") def getImpurity(self): """ Gets the value of impurity or its default value. @@ -225,13 +253,19 @@ def getImpurity(self): @inherit_doc class DecisionTreeModel(JavaModel): + """Abstraction for Decision Tree models. + + .. versionadded:: 1.5.0 + """ @property + @since("1.5.0") def numNodes(self): """Return number of nodes of the decision tree.""" return self._call_java("numNodes") @property + @since("1.5.0") def depth(self): """Return depth of the decision tree.""" return self._call_java("depth") @@ -242,8 +276,13 @@ def __repr__(self): @inherit_doc class TreeEnsembleModels(JavaModel): + """Represents a tree ensemble model. + + .. versionadded:: 1.5.0 + """ @property + @since("1.5.0") def treeWeights(self): """Return the weights for each tree""" return list(self._call_java("javaTreeWeights")) @@ -256,6 +295,8 @@ def __repr__(self): class DecisionTreeRegressionModel(DecisionTreeModel): """ Model fitted by DecisionTreeRegressor. + + .. versionadded:: 1.4.0 """ @@ -282,6 +323,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 0.5 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -336,6 +379,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, @@ -353,6 +397,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestRegressionModel(java_model) + @since("1.4.0") def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. @@ -360,12 +405,14 @@ def setImpurity(self, value): self._paramMap[self.impurity] = value return self + @since("1.4.0") def getImpurity(self): """ Gets the value of impurity or its default value. """ return self.getOrDefault(self.impurity) + @since("1.4.0") def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. @@ -373,12 +420,14 @@ def setSubsamplingRate(self, value): self._paramMap[self.subsamplingRate] = value return self + @since("1.4.0") def getSubsamplingRate(self): """ Gets the value of subsamplingRate or its default value. """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") def setNumTrees(self, value): """ Sets the value of :py:attr:`numTrees`. @@ -386,12 +435,14 @@ def setNumTrees(self, value): self._paramMap[self.numTrees] = value return self + @since("1.4.0") def getNumTrees(self): """ Gets the value of numTrees or its default value. """ return self.getOrDefault(self.numTrees) + @since("1.4.0") def setFeatureSubsetStrategy(self, value): """ Sets the value of :py:attr:`featureSubsetStrategy`. @@ -399,6 +450,7 @@ def setFeatureSubsetStrategy(self, value): self._paramMap[self.featureSubsetStrategy] = value return self + @since("1.4.0") def getFeatureSubsetStrategy(self): """ Gets the value of featureSubsetStrategy or its default value. @@ -409,6 +461,8 @@ def getFeatureSubsetStrategy(self): class RandomForestRegressionModel(TreeEnsembleModels): """ Model fitted by RandomForestRegressor. + + .. versionadded:: 1.4.0 """ @@ -435,6 +489,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -481,6 +537,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -498,6 +555,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return GBTRegressionModel(java_model) + @since("1.4.0") def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. @@ -505,12 +563,14 @@ def setLossType(self, value): self._paramMap[self.lossType] = value return self + @since("1.4.0") def getLossType(self): """ Gets the value of lossType or its default value. """ return self.getOrDefault(self.lossType) + @since("1.4.0") def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. @@ -518,12 +578,14 @@ def setSubsamplingRate(self, value): self._paramMap[self.subsamplingRate] = value return self + @since("1.4.0") def getSubsamplingRate(self): """ Gets the value of subsamplingRate or its default value. """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. @@ -531,6 +593,7 @@ def setStepSize(self, value): self._paramMap[self.stepSize] = value return self + @since("1.4.0") def getStepSize(self): """ Gets the value of stepSize or its default value. @@ -541,6 +604,8 @@ def getStepSize(self): class GBTRegressionModel(TreeEnsembleModels): """ Model fitted by GBTRegressor. + + .. versionadded:: 1.4.0 """ From e9bae83e2a1270af7447400773c561603e3bb3ce Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:47:21 -0700 Subject: [PATCH 0035/1168] [SPARK-10281] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.clustering Author: Yu ISHIKAWA Closes #8691 from yu-iskw/SPARK-10281. --- python/pyspark/ml/clustering.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index cb4c16e25a7a3..7bb8ab94e17df 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -26,8 +27,11 @@ class KMeansModel(JavaModel): """ Model fitted by KMeans. + + .. versionadded:: 1.5.0 """ + @since("1.5.0") def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] @@ -55,6 +59,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol True >>> rows[2].prediction == rows[3].prediction True + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -88,6 +94,7 @@ def _create_model(self, java_model): return KMeansModel(java_model) @keyword_only + @since("1.5.0") def setParams(self, featuresCol="features", predictionCol="prediction", k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): """ @@ -99,6 +106,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", k=2, kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.5.0") def setK(self, value): """ Sets the value of :py:attr:`k`. @@ -110,12 +118,14 @@ def setK(self, value): self._paramMap[self.k] = value return self + @since("1.5.0") def getK(self): """ Gets the value of `k` """ return self.getOrDefault(self.k) + @since("1.5.0") def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. @@ -130,12 +140,14 @@ def setInitMode(self, value): self._paramMap[self.initMode] = value return self + @since("1.5.0") def getInitMode(self): """ Gets the value of `initMode` """ return self.getOrDefault(self.initMode) + @since("1.5.0") def setInitSteps(self, value): """ Sets the value of :py:attr:`initSteps`. @@ -147,6 +159,7 @@ def setInitSteps(self, value): self._paramMap[self.initSteps] = value return self + @since("1.5.0") def getInitSteps(self): """ Gets the value of `initSteps` From f8b860a051a16b397b41f9abb17bf7db8a19d469 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:48:45 -0700 Subject: [PATCH 0036/1168] [SPARK-10278] [MLLIB] [PYSPARK] Add @since annotation to pyspark.mllib.tree Author: Yu ISHIKAWA Closes #8685 from yu-iskw/SPARK-10278. --- python/pyspark/mllib/tree.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 372b86a7c95d9..0001b60093a69 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -19,7 +19,7 @@ import random -from pyspark import SparkContext, RDD +from pyspark import SparkContext, RDD, since from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -30,6 +30,11 @@ class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): + """TreeEnsembleModel + + .. versionadded:: 1.3.0 + """ + @since("1.3.0") def predict(self, x): """ Predict values for a single data point or an RDD of points using @@ -45,12 +50,14 @@ def predict(self, x): else: return self.call("predict", _convert_to_vector(x)) + @since("1.3.0") def numTrees(self): """ Get number of trees in ensemble. """ return self.call("numTrees") + @since("1.3.0") def totalNumNodes(self): """ Get total number of nodes, summed over all trees in the @@ -62,6 +69,7 @@ def __repr__(self): """ Summary of model """ return self._java_model.toString() + @since("1.3.0") def toDebugString(self): """ Full model """ return self._java_model.toDebugString() @@ -72,7 +80,10 @@ class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): .. note:: Experimental A decision tree model for classification or regression. + + .. versionadded:: 1.1.0 """ + @since("1.1.0") def predict(self, x): """ Predict the label of one or more examples. @@ -90,16 +101,23 @@ def predict(self, x): else: return self.call("predict", _convert_to_vector(x)) + @since("1.1.0") def numNodes(self): + """Get number of nodes in tree, including leaf nodes.""" return self._java_model.numNodes() + @since("1.1.0") def depth(self): + """Get depth of tree. + E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + """ return self._java_model.depth() def __repr__(self): """ summary of model. """ return self._java_model.toString() + @since("1.2.0") def toDebugString(self): """ full model. """ return self._java_model.toDebugString() @@ -115,6 +133,8 @@ class DecisionTree(object): Learning algorithm for a decision tree model for classification or regression. + + .. versionadded:: 1.1.0 """ @classmethod @@ -127,6 +147,7 @@ def _train(cls, data, type, numClasses, features, impurity="gini", maxDepth=5, m return DecisionTreeModel(model) @classmethod + @since("1.1.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): @@ -185,6 +206,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) @classmethod + @since("1.1.0") def trainRegressor(cls, data, categoricalFeaturesInfo, impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): @@ -239,6 +261,8 @@ class RandomForestModel(TreeEnsembleModel, JavaLoader): .. note:: Experimental Represents a random forest model. + + .. versionadded:: 1.2.0 """ @classmethod @@ -252,6 +276,8 @@ class RandomForest(object): Learning algorithm for a random forest model for classification or regression. + + .. versionadded:: 1.2.0 """ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") @@ -271,6 +297,7 @@ def _train(cls, data, algo, numClasses, categoricalFeaturesInfo, numTrees, return RandomForestModel(model) @classmethod + @since("1.2.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32, seed=None): @@ -352,6 +379,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, maxDepth, maxBins, seed) @classmethod + @since("1.2.0") def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="variance", maxDepth=4, maxBins=32, seed=None): """ @@ -418,6 +446,8 @@ class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): .. note:: Experimental Represents a gradient-boosted tree model. + + .. versionadded:: 1.3.0 """ @classmethod @@ -431,6 +461,8 @@ class GradientBoostedTrees(object): Learning algorithm for a gradient boosted trees model for classification or regression. + + .. versionadded:: 1.3.0 """ @classmethod @@ -443,6 +475,7 @@ def _train(cls, data, algo, categoricalFeaturesInfo, return GradientBoostedTreesModel(model) @classmethod + @since("1.3.0") def trainClassifier(cls, data, categoricalFeaturesInfo, loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): @@ -505,6 +538,7 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, loss, numIterations, learningRate, maxDepth, maxBins) @classmethod + @since("1.3.0") def trainRegressor(cls, data, categoricalFeaturesInfo, loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): From 647f901a6ef194eca88b16bf1377a2fb0277ef65 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:50:00 -0700 Subject: [PATCH 0037/1168] [SPARK-10279] [MLLIB] [PYSPARK] [DOCS] Add @since annotation to pyspark.mllib.util Author: Yu ISHIKAWA Closes #8689 from yu-iskw/SPARK-10279. --- python/pyspark/mllib/util.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 10a1e4b3eb0fc..39bc6586dd582 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -23,7 +23,7 @@ xrange = range basestring = str -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector @@ -32,6 +32,8 @@ class MLUtils(object): """ Helper methods to load, save and pre-process data used in MLlib. + + .. versionadded:: 1.0.0 """ @staticmethod @@ -69,6 +71,7 @@ def _convert_labeled_point_to_libsvm(p): return " ".join(items) @staticmethod + @since("1.0.0") def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None): """ Loads labeled data in the LIBSVM format into an RDD of @@ -123,6 +126,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) @staticmethod + @since("1.0.0") def saveAsLibSVMFile(data, dir): """ Save labeled data in LIBSVM format. @@ -147,6 +151,7 @@ def saveAsLibSVMFile(data, dir): lines.saveAsTextFile(dir) @staticmethod + @since("1.1.0") def loadLabeledPoints(sc, path, minPartitions=None): """ Load labeled points saved using RDD.saveAsTextFile. @@ -172,6 +177,7 @@ def loadLabeledPoints(sc, path, minPartitions=None): return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) @staticmethod + @since("1.5.0") def appendBias(data): """ Returns a new vector with `1.0` (bias) appended to @@ -186,6 +192,7 @@ def appendBias(data): return _convert_to_vector(np.append(vec.toArray(), 1.0)) @staticmethod + @since("1.5.0") def loadVectors(sc, path): """ Loads vectors saved using `RDD[Vector].saveAsTextFile` @@ -197,6 +204,8 @@ def loadVectors(sc, path): class Saveable(object): """ Mixin for models and transformers which may be saved as files. + + .. versionadded:: 1.3.0 """ def save(self, sc, path): @@ -222,9 +231,13 @@ class JavaSaveable(Saveable): """ Mixin for models that provide save() through their Scala implementation. + + .. versionadded:: 1.3.0 """ + @since("1.3.0") def save(self, sc, path): + """Save this model to the given path.""" if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) if not isinstance(path, basestring): @@ -235,6 +248,8 @@ def save(self, sc, path): class Loader(object): """ Mixin for classes which can load saved models from files. + + .. versionadded:: 1.3.0 """ @classmethod @@ -256,6 +271,8 @@ class JavaLoader(Loader): """ Mixin for classes which can load saved models using its Scala implementation. + + .. versionadded:: 1.3.0 """ @classmethod @@ -280,15 +297,21 @@ def _load_java(cls, sc, path): return java_obj.load(sc._jsc.sc(), path) @classmethod + @since("1.3.0") def load(cls, sc, path): + """Load a model from the given path.""" java_model = cls._load_java(sc, path) return cls(java_model) class LinearDataGenerator(object): - """Utils for generating linear data""" + """Utils for generating linear data. + + .. versionadded:: 1.5.0 + """ @staticmethod + @since("1.5.0") def generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps): """ @@ -311,6 +334,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, xVariance, int(nPoints), int(seed), float(eps))) @staticmethod + @since("1.5.0") def generateLinearRDD(sc, nexamples, nfeatures, eps, nParts=2, intercept=0.0): """ From 7bb7bee5bf016dd809abeb393ecc9ff020546d66 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:50:46 -0700 Subject: [PATCH 0038/1168] [SPARK-10274] [MLLIB] Add @since annotation to pyspark.mllib.fpm Author: Yu ISHIKAWA Closes #8665 from yu-iskw/SPARK-10274. --- python/pyspark/mllib/fpm.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index bdc4a132b1b18..bdabba9602a8c 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -19,7 +19,7 @@ from numpy import array from collections import namedtuple -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc @@ -41,8 +41,11 @@ class FPGrowthModel(JavaModelWrapper): >>> model = FPGrowth.train(rdd, 0.6, 2) >>> sorted(model.freqItemsets().collect()) [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + + .. versionadded:: 1.4.0 """ + @since("1.4.0") def freqItemsets(self): """ Returns the frequent itemsets of this model. @@ -55,9 +58,12 @@ class FPGrowth(object): .. note:: Experimental A Parallel FP-growth algorithm to mine frequent itemsets. + + .. versionadded:: 1.4.0 """ @classmethod + @since("1.4.0") def train(cls, data, minSupport=0.3, numPartitions=-1): """ Computes an FP-Growth model that contains frequent itemsets. @@ -74,6 +80,8 @@ def train(cls, data, minSupport=0.3, numPartitions=-1): class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])): """ Represents an (items, freq) tuple. + + .. versionadded:: 1.4.0 """ From 7f7a0767d522f42f70d1c5f5835579774191b744 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:51:19 -0700 Subject: [PATCH 0039/1168] [SPARK-10282] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.recommendation Author: Yu ISHIKAWA Closes #8692 from yu-iskw/SPARK-10282. --- python/pyspark/ml/recommendation.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index b06099ac0aee6..ec5748a1cfe94 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -80,6 +81,8 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=3.19...) >>> predictions[2] Row(user=2, item=0, prediction=-1.15...) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -122,6 +125,7 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10): @@ -137,6 +141,7 @@ def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItem def _create_model(self, java_model): return ALSModel(java_model) + @since("1.4.0") def setRank(self, value): """ Sets the value of :py:attr:`rank`. @@ -144,12 +149,14 @@ def setRank(self, value): self._paramMap[self.rank] = value return self + @since("1.4.0") def getRank(self): """ Gets the value of rank or its default value. """ return self.getOrDefault(self.rank) + @since("1.4.0") def setNumUserBlocks(self, value): """ Sets the value of :py:attr:`numUserBlocks`. @@ -157,12 +164,14 @@ def setNumUserBlocks(self, value): self._paramMap[self.numUserBlocks] = value return self + @since("1.4.0") def getNumUserBlocks(self): """ Gets the value of numUserBlocks or its default value. """ return self.getOrDefault(self.numUserBlocks) + @since("1.4.0") def setNumItemBlocks(self, value): """ Sets the value of :py:attr:`numItemBlocks`. @@ -170,12 +179,14 @@ def setNumItemBlocks(self, value): self._paramMap[self.numItemBlocks] = value return self + @since("1.4.0") def getNumItemBlocks(self): """ Gets the value of numItemBlocks or its default value. """ return self.getOrDefault(self.numItemBlocks) + @since("1.4.0") def setNumBlocks(self, value): """ Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value. @@ -183,6 +194,7 @@ def setNumBlocks(self, value): self._paramMap[self.numUserBlocks] = value self._paramMap[self.numItemBlocks] = value + @since("1.4.0") def setImplicitPrefs(self, value): """ Sets the value of :py:attr:`implicitPrefs`. @@ -190,12 +202,14 @@ def setImplicitPrefs(self, value): self._paramMap[self.implicitPrefs] = value return self + @since("1.4.0") def getImplicitPrefs(self): """ Gets the value of implicitPrefs or its default value. """ return self.getOrDefault(self.implicitPrefs) + @since("1.4.0") def setAlpha(self, value): """ Sets the value of :py:attr:`alpha`. @@ -203,12 +217,14 @@ def setAlpha(self, value): self._paramMap[self.alpha] = value return self + @since("1.4.0") def getAlpha(self): """ Gets the value of alpha or its default value. """ return self.getOrDefault(self.alpha) + @since("1.4.0") def setUserCol(self, value): """ Sets the value of :py:attr:`userCol`. @@ -216,12 +232,14 @@ def setUserCol(self, value): self._paramMap[self.userCol] = value return self + @since("1.4.0") def getUserCol(self): """ Gets the value of userCol or its default value. """ return self.getOrDefault(self.userCol) + @since("1.4.0") def setItemCol(self, value): """ Sets the value of :py:attr:`itemCol`. @@ -229,12 +247,14 @@ def setItemCol(self, value): self._paramMap[self.itemCol] = value return self + @since("1.4.0") def getItemCol(self): """ Gets the value of itemCol or its default value. """ return self.getOrDefault(self.itemCol) + @since("1.4.0") def setRatingCol(self, value): """ Sets the value of :py:attr:`ratingCol`. @@ -242,12 +262,14 @@ def setRatingCol(self, value): self._paramMap[self.ratingCol] = value return self + @since("1.4.0") def getRatingCol(self): """ Gets the value of ratingCol or its default value. """ return self.getOrDefault(self.ratingCol) + @since("1.4.0") def setNonnegative(self, value): """ Sets the value of :py:attr:`nonnegative`. @@ -255,6 +277,7 @@ def setNonnegative(self, value): self._paramMap[self.nonnegative] = value return self + @since("1.4.0") def getNonnegative(self): """ Gets the value of nonnegative or its default value. @@ -265,14 +288,18 @@ def getNonnegative(self): class ALSModel(JavaModel): """ Model fitted by ALS. + + .. versionadded:: 1.4.0 """ @property + @since("1.4.0") def rank(self): """rank of the matrix factorization model""" return self._call_java("rank") @property + @since("1.4.0") def userFactors(self): """ a DataFrame that stores user factors in two columns: `id` and @@ -281,6 +308,7 @@ def userFactors(self): return self._call_java("userFactors") @property + @since("1.4.0") def itemFactors(self): """ a DataFrame that stores item factors in two columns: `id` and From 285924705f801ef8e22dd62ce53c831150b18e7f Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 17 Sep 2015 09:17:43 -0700 Subject: [PATCH 0040/1168] [SPARK-10077] [DOCS] [ML] Add package info for java of ml/feature Should be the same as SPARK-7808 but use Java for the code example. It would be great to add package doc for `spark.ml.feature`. Author: Holden Karau Closes #8740 from holdenk/SPARK-10077-JAVA-PACKAGE-DOC-FOR-SPARK.ML.FEATURE. --- .../apache/spark/ml/feature/package-info.java | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java new file mode 100644 index 0000000000000..c22d2e0cd2d90 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** + * Feature transformers + * + * The `ml.feature` package provides common feature transformers that help convert raw data or + * features into more suitable forms for model fitting. + * Most feature transformers are implemented as {@link org.apache.spark.ml.Transformer}s, which + * transforms one {@link org.apache.spark.sql.DataFrame} into another, e.g., + * {@link org.apache.spark.feature.HashingTF}. + * Some feature transformers are implemented as {@link org.apache.spark.ml.Estimator}}s, because the + * transformation requires some aggregated information of the dataset, e.g., document + * frequencies in {@link org.apache.spark.ml.feature.IDF}. + * For those feature transformers, calling {@link org.apache.spark.ml.Estimator#fit} is required to + * obtain the model first, e.g., {@link org.apache.spark.ml.feature.IDFModel}, in order to apply + * transformation. + * The transformation is usually done by appending new columns to the input + * {@link org.apache.spark.sql.DataFrame}, so all input columns are carried over. + * + * We try to make each transformer minimal, so it becomes flexible to assemble feature + * transformation pipelines. + * {@link org.apache.spark.ml.Pipeline} can be used to chain feature transformers, and + * {@link org.apache.spark.ml.feature.VectorAssembler} can be used to combine multiple feature + * transformations, for example: + * + *
+ * 
+ *   import java.util.Arrays;
+ *
+ *   import org.apache.spark.api.java.JavaRDD;
+ *   import static org.apache.spark.sql.types.DataTypes.*;
+ *   import org.apache.spark.sql.types.StructType;
+ *   import org.apache.spark.sql.DataFrame;
+ *   import org.apache.spark.sql.RowFactory;
+ *   import org.apache.spark.sql.Row;
+ *
+ *   import org.apache.spark.ml.feature.*;
+ *   import org.apache.spark.ml.Pipeline;
+ *   import org.apache.spark.ml.PipelineStage;
+ *   import org.apache.spark.ml.PipelineModel;
+ *
+ *  // a DataFrame with three columns: id (integer), text (string), and rating (double).
+ *  StructType schema = createStructType(
+ *    Arrays.asList(
+ *      createStructField("id", IntegerType, false),
+ *      createStructField("text", StringType, false),
+ *      createStructField("rating", DoubleType, false)));
+ *  JavaRDD rowRDD = jsc.parallelize(
+ *    Arrays.asList(
+ *      RowFactory.create(0, "Hi I heard about Spark", 3.0),
+ *      RowFactory.create(1, "I wish Java could use case classes", 4.0),
+ *      RowFactory.create(2, "Logistic regression models are neat", 4.0)));
+ *  DataFrame df = jsql.createDataFrame(rowRDD, schema);
+ *  // define feature transformers
+ *  RegexTokenizer tok = new RegexTokenizer()
+ *    .setInputCol("text")
+ *    .setOutputCol("words");
+ *  StopWordsRemover sw = new StopWordsRemover()
+ *    .setInputCol("words")
+ *    .setOutputCol("filtered_words");
+ *  HashingTF tf = new HashingTF()
+ *    .setInputCol("filtered_words")
+ *    .setOutputCol("tf")
+ *    .setNumFeatures(10000);
+ *  IDF idf = new IDF()
+ *    .setInputCol("tf")
+ *    .setOutputCol("tf_idf");
+ *  VectorAssembler assembler = new VectorAssembler()
+ *    .setInputCols(new String[] {"tf_idf", "rating"})
+ *    .setOutputCol("features");
+ *
+ *  // assemble and fit the feature transformation pipeline
+ *  Pipeline pipeline = new Pipeline()
+ *    .setStages(new PipelineStage[] {tok, sw, tf, idf, assembler});
+ *  PipelineModel model = pipeline.fit(df);
+ *
+ *  // save transformed features with raw data
+ *  model.transform(df)
+ *    .select("id", "text", "rating", "features")
+ *    .write().format("parquet").save("/output/path");
+ * 
+ * 
+ * + * Some feature transformers implemented in MLlib are inspired by those implemented in scikit-learn. + * The major difference is that most scikit-learn feature transformers operate eagerly on the entire + * input dataset, while MLlib's feature transformers operate lazily on individual columns, + * which is more efficient and flexible to handle large and complex datasets. + * + * @see + * scikit-learn.preprocessing + */ +package org.apache.spark.ml.feature; From 6710e031dc5298392fef7fd191864fff5f637183 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Sep 2015 09:21:21 -0700 Subject: [PATCH 0041/1168] [SPARK-10459] [SQL] Do not need to have ConvertToSafe for PythonUDF JIRA: https://issues.apache.org/jira/browse/SPARK-10459 As mentioned in the JIRA, `PythonUDF` actually could process `UnsafeRow`. Specially, the rows in `childResults` in `BatchPythonEvaluation` will be projected to a `MutableRow`. So I think we can enable `canProcessUnsafeRows` for `BatchPythonEvaluation` and get rid of redundant `ConvertToSafe`. Author: Liang-Chi Hsieh Closes #8616 from viirya/pyudf-unsafe. --- .../scala/org/apache/spark/sql/execution/pythonUDFs.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 5a58d846ad80b..d0411da6fdf5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -337,6 +337,10 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil + override def outputsUnsafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + protected override def doExecute(): RDD[InternalRow] = { val childResults = child.execute().map(_.copy()) From cb75be57095e2e83afdfd9f4e336da63d50ef7a4 Mon Sep 17 00:00:00 2001 From: "yangping.wu" Date: Thu, 17 Sep 2015 09:52:40 -0700 Subject: [PATCH 0042/1168] [SPARK-10660] Doc describe error in the "Running Spark on YARN" page MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In the Configuration section, the **spark.yarn.driver.memoryOverhead** and **spark.yarn.am.memoryOverhead**‘s default value should be "driverMemory * 0.10, with minimum of 384" and "AM memory * 0.10, with minimum of 384" respectively. Because from Spark 1.4.0, the **MEMORY_OVERHEAD_FACTOR** is set to 0.1.0, not 0.07. Author: yangping.wu Closes #8797 from 397090770/SparkOnYarnDocError. --- docs/running-on-yarn.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index d1244323edfff..3a961d245f3de 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -211,14 +211,14 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.driver.memoryOverhead - driverMemory * 0.07, with minimum of 384 + driverMemory * 0.10, with minimum of 384 The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). spark.yarn.am.memoryOverhead - AM memory * 0.07, with minimum of 384 + AM memory * 0.10, with minimum of 384 Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode. From 85aa98ad055afe7a06902f45f82a5a8cd48f5794 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Sep 2015 10:02:15 -0700 Subject: [PATCH 0043/1168] [SPARK-10642] [PYSPARK] Fix crash when calling rdd.lookup() on tuple keys JIRA: https://issues.apache.org/jira/browse/SPARK-10642 When calling `rdd.lookup()` on a RDD with tuple keys, `portable_hash` will return a long. That causes `DAGScheduler.submitJob` to throw `java.lang.ClassCastException: java.lang.Long cannot be cast to java.lang.Integer`. Author: Liang-Chi Hsieh Closes #8796 from viirya/fix-pyrdd-lookup. --- python/pyspark/rdd.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9ef60a7e2c84b..ab5aab1e115f7 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -84,7 +84,7 @@ def portable_hash(x): h ^= len(x) if h == -1: h = -2 - return h + return int(h) return hash(x) @@ -2192,6 +2192,9 @@ def lookup(self, key): [42] >>> sorted.lookup(1024) [] + >>> rdd2 = sc.parallelize([(('a', 'b'), 'c')]).groupByKey() + >>> list(rdd2.lookup(('a', 'b'))[0]) + ['c'] """ values = self.filter(lambda kv: kv[0] == key).values() From a5082616941dcb522f247250d83c6dd2d95d1cf1 Mon Sep 17 00:00:00 2001 From: Josiah Samuel Date: Thu, 17 Sep 2015 10:18:21 -0700 Subject: [PATCH 0044/1168] [SPARK-10172] [CORE] disable sort in HistoryServer webUI This pull request is to address the JIRA SPARK-10172 (History Server web UI gets messed up when sorting on any column). The content of the table gets messed up due to the rowspan attribute of the table data(cell) during sorting. The current table sort library used in SparkUI (sorttable.js) doesn't support/handle cells(td) with rowspans. The fix will disable the table sort in the web UI, when there are jobs listed with multiple attempts. Author: Josiah Samuel Closes #8506 from josiahsams/SPARK-10172. --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 0830cc1ba1245..b347cb3be69f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -51,7 +51,10 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val hasMultipleAttempts = appsToShow.exists(_.attempts.size > 1) val appTable = if (hasMultipleAttempts) { - UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, appsToShow) + // Sorting is disable here as table sort on rowspan has issues. + // ref. SPARK-10172 + UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, + appsToShow, sortable = false) } else { UIUtils.listingTable(appHeader, appRow, appsToShow) } From e114044dc781d072824e7927b006c5270fbd45f3 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 17 Sep 2015 10:25:18 -0700 Subject: [PATCH 0045/1168] [SPARK-10531] [CORE] AppId is set as AppName in status rest api Verify it manually. Author: Jeff Zhang Closes #8688 from zjffdu/SPARK-10531. --- .../main/scala/org/apache/spark/SparkContext.scala | 1 + .../spark/deploy/history/FsHistoryProvider.scala | 9 ++++----- .../scala/org/apache/spark/deploy/master/Master.scala | 2 +- core/src/main/scala/org/apache/spark/ui/SparkUI.scala | 11 ++++++----- .../scala/org/apache/spark/ui/UISeleniumSuite.scala | 2 +- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a2f34eafa2c38..9c3218719f7fc 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -521,6 +521,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _applicationId = _taskScheduler.applicationId() _applicationAttemptId = taskScheduler.applicationAttemptId() _conf.set("spark.app.id", _applicationId) + _ui.foreach(_.setAppId(_applicationId)) _env.blockManager.initialize(_applicationId) // The metrics system for Driver need to be set spark.app.id to app ID. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a5755eac36396..8eb2ba1e8683b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -146,16 +146,15 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val ui = { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId, + SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name, HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } val appListener = new ApplicationEventListener() replayBus.addListener(appListener) - val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - appInfo.map { info => - ui.setAppName(s"${info.name} ($appId)") - + val appAttemptInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), + replayBus) + appAttemptInfo.map { info => val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) ui.getSecurityManager.setAcls(uiAclsEnabled) // make sure to set admin acls before view acls so they are properly picked up diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 26904d39a9bec..d518e92133aad 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -944,7 +944,7 @@ private[deploy] class Master( val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) val replayBus = new ReplayListenerBus() val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), - appName + status, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) + appName, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) val maybeTruncated = eventLogFile.endsWith(EventLoggingListener.IN_PROGRESS) try { replayBus.replay(logInput, eventLogFile, maybeTruncated) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index d8b90568b7b9a..99085ada9f0af 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -56,6 +56,8 @@ private[spark] class SparkUI private ( val stagesTab = new StagesTab(this) + var appId: String = _ + /** Initialize all components of the server. */ def initialize() { attachTab(new JobsTab(this)) @@ -75,9 +77,8 @@ private[spark] class SparkUI private ( def getAppName: String = appName - /** Set the app name for this UI. */ - def setAppName(name: String) { - appName = name + def setAppId(id: String): Unit = { + appId = id } /** Stop the server behind this web interface. Only valid after bind(). */ @@ -94,12 +95,12 @@ private[spark] class SparkUI private ( private[spark] def appUIAddress = s"http://$appUIHostPort" def getSparkUI(appId: String): Option[SparkUI] = { - if (appId == appName) Some(this) else None + if (appId == this.appId) Some(this) else None } def getApplicationInfoList: Iterator[ApplicationInfo] = { Iterator(new ApplicationInfo( - id = appName, + id = appId, name = appName, attempts = Seq(new ApplicationAttemptInfo( attemptId = None, diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 22e30ecaf0533..18eec7da9763e 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -658,6 +658,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def apiUrl(ui: SparkUI, path: String): URL = { - new URL(ui.appUIAddress + "/api/v1/applications/test/" + path) + new URL(ui.appUIAddress + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) } } From f8aeb6aa9d8028d3167525ea576f452093865ce0 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 17 Sep 2015 11:05:30 -0700 Subject: [PATCH 0046/1168] [SPARK-10650] Clean before building docs The [published docs for 1.5.0](http://spark.apache.org/docs/1.5.0/api/java/org/apache/spark/streaming/) have a bunch of test classes in them. The only way I can reproduce this is to `test:compile` before running `unidoc`. To prevent this from happening again, I've added a clean before doc generation. Author: Michael Armbrust Closes #8787 from marmbrus/testsInDocs. --- docs/_plugins/copy_api_dirs.rb | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 15ceda11a8a80..01718d98dffe0 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -26,12 +26,15 @@ curr_dir = pwd cd("..") - puts "Running 'build/sbt -Pkinesis-asl compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `build/sbt -Pkinesis-asl compile unidoc` + puts "Running 'build/sbt -Pkinesis-asl clean compile unidoc' from " + pwd + "; this may take a few minutes..." + puts `build/sbt -Pkinesis-asl clean compile unidoc` puts "Moving back into docs dir." cd("docs") + puts "Removing old docs" + puts `rm -rf api` + # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. source = "../target/scala-2.10/unidoc" From 3cb3ad0baef3e43b12863326e82a4afd64cfbeb2 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 17 Sep 2015 11:14:52 -0700 Subject: [PATCH 0047/1168] [SPARK-10639] [SQL] Need to convert UDAF's result from scala to sql type https://issues.apache.org/jira/browse/SPARK-10639 Author: Yin Huai Closes #8788 from yhuai/udafConversion. --- .../sql/catalyst/CatalystTypeConverters.scala | 7 +- .../spark/sql/RandomDataGenerator.scala | 16 ++- .../spark/sql/execution/aggregate/udaf.scala | 37 +++++- .../org/apache/spark/sql/QueryTest.scala | 21 ++-- .../spark/sql/UserDefinedTypeSuite.scala | 11 ++ .../execution/AggregationQuerySuite.scala | 108 +++++++++++++++++- 6 files changed, 188 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 966623ed017ba..f25591794abdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -138,8 +138,13 @@ object CatalystTypeConverters { private case class UDTConverter( udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + // toCatalyst (it calls toCatalystImpl) will do null check. override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) - override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) + + override def toScala(catalystValue: Any): Any = { + if (catalystValue == null) null else udt.deserialize(catalystValue) + } + override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row.get(column, udt.sqlType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 4025cbcec1019..e48395028e399 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -108,7 +108,21 @@ object RandomDataGenerator { arr }) case BooleanType => Some(() => rand.nextBoolean()) - case DateType => Some(() => new java.sql.Date(rand.nextInt())) + case DateType => + val generator = + () => { + var milliseconds = rand.nextLong() % 253402329599999L + // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT + // for "0001-01-01 00:00:00.000000". We need to find a + // number that is greater or equals to this number as a valid timestamp value. + while (milliseconds < -62135740800000L) { + // 253402329599999L is the the number of milliseconds since + // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". + milliseconds = rand.nextLong() % 253402329599999L + } + DateTimeUtils.toJavaDate((milliseconds / DateTimeUtils.MILLIS_PER_DAY).toInt) + } + Some(generator) case TimestampType => val generator = () => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index d43d3dd9ffaae..1114fe6552bdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -40,6 +40,9 @@ sealed trait BufferSetterGetterUtils { var i = 0 while (i < getters.length) { getters(i) = dataTypes(i) match { + case NullType => + (row: InternalRow, ordinal: Int) => null + case BooleanType => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal) @@ -74,6 +77,14 @@ sealed trait BufferSetterGetterUtils { (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale) + case DateType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getInt(ordinal) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getLong(ordinal) + case other => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.get(ordinal, other) @@ -92,6 +103,9 @@ sealed trait BufferSetterGetterUtils { var i = 0 while (i < setters.length) { setters(i) = dataTypes(i) match { + case NullType => + (row: MutableRow, ordinal: Int, value: Any) => row.setNullAt(ordinal) + case b: BooleanType => (row: MutableRow, ordinal: Int, value: Any) => if (value != null) { @@ -150,9 +164,23 @@ sealed trait BufferSetterGetterUtils { case dt: DecimalType => val precision = dt.precision + (row: MutableRow, ordinal: Int, value: Any) => + // To make it work with UnsafeRow, we cannot use setNullAt. + // Please see the comment of UnsafeRow's setDecimal. + row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + + case DateType => (row: MutableRow, ordinal: Int, value: Any) => if (value != null) { - row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + row.setInt(ordinal, value.asInstanceOf[Int]) + } else { + row.setNullAt(ordinal) + } + + case TimestampType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setLong(ordinal, value.asInstanceOf[Long]) } else { row.setNullAt(ordinal) } @@ -205,6 +233,7 @@ private[sql] class MutableAggregationBufferImpl ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } + toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i))) } @@ -352,6 +381,10 @@ private[sql] case class ScalaUDAF( } } + private[this] lazy val outputToCatalystConverter: Any => Any = { + CatalystTypeConverters.createToCatalystConverter(dataType) + } + // This buffer is only used at executor side. private[this] var inputAggregateBuffer: InputAggregationBuffer = null @@ -424,7 +457,7 @@ private[sql] case class ScalaUDAF( override def eval(buffer: InternalRow): Any = { evalAggregateBuffer.underlyingInputBuffer = buffer - udaf.evaluate(evalAggregateBuffer) + outputToCatalystConverter(udaf.evaluate(evalAggregateBuffer)) } override def toString: String = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index cada03e9ac6bb..e3c5a426671d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -115,19 +115,26 @@ object QueryTest { */ def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case d: java.math.BigDecimal => BigDecimal(d) + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for // equality test. - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } + val converted: Seq[Row] = answer.map(prepareRow) if (!isSorted) converted.sortBy(_.toString()) else converted } val sparkAnswer = try df.collect().toSeq catch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 46d87843dfa4d..7992fd59ff4ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty} import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -163,4 +164,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { assert(new MyDenseVectorUDT().typeName === "mydensevector") assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset") } + + test("Catalyst type converter null handling for UDTs") { + val udt = new MyDenseVectorUDT() + val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt) + assert(toScalaConverter(null) === null) + + val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt) + assert(toCatalystConverter(null) === null) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index a73b1bd52c09f..24b1846923c77 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,13 +17,55 @@ package org.apache.spark.sql.hive.execution +import scala.collection.JavaConverters._ + import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton +class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { + + def inputSchema: StructType = schema + + def bufferSchema: StructType = schema + + def dataType: DataType = schema + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + (0 until schema.length).foreach { i => + buffer.update(i, null) + } + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!input.isNullAt(0) && input.getInt(0) == 50) { + (0 until schema.length).foreach { i => + buffer.update(i, input.get(i)) + } + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) { + (0 until schema.length).foreach { i => + buffer1.update(i, buffer2.get(i)) + } + } + } + + def evaluate(buffer: Row): Any = { + Row.fromSeq(buffer.toSeq) + } +} + abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -508,6 +550,70 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) } } + + test("udaf with all data types") { + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + // Right now, we will use SortBasedAggregate to handle UDAFs. + // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use + // UnsafeRow as the aggregation buffer. While, dataTypes will trigger + // SortBasedAggregate to use a safe row as the aggregation buffer. + Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes => + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + // The schema used for data generator. + val schemaForGenerator = StructType(fields) + // The schema used for the DataFrame df. + val schema = StructType(StructField("id", IntegerType) +: fields) + + logInfo(s"Testing schema: ${schema.treeString}") + + val udaf = new ScalaAggregateFunction(schema) + // Generate data at the driver side. We need to materialize the data first and then + // create RDD. + val maybeDataGenerator = + RandomDataGenerator.forType( + dataType = schemaForGenerator, + nullable = true, + seed = Some(System.nanoTime())) + val dataGenerator = + maybeDataGenerator + .getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator")) + val data = (1 to 50).map { i => + dataGenerator.apply() match { + case row: Row => Row.fromSeq(i +: row.toSeq) + case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null)) + case other => + fail(s"Row or null is expected to be generated, " + + s"but a ${other.getClass.getCanonicalName} is generated.") + } + } + + // Create a DF for the schema with random data. + val rdd = sqlContext.sparkContext.parallelize(data, 1) + val df = sqlContext.createDataFrame(rdd, schema) + + val allColumns = df.schema.fields.map(f => col(f.name)) + val expectedAnaswer = + data + .find(r => r.getInt(0) == 50) + .getOrElse(fail("A row with id 50 should be the expected answer.")) + checkAnswer( + df.groupBy().agg(udaf(allColumns: _*)), + // udaf returns a Row as the output value. + Row(expectedAnaswer) + ) + } + } } class SortBasedAggregationQuerySuite extends AggregationQuerySuite { From ea06e502184ffd40e2b4a6382bb942bb2bdfb1d0 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Sep 2015 11:24:38 -0700 Subject: [PATCH 0048/1168] [SPARK-10394] [ML] Make GBTParams use shared stepSize ```GBTParams``` has ```stepSize``` as learning rate currently. ML has shared param class ```HasStepSize```, ```GBTParams``` can extend from it rather than duplicated implementation. Author: Yanbo Liang Closes #8552 from yanboliang/spark-10394. --- .../org/apache/spark/ml/tree/treeParams.scala | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index d29f5253c9c3f..42e74ce6d2c69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, HasSeed, HasThresholds} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -365,17 +365,7 @@ private[ml] object RandomForestParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { - - /** - * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each - * estimator. - * (default = 0.1) - * @group param - */ - final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + - " learning rate) in interval (0, 1] for shrinking the contribution of each estimator", - ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { /* TODO: Add this doc when we add this param. SPARK-7132 * Threshold for stopping early when runWithValidation is used. @@ -393,11 +383,19 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) - /** @group setParam */ + /** + * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each + * estimator. + * (default = 0.1) + * @group setParam + */ def setStepSize(value: Double): this.type = set(stepSize, value) - /** @group getParam */ - final def getStepSize: Double = $(stepSize) + override def validateParams(): Unit = { + require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)( + getStepSize), "GBT parameter stepSize should be in interval (0, 1], " + + s"but it given invalid value $getStepSize.") + } /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ private[ml] def getOldBoostingStrategy( From d2012fde61b4108ade929243315d4ef5838e53a7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 17 Sep 2015 11:40:24 -0700 Subject: [PATCH 0049/1168] [SPARK-10657] Remove SCP-based Jenkins log archiving As of https://issues.apache.org/jira/browse/SPARK-7561, we no longer need to use our custom SCP-based mechanism for archiving Jenkins logs on the master machine; this has been superseded by the use of a Jenkins plugin which archives the logs and provides public links to view them. Per shaneknapp, we should remove this log syncing mechanism if it is no longer necessary; removing the need to SCP from the Jenkins workers to the masters is a desired step as part of some larger Jenkins infra refactoring. Author: Josh Rosen Closes #8793 from JoshRosen/remove-jenkins-ssh-to-master. --- dev/run-tests-jenkins | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3be78575e70f1..d3b05fa6df0ce 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -116,39 +116,6 @@ function post_message () { fi } -function send_archived_logs () { - echo "Archiving unit tests logs..." - - local log_files=$( - find .\ - -name "unit-tests.log" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.failed" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.hiveFailed" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.wrong" - ) - - if [ -z "$log_files" ]; then - echo "> No log files found." >&2 - else - local log_archive="unit-tests-logs.tar.gz" - echo "$log_files" | xargs tar czf ${log_archive} - - local jenkins_build_dir=${JENKINS_HOME}/jobs/${JOB_NAME}/builds/${BUILD_NUMBER} - local scp_output=$(scp ${log_archive} amp-jenkins-master:${jenkins_build_dir}/${log_archive}) - local scp_status="$?" - - if [ "$scp_status" -ne 0 ]; then - echo "Failed to send archived unit tests logs to Jenkins master." >&2 - echo "> scp_status: ${scp_status}" >&2 - echo "> scp_output: ${scp_output}" >&2 - else - echo "> Send successful." - fi - - rm -f ${log_archive} - fi -} - # post start message { start_message="\ @@ -244,8 +211,6 @@ done test_result_note=" * This patch **fails $failing_test**." fi - - send_archived_logs } # post end message From 0002e997b5b9adf41577ea3afc4e67c73ba095a3 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 17 Sep 2015 14:09:06 -0700 Subject: [PATCH 0050/1168] [SPARK-9698] [ML] Add RInteraction transformer for supporting R-style feature interactions This is a pre-req for supporting the ":" operator in the RFormula feature transformer. Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit mengxr Author: Eric Liang Closes #7987 from ericl/interaction. --- .../apache/spark/ml/feature/Interaction.scala | 278 ++++++++++++++++++ .../spark/ml/feature/InteractionSuite.scala | 165 +++++++++++ 2 files changed, 443 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala new file mode 100644 index 0000000000000..9194763fb32f5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.Transformer +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Implements the feature interaction transform. This transformer takes in Double and Vector type + * columns and outputs a flattened vector of their feature interactions. To handle interaction, + * we first one-hot encode any nominal features. Then, a vector of the feature cross-products is + * produced. + * + * For example, given the input feature values `Double(2)` and `Vector(3, 4)`, the output would be + * `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal + * with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`. + */ +@Experimental +class Interaction(override val uid: String) extends Transformer + with HasInputCols with HasOutputCol { + + def this() = this(Identifiable.randomUID("interaction")) + + /** @group setParam */ + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + // optimistic schema; does not contain any ML attributes + override def transformSchema(schema: StructType): StructType = { + validateParams() + StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) + } + + override def transform(dataset: DataFrame): DataFrame = { + validateParams() + val inputFeatures = $(inputCols).map(c => dataset.schema(c)) + val featureEncoders = getFeatureEncoders(inputFeatures) + val featureAttrs = getFeatureAttrs(inputFeatures) + + def interactFunc = udf { row: Row => + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + var size = 1 + indices += 0 + values += 1.0 + var featureIndex = row.length - 1 + while (featureIndex >= 0) { + val prevIndices = indices.result() + val prevValues = values.result() + val prevSize = size + val currentEncoder = featureEncoders(featureIndex) + indices = ArrayBuilder.make[Int] + values = ArrayBuilder.make[Double] + size *= currentEncoder.outputSize + currentEncoder.foreachNonzeroOutput(row(featureIndex), (i, a) => { + var j = 0 + while (j < prevIndices.length) { + indices += prevIndices(j) + i * prevSize + values += prevValues(j) * a + j += 1 + } + }) + featureIndex -= 1 + } + Vectors.sparse(size, indices.result(), values.result()).compressed + } + + val featureCols = inputFeatures.map { f => + f.dataType match { + case DoubleType => dataset(f.name) + case _: VectorUDT => dataset(f.name) + case _: NumericType | BooleanType => dataset(f.name).cast(DoubleType) + } + } + dataset.select( + col("*"), + interactFunc(struct(featureCols: _*)).as($(outputCol), featureAttrs.toMetadata())) + } + + /** + * Creates a feature encoder for each input column, which supports efficient iteration over + * one-hot encoded feature values. See also the class-level comment of [[FeatureEncoder]]. + * + * @param features The input feature columns to create encoders for. + */ + private def getFeatureEncoders(features: Seq[StructField]): Array[FeatureEncoder] = { + def getNumFeatures(attr: Attribute): Int = { + attr match { + case nominal: NominalAttribute => + math.max(1, nominal.getNumValues.getOrElse( + throw new SparkException("Nominal features must have attr numValues defined."))) + case _ => + 1 // numeric feature + } + } + features.map { f => + val numFeatures = f.dataType match { + case _: NumericType | BooleanType => + Array(getNumFeatures(Attribute.fromStructField(f))) + case _: VectorUDT => + val attrs = AttributeGroup.fromStructField(f).attributes.getOrElse( + throw new SparkException("Vector attributes must be defined for interaction.")) + attrs.map(getNumFeatures).toArray + } + new FeatureEncoder(numFeatures) + }.toArray + } + + /** + * Generates ML attributes for the output vector of all feature interactions. We make a best + * effort to generate reasonable names for output features, based on the concatenation of the + * interacting feature names and values delimited with `_`. When no feature name is specified, + * we fall back to using the feature index (e.g. `foo:bar_2_0` may indicate an interaction + * between the numeric `foo` feature and a nominal third feature from column `bar`. + * + * @param features The input feature columns to the Interaction transformer. + */ + private def getFeatureAttrs(features: Seq[StructField]): AttributeGroup = { + var featureAttrs: Seq[Attribute] = Nil + features.reverse.foreach { f => + val encodedAttrs = f.dataType match { + case _: NumericType | BooleanType => + val attr = Attribute.fromStructField(f) + encodedFeatureAttrs(Seq(attr), None) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(f) + encodedFeatureAttrs(group.attributes.get, Some(group.name)) + } + if (featureAttrs.isEmpty) { + featureAttrs = encodedAttrs + } else { + featureAttrs = encodedAttrs.flatMap { head => + featureAttrs.map { tail => + NumericAttribute.defaultAttr.withName(head.name.get + ":" + tail.name.get) + } + } + } + } + new AttributeGroup($(outputCol), featureAttrs.toArray) + } + + /** + * Generates the output ML attributes for a single input feature. Each output feature name has + * up to three parts: the group name, feature name, and category name (for nominal features), + * each separated by an underscore. + * + * @param inputAttrs The attributes of the input feature. + * @param groupName Optional name of the input feature group (for Vector type features). + */ + private def encodedFeatureAttrs( + inputAttrs: Seq[Attribute], + groupName: Option[String]): Seq[Attribute] = { + + def format( + index: Int, + attrName: Option[String], + categoryName: Option[String]): String = { + val parts = Seq(groupName, Some(attrName.getOrElse(index.toString)), categoryName) + parts.flatten.mkString("_") + } + + inputAttrs.zipWithIndex.flatMap { + case (nominal: NominalAttribute, i) => + if (nominal.values.isDefined) { + nominal.values.get.map( + v => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(v)))) + } else { + Array.tabulate(nominal.getNumValues.get)( + j => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(j.toString)))) + } + case (a: Attribute, i) => + Seq(NumericAttribute.defaultAttr.withName(format(i, a.name, None))) + } + } + + override def copy(extra: ParamMap): Interaction = defaultCopy(extra) + + override def validateParams(): Unit = { + require(get(inputCols).isDefined, "Input cols must be defined first.") + require(get(outputCol).isDefined, "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") + } +} + +/** + * This class performs on-the-fly one-hot encoding of features as you iterate over them. To + * indicate which input features should be one-hot encoded, an array of the feature counts + * must be passed in ahead of time. + * + * @param numFeatures Array of feature counts for each input feature. For nominal features this + * count is equal to the number of categories. For numeric features the count + * should be set to 1. + */ +private[ml] class FeatureEncoder(numFeatures: Array[Int]) { + assert(numFeatures.forall(_ > 0), "Features counts must all be positive.") + + /** The size of the output vector. */ + val outputSize = numFeatures.sum + + /** Precomputed offsets for the location of each output feature. */ + private val outputOffsets = { + val arr = new Array[Int](numFeatures.length) + var i = 1 + while (i < arr.length) { + arr(i) = arr(i - 1) + numFeatures(i - 1) + i += 1 + } + arr + } + + /** + * Given an input row of features, invokes the specific function for every non-zero output. + * + * @param value The row value to encode, either a Double or Vector. + * @param f The callback to invoke on each non-zero (index, value) output pair. + */ + def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { + case d: Double => + assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") + val numOutputCols = numFeatures.head + if (numOutputCols > 1) { + assert( + d >= 0.0 && d == d.toInt && d < numOutputCols, + s"Values from column must be indices, but got $d.") + f(d.toInt, 1.0) + } else { + f(0, d) + } + case vec: Vector => + assert(numFeatures.length == vec.size, + s"Vector column size was ${vec.size}, expected ${numFeatures.length}") + vec.foreachActive { (i, v) => + val numOutputCols = numFeatures(i) + if (numOutputCols > 1) { + assert( + v >= 0.0 && v == v.toInt && v < numOutputCols, + s"Values from column must be indices, but got $v.") + f(outputOffsets(i) + v.toInt, 1.0) + } else { + f(outputOffsets(i), v) + } + } + case null => + throw new SparkException("Values to interact cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala new file mode 100644 index 0000000000000..2beb62ca08233 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.functions.col + +class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new Interaction()) + } + + test("feature encoder") { + def encode(cardinalities: Array[Int], value: Any): Vector = { + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + val encoder = new FeatureEncoder(cardinalities) + encoder.foreachNonzeroOutput(value, (i, v) => { + indices += i + values += v + }) + Vectors.sparse(encoder.outputSize, indices.result(), values.result()).compressed + } + assert(encode(Array(1), 2.2) === Vectors.dense(2.2)) + assert(encode(Array(3), Vectors.dense(1)) === Vectors.dense(0, 1, 0)) + assert(encode(Array(1, 1), Vectors.dense(1.1, 2.2)) === Vectors.dense(1.1, 2.2)) + assert(encode(Array(3, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 0, 2.2)) + assert(encode(Array(2, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 2.2)) + assert(encode(Array(2, 1, 1), Vectors.dense(0, 2.2, 0)) === Vectors.dense(1, 0, 2.2, 0)) + intercept[SparkException] { encode(Array(1), "foo") } + intercept[SparkException] { encode(Array(1), null) } + intercept[AssertionError] { encode(Array(2), 2.2) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(2.2)) } + intercept[AssertionError] { encode(Array(1), Vectors.dense(1.0, 2.0, 3.0)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(-1)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(3)) } + } + + test("numeric interaction") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as("a", NumericAttribute.defaultAttr.toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a:b_foo"), Some(1)), + new NumericAttribute(Some("a:b_bar"), Some(2)))) + assert(attrs === expectedAttrs) + } + + test("nominal interaction") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as( + "a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_up:b_foo"), Some(1)), + new NumericAttribute(Some("a_up:b_bar"), Some(2)), + new NumericAttribute(Some("a_down:b_foo"), Some(3)), + new NumericAttribute(Some("a_down:b_bar"), Some(4)), + new NumericAttribute(Some("a_left:b_foo"), Some(5)), + new NumericAttribute(Some("a_left:b_bar"), Some(6)))) + assert(attrs === expectedAttrs) + } + + test("default attr names") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0), + (1, Vectors.dense(1.0, 5.0), 10.0)) + ).toDF("a", "b", "c") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NominalAttribute.defaultAttr.withNumValues(2), + NumericAttribute.defaultAttr)) + val df = data.select( + col("a").as("a", NominalAttribute.defaultAttr.withNumValues(3).toMetadata()), + col("b").as("b", groupAttr.toMetadata()), + col("c").as("c", NumericAttribute.defaultAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), + (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0))) + ).toDF("a", "b", "c", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_0:b_0_0:c"), Some(1)), + new NumericAttribute(Some("a_0:b_0_1:c"), Some(2)), + new NumericAttribute(Some("a_0:b_1:c"), Some(3)), + new NumericAttribute(Some("a_1:b_0_0:c"), Some(4)), + new NumericAttribute(Some("a_1:b_0_1:c"), Some(5)), + new NumericAttribute(Some("a_1:b_1:c"), Some(6)), + new NumericAttribute(Some("a_2:b_0_0:c"), Some(7)), + new NumericAttribute(Some("a_2:b_0_1:c"), Some(8)), + new NumericAttribute(Some("a_2:b_1:c"), Some(9)))) + assert(attrs === expectedAttrs) + } +} From d203143f999dc054e5ee13d7c9516bc99a4e4516 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 17 Sep 2015 19:16:34 -0700 Subject: [PATCH 0051/1168] [SPARK-10674] [TESTS] Increase timeouts in SaslIntegrationSuite. 1s seems to trigger too many times on the jenkins build boxes, so increase the timeout and cross fingers. Author: Marcelo Vanzin Closes #8802 from vanzin/SPARK-10674 and squashes the following commits: 3c93117 [Marcelo Vanzin] Use java 7 syntax. d667d1b [Marcelo Vanzin] [SPARK-10674] [tests] Increase timeouts in SaslIntegrationSuite. --- .../spark/network/sasl/SaslIntegrationSuite.java | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 5cb0e4d4a6458..c393a5e1e6810 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -56,6 +56,11 @@ import org.apache.spark.network.util.TransportConf; public class SaslIntegrationSuite { + + // Use a long timeout to account for slow / overloaded build machines. In the normal case, + // tests should finish way before the timeout expires. + private final static long TIMEOUT_MS = 10_000; + static TransportServer server; static TransportConf conf; static TransportContext context; @@ -102,7 +107,7 @@ public void testGoodClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; - byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); + byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS); assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg } @@ -131,7 +136,7 @@ public void testNoSaslClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.sendRpcSync(new byte[13], 1000); + client.sendRpcSync(new byte[13], TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); @@ -139,7 +144,7 @@ public void testNoSaslClient() throws IOException { try { // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); @@ -217,12 +222,12 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable t) { new String[] { System.getProperty("java.io.tmpdir") }, 1, "org.apache.spark.shuffle.sort.SortShuffleManager"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); - client1.sendRpcSync(regmsg.toByteArray(), 10000); + client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS); // Make a successful request to fetch blocks, which creates a new stream. But do not actually // fetch any blocks, to keep the stream open. OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); - byte[] response = client1.sendRpcSync(openMessage.toByteArray(), 10000); + byte[] response = client1.sendRpcSync(openMessage.toByteArray(), TIMEOUT_MS); StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); long streamId = stream.streamId; From 63f5a299eb197fee21769f4c8f57e7e7bf2a9805 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Sep 2015 21:37:10 -0700 Subject: [PATCH 0052/1168] [SPARK-8518] [ML] Log-linear models for survival analysis [Accelerated Failure Time (AFT) model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) is the most commonly used and easy to parallel method of survival analysis for censored survival data. It is the log-linear model based on the Weibull distribution of the survival time. Users can refer to the R function [```survreg```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) to compare the model and [```predict```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/predict.survreg.html) to compare the prediction. There are different kinds of model prediction, I have just select the type ```response``` which is default used for R. Author: Yanbo Liang Closes #8611 from yanboliang/spark-8518. --- .../ml/regression/AFTSurvivalRegression.scala | 449 ++++++++++++++++++ .../AFTSurvivalRegressionSuite.scala | 311 ++++++++++++ 2 files changed, 760 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala new file mode 100644 index 0000000000000..5b25db651f56c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.collection.mutable + +import breeze.linalg.{DenseVector => BDV} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.storage.StorageLevel + +/** + * Params for accelerated failure time (AFT) regression. + */ +private[regression] trait AFTSurvivalRegressionParams extends Params + with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter + with HasTol with HasFitIntercept { + + /** + * Param for censor column name. + * The value of this column could be 0 or 1. + * If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored. + * @group param + */ + @Since("1.6.0") + final val censorCol: Param[String] = new Param(this, "censorCol", "censor column name") + + /** @group getParam */ + @Since("1.6.0") + def getCensorCol: String = $(censorCol) + setDefault(censorCol -> "censor") + + /** + * Param for quantile probabilities array. + * Values of the quantile probabilities array should be in the range [0, 1]. + * @group param + */ + @Since("1.6.0") + final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this, + "quantileProbabilities", "quantile probabilities array", + (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1))) + + /** @group getParam */ + @Since("1.6.0") + def getQuantileProbabilities: Array[Double] = $(quantileProbabilities) + + /** Checks whether the input has quantile probabilities array. */ + protected[regression] def hasQuantileProbabilities: Boolean = { + isDefined(quantileProbabilities) && $(quantileProbabilities).size != 0 + } + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param fitting whether this is in fitting or prediction + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + if (fitting) { + SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + } + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } +} + +/** + * :: Experimental :: + * Fit a parametric survival regression model named accelerated failure time (AFT) model + * ([[https://en.wikipedia.org/wiki/Accelerated_failure_time_model]]) + * based on the Weibull distribution of the survival time. + */ +@Experimental +@Since("1.6.0") +class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String) + extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging { + + @Since("1.6.0") + def this() = this(Identifiable.randomUID("aftSurvReg")) + + /** @group setParam */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setCensorCol(value: String): this.type = set(censorCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** + * Set if we should fit the intercept + * Default is true. + * @group setParam + */ + @Since("1.6.0") + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + setDefault(fitIntercept -> true) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + @Since("1.6.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + @Since("1.6.0") + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + /** + * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, + * and put it in an RDD with strong types. + */ + protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { + dataset.select($(featuresCol), $(labelCol), $(censorCol)).map { + case Row(features: Vector, label: Double, censor: Double) => + AFTPoint(features, label, censor) + } + } + + @Since("1.6.0") + override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = { + validateAndTransformSchema(dataset.schema, fitting = true) + val instances = extractAFTPoints(dataset) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val costFun = new AFTCostFun(instances, $(fitIntercept)) + val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + + val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size + /* + The weights vector has three parts: + the first element: Double, log(sigma), the log of scale parameter + the second element: Double, intercept of the beta parameter + the third to the end elements: Doubles, regression coefficients vector of the beta parameter + */ + val initialWeights = Vectors.zeros(numFeatures + 2) + + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialWeights.toBreeze.toDenseVector) + + val weights = { + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } + if (state == null) { + val msg = s"${optimizer.getClass.getName} failed." + throw new SparkException(msg) + } + + state.x.toArray.clone() + } + + if (handlePersistence) instances.unpersist() + + val coefficients = Vectors.dense(weights.slice(2, weights.length)) + val intercept = weights(1) + val scale = math.exp(weights(0)) + val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) + copyValues(model.setParent(this)) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true) + } + + @Since("1.6.0") + override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model produced by [[AFTSurvivalRegression]]. + */ +@Experimental +@Since("1.6.0") +class AFTSurvivalRegressionModel private[ml] ( + @Since("1.6.0") override val uid: String, + @Since("1.6.0") val coefficients: Vector, + @Since("1.6.0") val intercept: Double, + @Since("1.6.0") val scale: Double) + extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams { + + /** @group setParam */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) + + @Since("1.6.0") + def predictQuantiles(features: Vector): Vector = { + require(hasQuantileProbabilities, + "AFTSurvivalRegressionModel predictQuantiles must set quantile probabilities array") + // scale parameter for the Weibull distribution of lifetime + val lambda = math.exp(BLAS.dot(coefficients, features) + intercept) + // shape parameter for the Weibull distribution of lifetime + val k = 1 / scale + val quantiles = $(quantileProbabilities).map { + q => lambda * math.exp(math.log(-math.log(1 - q)) / k) + } + Vectors.dense(quantiles) + } + + @Since("1.6.0") + def predict(features: Vector): Double = { + math.exp(BLAS.dot(coefficients, features) + intercept) + } + + @Since("1.6.0") + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema) + val predictUDF = udf { features: Vector => predict(features) } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false) + } + + @Since("1.6.0") + override def copy(extra: ParamMap): AFTSurvivalRegressionModel = { + copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) + .setParent(parent) + } +} + +/** + * AFTAggregator computes the gradient and loss for a AFT loss function, + * as used in AFT survival regression for samples in sparse or dense vector in a online fashion. + * + * The loss function and likelihood function under the AFT model based on: + * Lawless, J. F., Statistical Models and Methods for Lifetime Data, + * New York: John Wiley & Sons, Inc. 2003. + * + * Two AFTAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * Given the values of the covariates x^{'}, for random lifetime t_{i} of subjects i = 1, ..., n, + * with possible right-censoring, the likelihood function under the AFT model is given as + * {{{ + * L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0} + * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0} + * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} + * }}} + * Where \delta_{i} is the indicator of the event has occurred i.e. uncensored or not. + * Using \epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}, the log-likelihood function + * assumes the form + * {{{ + * \iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+ + * \delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] + * }}} + * Where S_{0}(\epsilon_{i}) is the baseline survivor function, + * and f_{0}(\epsilon_{i}) is corresponding density function. + * + * The most commonly used log-linear survival regression method is based on the Weibull + * distribution of the survival time. The Weibull distribution for lifetime corresponding + * to extreme value distribution for log of the lifetime, + * and the S_{0}(\epsilon) function is + * {{{ + * S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) + * }}} + * the f_{0}(\epsilon_{i}) function is + * {{{ + * f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) + * }}} + * The log-likelihood function for Weibull distribution of lifetime is + * {{{ + * \iota(\beta,\sigma)= + * -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] + * }}} + * Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, + * the loss function we use to optimize is -\iota(\beta,\sigma). + * The gradient functions for \beta and \log\sigma respectively are + * {{{ + * \frac{\partial (-\iota)}{\partial \beta}= + * \sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} + * }}} + * {{{ + * \frac{\partial (-\iota)}{\partial (\log\sigma)}= + * \sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] + * }}} + * @param weights The log of scale parameter, the intercept and + * regression coefficients corresponding to the features. + * @param fitIntercept Whether to fit an intercept term. + */ +private class AFTAggregator(weights: BDV[Double], fitIntercept: Boolean) + extends Serializable { + + // beta is the intercept and regression coefficients to the covariates + private val beta = weights.slice(1, weights.length) + // sigma is the scale parameter of the AFT model + private val sigma = math.exp(weights(0)) + + private var totalCnt: Long = 0L + private var lossSum = 0.0 + private var gradientBetaSum = BDV.zeros[Double](beta.length) + private var gradientLogSigmaSum = 0.0 + + def count: Long = totalCnt + + def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt + + // Here we optimize loss function over beta and log(sigma) + def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), + gradientBetaSum/totalCnt.toDouble) + + /** + * Add a new training data to this AFTAggregator, and update the loss and gradient + * of the objective function. + * + * @param data The AFTPoint representation for one data point to be added into this aggregator. + * @return This AFTAggregator object. + */ + def add(data: AFTPoint): this.type = { + + // TODO: Don't create a new xi vector each time. + val xi = if (fitIntercept) { + Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze + } else { + Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze + } + val ti = data.label + val delta = data.censor + val epsilon = (math.log(ti) - beta.dot(xi)) / sigma + + lossSum += math.log(sigma) * delta + lossSum += (math.exp(epsilon) - delta * epsilon) + + // Sanity check (should never occur): + assert(!lossSum.isInfinity, + s"AFTAggregator loss sum is infinity. Error for unknown reason.") + + gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma + gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon + + totalCnt += 1 + this + } + + /** + * Merge another AFTAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other AFTAggregator to be merged. + * @return This AFTAggregator object. + */ + def merge(other: AFTAggregator): this.type = { + if (totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + + gradientBetaSum += other.gradientBetaSum + gradientLogSigmaSum += other.gradientLogSigmaSum + } + this + } +} + +/** + * AFTCostFun implements Breeze's DiffFunction[T] for AFT cost. + * It returns the loss and gradient at a particular point (coefficients). + * It's used in Breeze's convex optimization routines. + */ +private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean) + extends DiffFunction[BDV[Double]] { + + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + + val aftAggregator = data.treeAggregate(new AFTAggregator(coefficients, fitIntercept))( + seqOp = (c, v) => (c, v) match { + case (aggregator, instance) => aggregator.add(instance) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + (aftAggregator.loss, aftAggregator.gradient) + } +} + +/** + * Class that represents the (features, label, censor) of a data point. + * + * @param features List of features for this data point. + * @param label Label for this data point. + * @param censor Indicator of the event has occurred or not. If the value is 1, it means + * the event has occurred i.e. uncensored; otherwise censored. + */ +private[regression] case class AFTPoint(features: Vector, label: Double, censor: Double) { + require(censor == 1.0 || censor == 0.0, "censor of class AFTPoint must be 1.0 or 0.0") +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala new file mode 100644 index 0000000000000..ca7140a45ea65 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, DataFrame} + +class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var datasetUnivariate: DataFrame = _ + @transient var datasetMultivariate: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + datasetUnivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0))) + datasetMultivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) + } + + test("params") { + ParamsSuite.checkParams(new AFTSurvivalRegression) + val model = new AFTSurvivalRegressionModel("aftSurvReg", Vectors.dense(0.0), 0.0, 0.0) + ParamsSuite.checkParams(model) + } + + test("aft survival regression: default params") { + val aftr = new AFTSurvivalRegression + assert(aftr.getLabelCol === "label") + assert(aftr.getFeaturesCol === "features") + assert(aftr.getPredictionCol === "prediction") + assert(aftr.getCensorCol === "censor") + assert(aftr.getFitIntercept) + assert(aftr.getMaxIter === 100) + assert(aftr.getTol === 1E-6) + val model = aftr.fit(datasetUnivariate) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + model.transform(datasetUnivariate) + .select("label", "prediction") + .collect() + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + } + + def generateAFTInput( + numFeatures: Int, + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + weibullShape: Double, + weibullScale: Double, + exponentialMean: Double): Seq[AFTPoint] = { + + def censor(x: Double, y: Double): Double = { if (x <= y) 1.0 else 0.0 } + + val weibull = new WeibullGenerator(weibullShape, weibullScale) + weibull.setSeed(seed) + + val exponential = new ExponentialGenerator(exponentialMean) + exponential.setSeed(seed) + + val rnd = new Random(seed) + val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](numFeatures)(rnd.nextDouble())) + + x.foreach { v => + var i = 0 + val len = v.length + while (i < len) { + v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + } + val y = (1 to nPoints).map { i => (weibull.nextValue(), exponential.nextValue()) } + + y.zip(x).map { p => AFTPoint(Vectors.dense(p._2), p._1._1, censor(p._1._1, p._1._2)) } + } + + test("aft survival regression with univariate") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetUnivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- data$V1 + censor <- data$V2 + label <- data$V3 + sr.fit <- survreg(Surv(label, censor) ~ features, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.759 0.4141 4.247 2.16e-05 + features -0.039 0.0735 -0.531 5.96e-01 + Log(scale) 0.344 0.0379 9.073 1.16e-19 + + Scale= 1.41 + + Weibull distribution + Loglik(model)= -1152.2 Loglik(intercept only)= -1152.3 + Chisq= 0.28 on 1 degrees of freedom, p= 0.6 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.039) + val interceptR = 1.759 + val scaleR = 1.41 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + + testdata <- list(features=6.559282795753792) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.494763 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.1879174 2.6801195 14.5779394 + */ + val features = Vectors.dense(6.559282795753792) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 4.494763 + val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetUnivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("aft survival regression with multivariate") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.9206 0.1057 18.171 8.78e-74 + feature1 -0.0844 0.0611 -1.381 1.67e-01 + feature2 0.0677 0.0468 1.447 1.48e-01 + Log(scale) -0.0236 0.0436 -0.542 5.88e-01 + + Scale= 0.977 + + Weibull distribution + Loglik(model)= -1070.7 Loglik(intercept only)= -1072.7 + Chisq= 3.91 on 2 degrees of freedom, p= 0.14 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.0844, 0.0677) + val interceptR = 1.9206 + val scaleR = 0.977 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.761219 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.5287044 3.3285858 10.7517072 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 4.761219 + val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("aft survival regression w/o intercept") { + val trainer = new AFTSurvivalRegression().setFitIntercept(false) + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2 - 1, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + feature1 0.896 0.0685 13.1 3.93e-39 + feature2 -0.709 0.0522 -13.6 5.78e-42 + Log(scale) 0.420 0.0401 10.5 1.23e-25 + + Scale= 1.52 + + Weibull distribution + Loglik(model)= -1292.4 Loglik(intercept only)= -1072.7 + Chisq= -439.57 on 1 degrees of freedom, p= 1 + Number of Newton-Raphson Iterations: 6 + n= 1000 + */ + val coefficientsR = Vectors.dense(0.896, -0.709) + val interceptR = 0.0 + val scaleR = 1.52 + + assert(model.intercept === interceptR) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 44.54465 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 1.452103 25.506077 158.428600 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 44.54465 + val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } +} From 0252b2b4161b7d46ce237c65851edacef13042c6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 17 Sep 2015 22:05:20 -0700 Subject: [PATCH 0053/1168] [SPARK-10682] [GRAPHX] Remove Bagel test suites. Bagel has been deprecated and we haven't done any changes to it. There is no need to run those tests. This should speed up tests by 1 min. Author: Reynold Xin Closes #8807 from rxin/SPARK-10682. --- bagel/src/test/resources/log4j.properties | 27 ----- .../org/apache/spark/bagel/BagelSuite.scala | 113 ------------------ 2 files changed, 140 deletions(-) delete mode 100644 bagel/src/test/resources/log4j.properties delete mode 100644 bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties deleted file mode 100644 index edbecdae92096..0000000000000 --- a/bagel/src/test/resources/log4j.properties +++ /dev/null @@ -1,27 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala deleted file mode 100644 index fb10d734ac74b..0000000000000 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.bagel - -import org.scalatest.{BeforeAndAfter, Assertions} -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ - -import org.apache.spark._ -import org.apache.spark.storage.StorageLevel - -class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable -class TestMessage(val targetId: String) extends Message[String] with Serializable - -class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - } - - test("halting by voting") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("halting by message silence") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) - val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - val msgsOut = - msgs match { - case Some(ms) if (superstep < numSupersteps - 1) => - ms - case _ => - Array[TestMessage]() - } - (new TestVertex(self.active, self.age + 1), msgsOut) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("large number of iterations") { - // This tests whether jobs with a large number of iterations finish in a reasonable time, - // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang - failAfter(30 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 50 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } - - test("using non-default persistence level") { - failAfter(10 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 20 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } -} From 19609bc68dd55837cbd232e6d4b1cbfab5642e5f Mon Sep 17 00:00:00 2001 From: linweizhong Date: Thu, 17 Sep 2015 22:25:24 -0700 Subject: [PATCH 0054/1168] [SPARK-9522] [SQL] SparkSubmit process can not exit if kill application when HiveThriftServer was starting When we start HiveThriftServer, we will start SparkContext first, then start HiveServer2, if we kill application while HiveServer2 is starting then SparkContext will stop successfully, but SparkSubmit process can not exit. Author: linweizhong Closes #7853 from Sephiroth-Lin/SPARK-9522. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/sql/hive/thriftserver/HiveThriftServer2.scala | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9c3218719f7fc..ebd8e946ee7a2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -97,7 +97,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() - private val stopped: AtomicBoolean = new AtomicBoolean(false) + private[spark] val stopped: AtomicBoolean = new AtomicBoolean(false) private def assertNotStopped(): Unit = { if (stopped.get()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index dd9fef9206d0b..a0643cec0fb7c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -93,6 +93,12 @@ object HiveThriftServer2 extends Logging { } else { None } + // If application was killed before HiveThriftServer2 start successfully then SparkSubmit + // process can not exit, so check whether if SparkContext was stopped. + if (SparkSQLEnv.sparkContext.stopped.get()) { + logError("SparkContext has stopped even if HiveServer2 has started, so exit") + System.exit(-1) + } } catch { case e: Exception => logError("Error starting HiveThriftServer2", e) From 3ef83d95b3a3cf77f5a5bf673721b15fc4fa6b78 Mon Sep 17 00:00:00 2001 From: Felix Bechstein Date: Thu, 17 Sep 2015 22:42:46 -0700 Subject: [PATCH 0055/1168] docs/running-on-mesos.md: state default values in default column This PR simply uses the default value column for defaults. Author: Felix Bechstein Closes #8810 from felixb/fix_mesos_doc. --- docs/running-on-mesos.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 247e6ecfbdb86..1814fb32ed8a5 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -332,21 +332,21 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.principal - Framework principal to authenticate to Mesos + (none) Set the principal with which Spark framework will use to authenticate with Mesos. spark.mesos.secret - Framework secret to authenticate to Mesos + (none)/td> Set the secret with which Spark framework will use to authenticate with Mesos. spark.mesos.role - Role for the Spark framework + * Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations and resource weight sharing. @@ -354,7 +354,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.constraints - Attribute based constraints to be matched against when accepting resource offers. + (none) Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes.
    From e4484d4011e1807d2077c584a9848befcdf79261 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 17 Sep 2015 22:46:13 -0700 Subject: [PATCH 0056/1168] Added tag to documentation. --- docs/running-on-mesos.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 1814fb32ed8a5..330c159c67bca 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -346,7 +346,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.role - * + * Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations and resource weight sharing. From b78f59c4a288c05221163a799db5c7cc97ad4611 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Fri, 18 Sep 2015 00:43:02 -0700 Subject: [PATCH 0057/1168] [SPARK-10684] [SQL] StructType.interpretedOrdering need not to be serialized Kryo fails with buffer overflow even with max value (2G). {noformat} org.apache.spark.SparkException: Kryo serialization failed: Buffer overflow. Available: 0, required: 1 Serialization trace: containsChild (org.apache.spark.sql.catalyst.expressions.BoundReference) child (org.apache.spark.sql.catalyst.expressions.SortOrder) array (scala.collection.mutable.ArraySeq) ordering (org.apache.spark.sql.catalyst.expressions.InterpretedOrdering) interpretedOrdering (org.apache.spark.sql.types.StructType) schema (org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema). To avoid this, increase spark.kryoserializer.buffer.max value. at org.apache.spark.serializer.KryoSerializerInstance.serialize(KryoSerializer.scala:263) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:240) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) {noformat} Author: navis.ryu Closes #8808 from navis/SPARK-10684. --- .../main/scala/org/apache/spark/sql/types/StructType.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d8968ef806390..b29cf22dcb582 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -305,7 +305,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru f(this) || fields.exists(field => field.dataType.existsRecursively(f)) } - private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType)) + @transient + private[sql] lazy val interpretedOrdering = + InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } object StructType extends AbstractDataType { From 65f289a84b3835ae9ad2b2bc03c7a22d4152be33 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Fri, 18 Sep 2015 08:22:38 -0700 Subject: [PATCH 0058/1168] [SPARK-10451] [SQL] Prevent unnecessary serializations in InMemoryColumnarTableScan Many of the fields in InMemoryColumnar scan and InMemoryRelation can be made transient. This reduces my 1000ms job to abt 700 ms . The task size reduces from 2.8 mb to ~1300kb Author: Yash Datta Closes #8604 from saucam/serde. --- .../columnar/InMemoryColumnarTableScan.scala | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 66d429bc06198..d7e145f9c2bb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -48,10 +48,10 @@ private[sql] case class InMemoryRelation( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - child: SparkPlan, + @transient child: SparkPlan, tableName: Option[String])( - private var _cachedColumnBuffers: RDD[CachedBatch] = null, - private var _statistics: Statistics = null, + @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null, + @transient private var _statistics: Statistics = null, private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) extends LogicalPlan with MultiInstanceRelation { @@ -62,7 +62,7 @@ private[sql] case class InMemoryRelation( _batchStats } - val partitionStatistics = new PartitionStatistics(output) + @transient val partitionStatistics = new PartitionStatistics(output) private def computeSizeInBytes = { val sizeOfRow: Expression = @@ -196,7 +196,7 @@ private[sql] case class InMemoryRelation( private[sql] case class InMemoryColumnarTableScan( attributes: Seq[Attribute], predicates: Seq[Expression], - relation: InMemoryRelation) + @transient relation: InMemoryRelation) extends LeafNode { override def output: Seq[Attribute] = attributes @@ -205,7 +205,7 @@ private[sql] case class InMemoryColumnarTableScan( // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. - val buildFilter: PartialFunction[Expression, Expression] = { + @transient val buildFilter: PartialFunction[Expression, Expression] = { case And(lhs: Expression, rhs: Expression) if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _) @@ -268,16 +268,23 @@ private[sql] case class InMemoryColumnarTableScan( readBatches.setValue(0) } - relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => - val partitionFilter = newPredicate( - partitionFilters.reduceOption(And).getOrElse(Literal(true)), - relation.partitionStatistics.schema) + // Using these variables here to avoid serialization of entire objects (if referenced directly) + // within the map Partitions closure. + val schema = relation.partitionStatistics.schema + val schemaIndex = schema.zipWithIndex + val relOutput = relation.output + val buffers = relation.cachedColumnBuffers + + buffers.mapPartitions { cachedBatchIterator => + val partitionFilter = newPredicate( + partitionFilters.reduceOption(And).getOrElse(Literal(true)), + schema) // Find the ordinals and data types of the requested columns. If none are requested, use the // narrowest (the field with minimum default element size). val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) { val (narrowestOrdinal, narrowestDataType) = - relation.output.zipWithIndex.map { case (a, ordinal) => + relOutput.zipWithIndex.map { case (a, ordinal) => ordinal -> a.dataType } minBy { case (_, dataType) => ColumnType(dataType).defaultSize @@ -285,7 +292,7 @@ private[sql] case class InMemoryColumnarTableScan( Seq(narrowestOrdinal) -> Seq(narrowestDataType) } else { attributes.map { a => - relation.output.indexWhere(_.exprId == a.exprId) -> a.dataType + relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType }.unzip } @@ -296,7 +303,7 @@ private[sql] case class InMemoryColumnarTableScan( // Build column accessors val columnAccessors = requestedColumnIndices.map { batchColumnIndex => ColumnAccessor( - relation.output(batchColumnIndex).dataType, + relOutput(batchColumnIndex).dataType, ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex))) } @@ -328,7 +335,7 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map { + def statsString: String = schemaIndex.map { case (a, i) => val value = cachedBatch.stats.get(i, a.dataType) s"${a.name}: $value" From 56259770d7a861a328f4f991ae5f88cace13ab2a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 18 Sep 2015 09:53:52 -0700 Subject: [PATCH 0059/1168] [SPARK-10615] [PYSPARK] change assertEquals to assertEqual As ```assertEquals``` is deprecated, so we need to change ```assertEquals``` to ```assertEqual``` for existing python unit tests. Author: Yanbo Liang Closes #8814 from yanboliang/spark-10615. --- python/pyspark/ml/tests.py | 16 +-- python/pyspark/mllib/tests.py | 162 +++++++++++++++--------------- python/pyspark/sql/tests.py | 18 ++-- python/pyspark/streaming/tests.py | 2 +- 4 files changed, 99 insertions(+), 99 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b892318f50bd9..648fa8858fba3 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -182,7 +182,7 @@ def test_params(self): self.assertEqual(testParams.getMaxIter(), 10) testParams.setMaxIter(100) self.assertTrue(testParams.isSet(maxIter)) - self.assertEquals(testParams.getMaxIter(), 100) + self.assertEqual(testParams.getMaxIter(), 100) self.assertTrue(testParams.hasParam(inputCol)) self.assertFalse(testParams.hasDefault(inputCol)) @@ -195,7 +195,7 @@ def test_params(self): testParams._setDefault(seed=41) testParams.setSeed(43) - self.assertEquals( + self.assertEqual( testParams.explainParams(), "\n".join(["inputCol: input column name (undefined)", "maxIter: max number of iterations (>= 0) (default: 10, current: 100)", @@ -264,23 +264,23 @@ def test_ngram(self): self.assertEqual(ngram0.getInputCol(), "input") self.assertEqual(ngram0.getOutputCol(), "output") transformedDF = ngram0.transform(dataset) - self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) def test_stopwordsremover(self): sqlContext = SQLContext(self.sc) dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])]) stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") # Default - self.assertEquals(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getInputCol(), "input") transformedDF = stopWordRemover.transform(dataset) - self.assertEquals(transformedDF.head().output, ["panda"]) + self.assertEqual(transformedDF.head().output, ["panda"]) # Custom stopwords = ["panda"] stopWordRemover.setStopWords(stopwords) - self.assertEquals(stopWordRemover.getInputCol(), "input") - self.assertEquals(stopWordRemover.getStopWords(), stopwords) + self.assertEqual(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) - self.assertEquals(transformedDF.head().output, ["a"]) + self.assertEqual(transformedDF.head().output, ["a"]) class HasInducedError(Params): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 636f9a06cab7b..96cf13495aa95 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -166,13 +166,13 @@ def test_dot(self): [1., 2., 3., 4.], [1., 2., 3., 4.]]) arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEquals(10.0, sv.dot(dv)) + self.assertEqual(10.0, sv.dot(dv)) self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEquals(30.0, dv.dot(dv)) + self.assertEqual(30.0, dv.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEquals(30.0, lst.dot(dv)) + self.assertEqual(30.0, lst.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEquals(7.0, sv.dot(arr)) + self.assertEqual(7.0, sv.dot(arr)) def test_squared_distance(self): sv = SparseVector(4, {1: 1, 3: 2}) @@ -181,27 +181,27 @@ def test_squared_distance(self): lst1 = [4, 3, 2, 1] arr = pyarray.array('d', [0, 2, 1, 3]) narr = array([0, 2, 1, 3]) - self.assertEquals(15.0, _squared_distance(sv, dv)) - self.assertEquals(25.0, _squared_distance(sv, lst)) - self.assertEquals(20.0, _squared_distance(dv, lst)) - self.assertEquals(15.0, _squared_distance(dv, sv)) - self.assertEquals(25.0, _squared_distance(lst, sv)) - self.assertEquals(20.0, _squared_distance(lst, dv)) - self.assertEquals(0.0, _squared_distance(sv, sv)) - self.assertEquals(0.0, _squared_distance(dv, dv)) - self.assertEquals(0.0, _squared_distance(lst, lst)) - self.assertEquals(25.0, _squared_distance(sv, lst1)) - self.assertEquals(3.0, _squared_distance(sv, arr)) - self.assertEquals(3.0, _squared_distance(sv, narr)) + self.assertEqual(15.0, _squared_distance(sv, dv)) + self.assertEqual(25.0, _squared_distance(sv, lst)) + self.assertEqual(20.0, _squared_distance(dv, lst)) + self.assertEqual(15.0, _squared_distance(dv, sv)) + self.assertEqual(25.0, _squared_distance(lst, sv)) + self.assertEqual(20.0, _squared_distance(lst, dv)) + self.assertEqual(0.0, _squared_distance(sv, sv)) + self.assertEqual(0.0, _squared_distance(dv, dv)) + self.assertEqual(0.0, _squared_distance(lst, lst)) + self.assertEqual(25.0, _squared_distance(sv, lst1)) + self.assertEqual(3.0, _squared_distance(sv, arr)) + self.assertEqual(3.0, _squared_distance(sv, narr)) def test_hash(self): v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEquals(hash(v1), hash(v2)) - self.assertEquals(hash(v1), hash(v3)) - self.assertEquals(hash(v2), hash(v3)) + self.assertEqual(hash(v1), hash(v2)) + self.assertEqual(hash(v1), hash(v3)) + self.assertEqual(hash(v2), hash(v3)) self.assertFalse(hash(v1) == hash(v4)) self.assertFalse(hash(v2) == hash(v4)) @@ -212,8 +212,8 @@ def test_eq(self): v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEquals(v1, v2) - self.assertEquals(v1, v3) + self.assertEqual(v1, v2) + self.assertEqual(v1, v3) self.assertFalse(v2 == v4) self.assertFalse(v1 == v5) self.assertFalse(v1 == v6) @@ -238,13 +238,13 @@ def test_conversion(self): def test_sparse_vector_indexing(self): sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv[0], 0.) - self.assertEquals(sv[3], 2.) - self.assertEquals(sv[1], 1.) - self.assertEquals(sv[2], 0.) - self.assertEquals(sv[-1], 2) - self.assertEquals(sv[-2], 0) - self.assertEquals(sv[-4], 0) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[-1], 2) + self.assertEqual(sv[-2], 0) + self.assertEqual(sv[-4], 0) for ind in [4, -5]: self.assertRaises(ValueError, sv.__getitem__, ind) for ind in [7.8, '1']: @@ -255,7 +255,7 @@ def test_matrix_indexing(self): expected = [[0, 6], [1, 8], [4, 10]] for i in range(3): for j in range(2): - self.assertEquals(mat[i, j], expected[i][j]) + self.assertEqual(mat[i, j], expected[i][j]) def test_repr_dense_matrix(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) @@ -308,11 +308,11 @@ def test_sparse_matrix(self): # Test sparse matrix creation. sm1 = SparseMatrix( 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEquals(sm1.numRows, 3) - self.assertEquals(sm1.numCols, 4) - self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertEqual(sm1.numRows, 3) + self.assertEqual(sm1.numCols, 4) + self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) self.assertTrue( repr(sm1), 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') @@ -325,13 +325,13 @@ def test_sparse_matrix(self): for i in range(3): for j in range(4): - self.assertEquals(expected[i][j], sm1[i, j]) + self.assertEqual(expected[i][j], sm1[i, j]) self.assertTrue(array_equal(sm1.toArray(), expected)) # Test conversion to dense and sparse. smnew = sm1.toDense().toSparse() - self.assertEquals(sm1.numRows, smnew.numRows) - self.assertEquals(sm1.numCols, smnew.numCols) + self.assertEqual(sm1.numRows, smnew.numRows) + self.assertEqual(sm1.numCols, smnew.numCols) self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) self.assertTrue(array_equal(sm1.values, smnew.values)) @@ -339,11 +339,11 @@ def test_sparse_matrix(self): sm1t = SparseMatrix( 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], isTransposed=True) - self.assertEquals(sm1t.numRows, 3) - self.assertEquals(sm1t.numCols, 4) - self.assertEquals(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEquals(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEquals(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + self.assertEqual(sm1t.numRows, 3) + self.assertEqual(sm1t.numCols, 4) + self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) expected = [ [3, 2, 0, 0], @@ -352,18 +352,18 @@ def test_sparse_matrix(self): for i in range(3): for j in range(4): - self.assertEquals(expected[i][j], sm1t[i, j]) + self.assertEqual(expected[i][j], sm1t[i, j]) self.assertTrue(array_equal(sm1t.toArray(), expected)) def test_dense_matrix_is_transposed(self): mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEquals(mat1, mat) + self.assertEqual(mat1, mat) expected = [[0, 4], [1, 6], [3, 9]] for i in range(3): for j in range(2): - self.assertEquals(mat1[i, j], expected[i][j]) + self.assertEqual(mat1[i, j], expected[i][j]) self.assertTrue(array_equal(mat1.toArray(), expected)) sm = mat1.toSparse() @@ -412,8 +412,8 @@ def test_kmeans(self): ] clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", initializationSteps=7, epsilon=1e-4) - self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) def test_kmeans_deterministic(self): from pyspark.mllib.clustering import KMeans @@ -443,8 +443,8 @@ def test_gmm(self): clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, maxIterations=10, seed=56) labels = clusters.predict(data).collect() - self.assertEquals(labels[0], labels[1]) - self.assertEquals(labels[2], labels[3]) + self.assertEqual(labels[0], labels[1]) + self.assertEqual(labels[2], labels[3]) def test_gmm_deterministic(self): from pyspark.mllib.clustering import GaussianMixture @@ -456,7 +456,7 @@ def test_gmm_deterministic(self): clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, maxIterations=10, seed=63) for c1, c2 in zip(clusters1.weights, clusters2.weights): - self.assertEquals(round(c1, 7), round(c2, 7)) + self.assertEqual(round(c1, 7), round(c2, 7)) def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes @@ -711,18 +711,18 @@ def test_serialize(self): lil[1, 0] = 1 lil[3, 0] = 2 sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv, _convert_to_vector(lil)) - self.assertEquals(sv, _convert_to_vector(lil.tocsc())) - self.assertEquals(sv, _convert_to_vector(lil.tocoo())) - self.assertEquals(sv, _convert_to_vector(lil.tocsr())) - self.assertEquals(sv, _convert_to_vector(lil.todok())) + self.assertEqual(sv, _convert_to_vector(lil)) + self.assertEqual(sv, _convert_to_vector(lil.tocsc())) + self.assertEqual(sv, _convert_to_vector(lil.tocoo())) + self.assertEqual(sv, _convert_to_vector(lil.tocsr())) + self.assertEqual(sv, _convert_to_vector(lil.todok())) def serialize(l): return ser.loads(ser.dumps(_convert_to_vector(l))) - self.assertEquals(sv, serialize(lil)) - self.assertEquals(sv, serialize(lil.tocsc())) - self.assertEquals(sv, serialize(lil.tocsr())) - self.assertEquals(sv, serialize(lil.todok())) + self.assertEqual(sv, serialize(lil)) + self.assertEqual(sv, serialize(lil.tocsc())) + self.assertEqual(sv, serialize(lil.tocsr())) + self.assertEqual(sv, serialize(lil.todok())) def test_dot(self): from scipy.sparse import lil_matrix @@ -730,7 +730,7 @@ def test_dot(self): lil[1, 0] = 1 lil[3, 0] = 2 dv = DenseVector(array([1., 2., 3., 4.])) - self.assertEquals(10.0, dv.dot(lil)) + self.assertEqual(10.0, dv.dot(lil)) def test_squared_distance(self): from scipy.sparse import lil_matrix @@ -739,8 +739,8 @@ def test_squared_distance(self): lil[3, 0] = 2 dv = DenseVector(array([1., 2., 3., 4.])) sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) - self.assertEquals(15.0, dv.squared_distance(lil)) - self.assertEquals(15.0, sv.squared_distance(lil)) + self.assertEqual(15.0, dv.squared_distance(lil)) + self.assertEqual(15.0, sv.squared_distance(lil)) def scipy_matrix(self, size, values): """Create a column SciPy matrix from a dictionary of values""" @@ -759,8 +759,8 @@ def test_clustering(self): self.scipy_matrix(3, {2: 1.1}) ] clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") - self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes @@ -984,12 +984,12 @@ def test_word2vec_setters(self): .setNumIterations(10) \ .setSeed(1024) \ .setMinCount(3) - self.assertEquals(model.vectorSize, 2) + self.assertEqual(model.vectorSize, 2) self.assertTrue(model.learningRate < 0.02) - self.assertEquals(model.numPartitions, 2) - self.assertEquals(model.numIterations, 10) - self.assertEquals(model.seed, 1024) - self.assertEquals(model.minCount, 3) + self.assertEqual(model.numPartitions, 2) + self.assertEqual(model.numIterations, 10) + self.assertEqual(model.seed, 1024) + self.assertEqual(model.minCount, 3) def test_word2vec_get_vectors(self): data = [ @@ -1002,7 +1002,7 @@ def test_word2vec_get_vectors(self): ["a"] ] model = Word2Vec().fit(self.sc.parallelize(data)) - self.assertEquals(len(model.getVectors()), 3) + self.assertEqual(len(model.getVectors()), 3) class StandardScalerTests(MLlibTestCase): @@ -1044,8 +1044,8 @@ def test_model_params(self): """Test that the model params are set correctly""" stkm = StreamingKMeans() stkm.setK(5).setDecayFactor(0.0) - self.assertEquals(stkm._k, 5) - self.assertEquals(stkm._decayFactor, 0.0) + self.assertEqual(stkm._k, 5) + self.assertEqual(stkm._decayFactor, 0.0) # Model not set yet. self.assertIsNone(stkm.latestModel()) @@ -1053,9 +1053,9 @@ def test_model_params(self): stkm.setInitialCenters( centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) - self.assertEquals( + self.assertEqual( stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) - self.assertEquals(stkm.latestModel().clusterWeights, [1.0, 1.0]) + self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0]) def test_accuracy_for_single_center(self): """Test that parameters obtained are correct for a single center.""" @@ -1070,7 +1070,7 @@ def test_accuracy_for_single_center(self): self.ssc.start() def condition(): - self.assertEquals(stkm.latestModel().clusterWeights, [25.0]) + self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) return True self._eventually(condition, catch_assertions=True) @@ -1114,7 +1114,7 @@ def test_trainOn_model(self): def condition(): finalModel = stkm.latestModel() self.assertTrue(all(finalModel.centers == array(initCenters))) - self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) return True self._eventually(condition, catch_assertions=True) @@ -1141,7 +1141,7 @@ def update(rdd): self.ssc.start() def condition(): - self.assertEquals(result, [[0], [1], [2], [3]]) + self.assertEqual(result, [[0], [1], [2], [3]]) return True self._eventually(condition, catch_assertions=True) @@ -1263,7 +1263,7 @@ def test_convergence(self): self.ssc.start() def condition(): - self.assertEquals(len(models), len(input_batches)) + self.assertEqual(len(models), len(input_batches)) return True # We want all batches to finish for this test. @@ -1297,7 +1297,7 @@ def test_predictions(self): self.ssc.start() def condition(): - self.assertEquals(len(true_predicted), len(input_batches)) + self.assertEqual(len(true_predicted), len(input_batches)) return True self._eventually(condition, catch_assertions=True) @@ -1400,7 +1400,7 @@ def test_parameter_convergence(self): self.ssc.start() def condition(): - self.assertEquals(len(model_weights), len(batches)) + self.assertEqual(len(model_weights), len(batches)) return True # We want all batches to finish for this test. @@ -1433,7 +1433,7 @@ def test_prediction(self): self.ssc.start() def condition(): - self.assertEquals(len(samples), len(batches)) + self.assertEqual(len(samples), len(batches)) return True # We want all batches to finish for this test. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f2172b7a27d88..3e680f1030a71 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -157,7 +157,7 @@ class DataTypeTests(unittest.TestCase): def test_data_type_eq(self): lt = LongType() lt2 = pickle.loads(pickle.dumps(LongType())) - self.assertEquals(lt, lt2) + self.assertEqual(lt, lt2) # regression test for SPARK-7978 def test_decimal_type(self): @@ -393,7 +393,7 @@ def test_infer_nested_schema(self): CustomRow(field1=2, field2="row2"), CustomRow(field1=3, field2="row3")]) df = self.sqlCtx.inferSchema(rdd) - self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] @@ -403,7 +403,7 @@ def test_create_dataframe_from_objects(self): def test_select_null_literal(self): df = self.sqlCtx.sql("select null as col") - self.assertEquals(Row(col=None), df.first()) + self.assertEqual(Row(col=None), df.first()) def test_apply_schema(self): from datetime import date, datetime @@ -519,14 +519,14 @@ def test_apply_schema_with_udt(self): StructField("point", ExamplePointUDT(), False)]) df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = (1.0, PythonOnlyPoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False)]) df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point - self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT @@ -554,14 +554,14 @@ def test_parquet_with_udt(self): df0.write.parquet(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df0 = self.sqlCtx.createDataFrame([row]) df0.write.parquet(output_dir, mode='overwrite') df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point - self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_column_operators(self): ci = self.df.key @@ -826,8 +826,8 @@ def test_infer_long_type(self): output_dir = os.path.join(self.tempdir.name, "infer_long_type") df.saveAsParquetFile(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) - self.assertEquals('a', df1.first().f1) - self.assertEquals(100000000000000, df1.first().f2) + self.assertEqual('a', df1.first().f1) + self.assertEqual(100000000000000, df1.first().f2) self.assertEqual(_infer_type(1), LongType()) self.assertEqual(_infer_type(2**10), LongType()) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index cfea95b0dec71..e4e56fff3b3fc 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -693,7 +693,7 @@ def check_output(n): # Verify that getActiveOrCreate() returns active context self.setupCalled = False - self.assertEquals(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc) + self.assertEqual(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc) self.assertFalse(self.setupCalled) # Verify that getActiveOrCreate() uses existing SparkContext From 7f1e4413592eaa46d71d4795123c1664ef722222 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 18 Sep 2015 12:19:08 -0700 Subject: [PATCH 0060/1168] [SPARK-10540] Fixes flaky all-data-type test This PR breaks the original test case into multiple ones (one test case for each data type). In this way, test failure output can be much more readable. Within each test case, we build a table with two columns, one of them is for the data type to test, the other is an "index" column, which is used to sort the DataFrame and workaround [SPARK-10591] [1] [1]: https://issues.apache.org/jira/browse/SPARK-10591 Author: Cheng Lian Closes #8768 from liancheng/spark-10540/test-all-data-types. --- .../sql/sources/hadoopFsRelationSuites.scala | 109 +++++++----------- 1 file changed, 43 insertions(+), 66 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 8ffcef85668d6..d7504936d90e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -100,80 +100,57 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } - ignore("test all data types") { - withTempPath { file => - // Create the schema. - val struct = - StructType( - StructField("f1", FloatType, true) :: - StructField("f2", ArrayType(BooleanType), true) :: Nil) - // TODO: add CalendarIntervalType to here once we can save it out. - val dataTypes = - Seq( - StringType, BinaryType, NullType, BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), - DateType, TimestampType, - ArrayType(IntegerType), MapType(StringType, LongType), struct, - new MyDenseVectorUDT()) - val fields = dataTypes.zipWithIndex.map { case (dataType, index) => - StructField(s"col$index", dataType, nullable = true) - } - val schema = StructType(fields) - - // Generate data at the driver side. We need to materialize the data first and then - // create RDD. - val maybeDataGenerator = - RandomDataGenerator.forType( - dataType = schema, + private val supportedDataTypes = Seq( + StringType, BinaryType, + NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), + MapType(StringType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), + new MyDenseVectorUDT() + ).filter(supportsDataType) + + for (dataType <- supportedDataTypes) { + test(s"test all data types - $dataType") { + withTempPath { file => + val path = file.getCanonicalPath + + val dataGenerator = RandomDataGenerator.forType( + dataType = dataType, nullable = true, - seed = Some(System.nanoTime())) - val dataGenerator = - maybeDataGenerator - .getOrElse(fail(s"Failed to create data generator for schema $schema")) - val data = (1 to 10).map { i => - dataGenerator.apply() match { - case row: Row => row - case null => Row.fromSeq(Seq.fill(schema.length)(null)) - case other => - fail(s"Row or null is expected to be generated, " + - s"but a ${other.getClass.getCanonicalName} is generated.") + seed = Some(System.nanoTime()) + ).getOrElse { + fail(s"Failed to create data generator for schema $dataType") } - } - // Create a DF for the schema with random data. - val rdd = sqlContext.sparkContext.parallelize(data, 10) - val df = sqlContext.createDataFrame(rdd, schema) + // Create a DF for the schema with random data. The index field is used to sort the + // DataFrame. This is a workaround for SPARK-10591. + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", dataType, nullable = true) + val rdd = sqlContext.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) + val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) - // All columns that have supported data types of this source. - val supportedColumns = schema.fields.collect { - case StructField(name, dataType, _, _) if supportsDataType(dataType) => name - } - val selectedColumns = util.Random.shuffle(supportedColumns.toSeq) - - val dfToBeSaved = df.selectExpr(selectedColumns: _*) - - // Save the data out. - dfToBeSaved - .write - .format(dataSourceName) - .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. - .save(file.getCanonicalPath) + df.write + .mode("overwrite") + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .save(path) - val loadedDF = - sqlContext + val loadedDF = sqlContext .read .format(dataSourceName) - .schema(dfToBeSaved.schema) - .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. - .load(file.getCanonicalPath) - .selectExpr(selectedColumns: _*) + .option("dataSchema", df.schema.json) + .schema(df.schema) + .load(path) + .orderBy("index") - // Read the data back. - checkAnswer( - loadedDF, - dfToBeSaved - ) + checkAnswer(loadedDF, df) + } } } From 282b9967303acdea2a95030008494f2de8e20dc3 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 18 Sep 2015 13:20:13 -0700 Subject: [PATCH 0061/1168] [SPARK-10539] [SQL] Project should not be pushed down through Intersect or Except #8742 Intersect and Except are both set operators and they use the all the columns to compare equality between rows. When pushing their Project parent down, the relations they based on would change, therefore not an equivalent transformation. JIRA: https://issues.apache.org/jira/browse/SPARK-10539 I added some comments based on the fix of https://github.com/apache/spark/pull/8742. Author: Yijie Shen Author: Yin Huai Closes #8823 from yhuai/fix_set_optimization. --- .../sql/catalyst/optimizer/Optimizer.scala | 37 ++++++++++--------- .../optimizer/SetOperationPushDownSuite.scala | 23 ++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 9 +++++ 3 files changed, 39 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 648a65e7c0eb3..324f40a051c38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -85,7 +85,22 @@ object SamplePushDown extends Rule[LogicalPlan] { } /** - * Pushes operations to either side of a Union, Intersect or Except. + * Pushes certain operations to both sides of a Union, Intersect or Except operator. + * Operations that are safe to pushdown are listed as follows. + * Union: + * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is + * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, + * we will not be able to pushdown Projections. + * + * Intersect: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * because we will not have non-deterministic expressions. + * + * Except: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * because we will not have non-deterministic expressions. */ object SetOperationPushDown extends Rule[LogicalPlan] { @@ -122,40 +137,26 @@ object SetOperationPushDown extends Rule[LogicalPlan] { Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - // Push down projection into union + // Push down projection through UNION ALL case Project(projectList, u @ Union(left, right)) => val rewrites = buildRewrites(u) Union( Project(projectList, left), Project(projectList.map(pushToRight(_, rewrites)), right)) - // Push down filter into intersect + // Push down filter through INTERSECT case Filter(condition, i @ Intersect(left, right)) => val rewrites = buildRewrites(i) Intersect( Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - // Push down projection into intersect - case Project(projectList, i @ Intersect(left, right)) => - val rewrites = buildRewrites(i) - Intersect( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) - - // Push down filter into except + // Push down filter through EXCEPT case Filter(condition, e @ Except(left, right)) => val rewrites = buildRewrites(e) Except( Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - - // Push down projection into except - case Project(projectList, e @ Except(left, right)) => - val rewrites = buildRewrites(e) - Except( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 49c979bc7d72c..3fca47a023dc6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -60,23 +60,22 @@ class SetOperationPushDownSuite extends PlanTest { comparePlans(exceptOptimized, exceptCorrectAnswer) } - test("union/intersect/except: project to each side") { + test("union: project to each side") { val unionQuery = testUnion.select('a) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select('a), testRelation2.select('d)).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { val intersectQuery = testIntersect.select('b, 'c) val exceptQuery = testExcept.select('a, 'b, 'c) - val unionOptimized = Optimize.execute(unionQuery.analyze) val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - val unionCorrectAnswer = - Union(testRelation.select('a), testRelation2.select('d)).analyze - val intersectCorrectAnswer = - Intersect(testRelation.select('b, 'c), testRelation2.select('e, 'f)).analyze - val exceptCorrectAnswer = - Except(testRelation.select('a, 'b, 'c), testRelation2.select('d, 'e, 'f)).analyze - - comparePlans(unionOptimized, unionCorrectAnswer) - comparePlans(intersectOptimized, intersectCorrectAnswer) - comparePlans(exceptOptimized, exceptCorrectAnswer) } + comparePlans(intersectOptimized, intersectQuery.analyze) + comparePlans(exceptOptimized, exceptQuery.analyze) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c167999af580e..1370713975f2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -907,4 +907,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) } } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { + val df1 = (1 to 100).map(Tuple1.apply).toDF("i") + val df2 = (1 to 30).map(Tuple1.apply).toDF("i") + val intersect = df1.intersect(df2) + val except = df1.except(df2) + assert(intersect.count() === 30) + assert(except.count() === 70) + } } From ed296deeb64656e968500c535bc4853b83c64820 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 18 Sep 2015 13:47:14 -0700 Subject: [PATCH 0062/1168] [SPARK-10449] [SQL] Don't merge decimal types with incompatable precision or scales From JIRA: Schema merging should only handle struct fields. But currently we also reconcile decimal precision and scale information. Author: Holden Karau Closes #8634 from holdenk/SPARK-10449-dont-merge-different-precision. --- .../org/apache/spark/sql/types/StructType.scala | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b29cf22dcb582..d6b436724b2a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -373,10 +373,19 @@ object StructType extends AbstractDataType { StructType(newFields) case (DecimalType.Fixed(leftPrecision, leftScale), - DecimalType.Fixed(rightPrecision, rightScale)) => - DecimalType( - max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), - max(leftScale, rightScale)) + DecimalType.Fixed(rightPrecision, rightScale)) => + if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) { + DecimalType(leftPrecision, leftScale) + } else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale") + } else if (leftPrecision != rightPrecision) { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"precision $leftPrecision and $rightPrecision") + } else { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"scala $leftScale and $rightScale") + } case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) if leftUdt.userClass == rightUdt.userClass => leftUdt From 798a7bd1697a65214e6327c26ea5d4b3154c287b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 18 Sep 2015 13:48:41 -0700 Subject: [PATCH 0063/1168] [SPARK-9808] Remove hash shuffle file consolidation. Author: Reynold Xin Closes #8812 from rxin/SPARK-9808-1. --- .../shuffle/FileShuffleBlockResolver.scala | 178 ++---------------- .../apache/spark/storage/BlockManager.scala | 9 - .../org/apache/spark/storage/DiskStore.scala | 3 - .../hash/HashShuffleManagerSuite.scala | 110 ----------- docs/configuration.md | 10 - .../shuffle/ExternalShuffleBlockResolver.java | 4 - project/MimaExcludes.scala | 4 + 7 files changed, 17 insertions(+), 301 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index c057de9b3f4df..d9902f96dfd4e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -17,9 +17,7 @@ package org.apache.spark.shuffle -import java.io.File import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ @@ -28,10 +26,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.FileShuffleBlockResolver.ShuffleFileGroup import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} -import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -43,24 +39,7 @@ private[spark] trait ShuffleWriterGroup { /** * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file - * per reducer (this set of files is called a ShuffleFileGroup). - * - * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle - * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer - * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle - * files, it releases them for another task. - * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple: - * - shuffleId: The unique id given to the entire shuffle stage. - * - bucketId: The id of the output partition (i.e., reducer id) - * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a - * time owns a particular fileId, and this id is returned to a pool when the task finishes. - * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length) - * that specifies where in a given file the actual block data is located. - * - * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping - * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for - * each block stored in each file. In order to find the location of a shuffle block, we search the - * files within a ShuffleFileGroups associated with the block's reducer. + * per reducer. */ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getHashBasedShuffleBlockData(). @@ -71,26 +50,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private lazy val blockManager = SparkEnv.get.blockManager - // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. - // TODO: Remove this once the shuffle file consolidation feature is stable. - private val consolidateShuffleFiles = - conf.getBoolean("spark.shuffle.consolidateFiles", false) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** - * Contains all the state related to a particular shuffle. This includes a pool of unused - * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. + * Contains all the state related to a particular shuffle. */ - private class ShuffleState(val numBuckets: Int) { - val nextFileId = new AtomicInteger(0) - val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - + private class ShuffleState(val numReducers: Int) { /** * The mapIds of all map tasks completed on this Executor for this shuffle. - * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise. */ val completedMapTasks = new ConcurrentLinkedQueue[Int]() } @@ -104,24 +72,16 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) * Get a ShuffleWriterGroup for the given map task, which will register it as complete * when the writers are closed successfully */ - def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, + def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer, writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) private val shuffleState = shuffleStates(shuffleId) - private var fileGroup: ShuffleFileGroup = null val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() - val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { - fileGroup = getUnusedFileGroup() - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, - writeMetrics) - } - } else { - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => + val writers: Array[DiskBlockObjectWriter] = { + Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) // Because of previous failures, the shuffle file may already exist on this machine. @@ -142,58 +102,14 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { - if (consolidateShuffleFiles) { - if (success) { - val offsets = writers.map(_.fileSegment().offset) - val lengths = writers.map(_.fileSegment().length) - fileGroup.recordMapOutput(mapId, offsets, lengths) - } - recycleFileGroup(fileGroup) - } else { - shuffleState.completedMapTasks.add(mapId) - } - } - - private def getUnusedFileGroup(): ShuffleFileGroup = { - val fileGroup = shuffleState.unusedFileGroups.poll() - if (fileGroup != null) fileGroup else newFileGroup() - } - - private def newFileGroup(): ShuffleFileGroup = { - val fileId = shuffleState.nextFileId.getAndIncrement() - val files = Array.tabulate[File](numBuckets) { bucketId => - val filename = physicalFileName(shuffleId, bucketId, fileId) - blockManager.diskBlockManager.getFile(filename) - } - val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files) - shuffleState.allFileGroups.add(fileGroup) - fileGroup - } - - private def recycleFileGroup(group: ShuffleFileGroup) { - shuffleState.unusedFileGroups.add(group) + shuffleState.completedMapTasks.add(mapId) } } } override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { - if (consolidateShuffleFiles) { - // Search all file groups associated with this shuffle. - val shuffleState = shuffleStates(blockId.shuffleId) - val iter = shuffleState.allFileGroups.iterator - while (iter.hasNext) { - val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) - if (segmentOpt.isDefined) { - val segment = segmentOpt.get - return new FileSegmentManagedBuffer( - transportConf, segment.file, segment.offset, segment.length) - } - } - throw new IllegalStateException("Failed to find shuffle block: " + blockId) - } else { - val file = blockManager.diskBlockManager.getFile(blockId) - new FileSegmentManagedBuffer(transportConf, file, 0, file.length) - } + val file = blockManager.diskBlockManager.getFile(blockId) + new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } /** Remove all the blocks / files and metadata related to a particular shuffle. */ @@ -209,17 +125,9 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { shuffleStates.get(shuffleId) match { case Some(state) => - if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups.asScala; - file <- fileGroup.files) { - file.delete() - } - } else { - for (mapId <- state.completedMapTasks.asScala; - reduceId <- 0 until state.numBuckets) { - val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) - blockManager.diskBlockManager.getFile(blockId).delete() - } + for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + blockManager.diskBlockManager.getFile(blockId).delete() } logInfo("Deleted all files for shuffle " + shuffleId) true @@ -229,10 +137,6 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) } } - private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { - "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) - } - private def cleanup(cleanupTime: Long) { shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) } @@ -241,59 +145,3 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) metadataCleaner.cancel() } } - -private[spark] object FileShuffleBlockResolver { - /** - * A group of shuffle files, one per reducer. - * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. - */ - private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { - private var numBlocks: Int = 0 - - /** - * Stores the absolute index of each mapId in the files of this group. For instance, - * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. - */ - private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() - - /** - * Stores consecutive offsets and lengths of blocks into each reducer file, ordered by - * position in the file. - * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every - * reducer. - */ - private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - - def apply(bucketId: Int): File = files(bucketId) - - def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { - assert(offsets.length == lengths.length) - mapIdToIndex(mapId) = numBlocks - numBlocks += 1 - for (i <- 0 until offsets.length) { - blockOffsetsByReducer(i) += offsets(i) - blockLengthsByReducer(i) += lengths(i) - } - } - - /** Returns the FileSegment associated with the given map task, or None if no entry exists. */ - def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { - val file = files(reducerId) - val blockOffsets = blockOffsetsByReducer(reducerId) - val blockLengths = blockLengthsByReducer(reducerId) - val index = mapIdToIndex.getOrElse(mapId, -1) - if (index >= 0) { - val offset = blockOffsets(index) - val length = blockLengths(index) - Some(new FileSegment(file, offset, length)) - } else { - None - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d31aa68eb6954..bca3942f8c555 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -106,15 +106,6 @@ private[spark] class BlockManager( } } - // Check that we're not using external shuffle service with consolidated shuffle files. - if (externalShuffleServiceEnabled - && conf.getBoolean("spark.shuffle.consolidateFiles", false) - && shuffleManager.isInstanceOf[HashShuffleManager]) { - throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated" - + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or " - + " switch to sort-based shuffle.") - } - var blockManagerId: BlockManagerId = _ // Address of the server that serves this executor's shuffle files. This is either an external diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 1f45956282166..feb9533604ffb 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -154,9 +154,6 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc override def remove(blockId: BlockId): Boolean = { val file = diskManager.getFile(blockId.name) - // If consolidation mode is used With HashShuffleMananger, the physical filename for the block - // is different from blockId.name. So the file returns here will not be exist, thus we avoid to - // delete the whole consolidated file by mistake. if (file.exists()) { file.delete() } else { diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala deleted file mode 100644 index 491dc3659e184..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import java.io.{File, FileWriter} - -import scala.language.reflectiveCalls - -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.FileShuffleBlockResolver -import org.apache.spark.storage.{ShuffleBlockId, FileSegment} - -class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { - private val testConf = new SparkConf(false) - - private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { - assert(buffer.isInstanceOf[FileSegmentManagedBuffer]) - val segment = buffer.asInstanceOf[FileSegmentManagedBuffer] - assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath) - assert(expected.offset === segment.getOffset) - assert(expected.length === segment.getLength) - } - - test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { - - val conf = new SparkConf(false) - // reset after EACH object write. This is to ensure that there are bytes appended after - // an object is written. So if the codepaths assume writeObject is end of data, this should - // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc. - conf.set("spark.serializer.objectStreamReset", "1") - conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") - - sc = new SparkContext("local", "test", conf) - - val shuffleBlockResolver = - SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[FileShuffleBlockResolver] - - val shuffle1 = shuffleBlockResolver.forMapTask(1, 1, 1, new JavaSerializer(conf), - new ShuffleWriteMetrics) - for (writer <- shuffle1.writers) { - writer.write("test1", "value") - writer.write("test2", "value") - } - for (writer <- shuffle1.writers) { - writer.commitAndClose() - } - - val shuffle1Segment = shuffle1.writers(0).fileSegment() - shuffle1.releaseWriters(success = true) - - val shuffle2 = shuffleBlockResolver.forMapTask(1, 2, 1, new JavaSerializer(conf), - new ShuffleWriteMetrics) - - for (writer <- shuffle2.writers) { - writer.write("test3", "value") - writer.write("test4", "vlue") - } - for (writer <- shuffle2.writers) { - writer.commitAndClose() - } - val shuffle2Segment = shuffle2.writers(0).fileSegment() - shuffle2.releaseWriters(success = true) - - // Now comes the test : - // Write to shuffle 3; and close it, but before registering it, check if the file lengths for - // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length - // of block based on remaining data in file : which could mess things up when there is - // concurrent read and writes happening to the same shuffle group. - - val shuffle3 = shuffleBlockResolver.forMapTask(1, 3, 1, new JavaSerializer(testConf), - new ShuffleWriteMetrics) - for (writer <- shuffle3.writers) { - writer.write("test3", "value") - writer.write("test4", "value") - } - for (writer <- shuffle3.writers) { - writer.commitAndClose() - } - // check before we register. - checkSegments(shuffle2Segment, shuffleBlockResolver.getBlockData(ShuffleBlockId(1, 2, 0))) - shuffle3.releaseWriters(success = true) - checkSegments(shuffle2Segment, shuffleBlockResolver.getBlockData(ShuffleBlockId(1, 2, 0))) - shuffleBlockResolver.removeShuffle(1) - } - - def writeToFile(file: File, numBytes: Int) { - val writer = new FileWriter(file, true) - for (i <- 0 until numBytes) writer.write(i) - writer.close() - } -} diff --git a/docs/configuration.md b/docs/configuration.md index 1a701f18881fe..3700051efb448 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -390,16 +390,6 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec. - - spark.shuffle.consolidateFiles - false - - If set to "true", consolidates intermediate files created during a shuffle. Creating fewer - files can improve filesystem performance for shuffles with large numbers of reduce tasks. It - is recommended to set this to "true" when using ext4 or xfs filesystems. On ext3, this option - might degrade performance on machines with many (>8) cores due to filesystem limitations. - - spark.shuffle.file.buffer 32k diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 79beec4429a99..c5f93bb47f55c 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -50,9 +50,6 @@ * of Executors. Each Executor must register its own configuration about where it stores its files * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated * from Spark's FileShuffleBlockResolver and IndexShuffleBlockResolver. - * - * Executors with shuffle file consolidation are not currently supported, as the index is stored in - * the Executor's memory, unlike the IndexShuffleBlockResolver. */ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); @@ -254,7 +251,6 @@ private void deleteExecutorDirs(String[] dirs) { * Hash-based shuffle data is simply stored as one file per block. * This logic is from FileShuffleBlockResolver. */ - // TODO: Support consolidated hash shuffle files private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1c96b0958586f..814a11e588ceb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -70,6 +70,10 @@ object MimaExcludes { "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") + ) ++ + Seq( + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") ) case v if v.startsWith("1.5") => Seq( From ecded0cbc9782603398b34fb6f2bfa8dea5c2039 Mon Sep 17 00:00:00 2001 From: Mingyu Kim Date: Fri, 18 Sep 2015 15:40:58 -0700 Subject: [PATCH 0064/1168] [SPARK-10611] Clone Configuration for each task for NewHadoopRDD This patch attempts to fix the Hadoop Configuration thread safety issue for NewHadoopRDD in the same way SPARK-2546 fixed the issue for HadoopRDD. Author: Mingyu Kim Closes #8763 from mingyukim/mkim/SPARK-10611. --- .../org/apache/spark/rdd/BinaryFileRDD.scala | 5 ++- .../org/apache/spark/rdd/NewHadoopRDD.scala | 37 ++++++++++++++++--- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 6fec00dcd0d85..aedced7408cde 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -34,12 +34,13 @@ private[spark] class BinaryFileRDD[T]( override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + val conf = getConf inputFormat match { case configurable: Configurable => - configurable.setConf(getConf) + configurable.setConf(conf) case _ => } - val jobContext = newJobContext(getConf, jobId) + val jobContext = newJobContext(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 174979aaeb231..2872b93b8730e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -44,7 +44,6 @@ private[spark] class NewHadoopPartition( extends Partition { val serializableHadoopSplit = new SerializableWritable(rawSplit) - override def hashCode(): Int = 41 * (41 + rddId) + index } @@ -84,6 +83,27 @@ class NewHadoopRDD[K, V]( @transient protected val jobId = new JobID(jobTrackerId, id) + private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) + + def getConf: Configuration = { + val conf: Configuration = confBroadcast.value.value + if (shouldCloneJobConf) { + // Hadoop Configuration objects are not thread-safe, which may lead to various problems if + // one job modifies a configuration while another reads it (SPARK-2546, SPARK-10611). This + // problem occurs somewhat rarely because most jobs treat the configuration as though it's + // immutable. One solution, implemented here, is to clone the Configuration object. + // Unfortunately, this clone can be very expensive. To avoid unexpected performance + // regressions for workloads and Hadoop versions that do not suffer from these thread-safety + // issues, this cloning is disabled by default. + NewHadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Cloning Hadoop Configuration") + new Configuration(conf) + } + } else { + conf + } + } + override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance inputFormat match { @@ -104,7 +124,7 @@ class NewHadoopRDD[K, V]( val iter = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) - val conf = confBroadcast.value.value + val conf = getConf val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) @@ -230,11 +250,15 @@ class NewHadoopRDD[K, V]( super.persist(storageLevel) } - - def getConf: Configuration = confBroadcast.value.value } private[spark] object NewHadoopRDD { + /** + * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456). + * Therefore, we synchronize on this lock before calling new Configuration(). + */ + val CONFIGURATION_INSTANTIATION_LOCK = new Object() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. @@ -268,12 +292,13 @@ private[spark] class WholeTextFileRDD( override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + val conf = getConf inputFormat match { case configurable: Configurable => - configurable.setConf(getConf) + configurable.setConf(conf) case _ => } - val jobContext = newJobContext(getConf, jobId) + val jobContext = newJobContext(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) From c1c6e47b92c6be0dfb70630c650197f1c96d864e Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 18 Sep 2015 16:23:05 -0700 Subject: [PATCH 0065/1168] [MINOR] [ML] override toString of AttributeGroup This makes equality test failures much more readable. mengxr Author: Eric Liang Author: Eric Liang Closes #8826 from ericl/attrgroupstr. --- .../scala/org/apache/spark/ml/attribute/AttributeGroup.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index 457c15830fd38..2c29eeb01a921 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -183,6 +183,8 @@ class AttributeGroup private ( sum = 37 * sum + attributes.map(_.toSeq).hashCode sum } + + override def toString: String = toMetadata.toString } /** From c20cca122fec038c9b0c0478d1ee8b112446a526 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 18 Sep 2015 18:42:20 -0700 Subject: [PATCH 0066/1168] [SPARK-10623] [SQL] Fixes ORC predicate push-down When pushing down a leaf predicate, ORC `SearchArgument` builder requires an extra "parent" predicate (any one among `AND`/`OR`/`NOT`) to wrap the leaf predicate. E.g., to push down `a < 1`, we must build `AND(a < 1)` instead. Fortunately, when actually constructing the `SearchArgument`, the builder will eliminate all those unnecessary wrappers. This PR is based on #8783 authored by zhzhan. I also took the chance to simply `OrcFilters` a little bit to improve readability. Author: Cheng Lian Closes #8799 from liancheng/spark-10623/fix-orc-ppd. --- .../spark/sql/hive/orc/OrcFilters.scala | 56 ++++++++----------- .../spark/sql/hive/orc/OrcQuerySuite.scala | 30 ++++++++++ 2 files changed, 52 insertions(+), 34 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index b3d9f7f71a27d..27193f54d3a91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -31,11 +31,13 @@ import org.apache.spark.sql.sources._ * and cannot be used anymore. */ private[orc] object OrcFilters extends Logging { - def createFilter(expr: Array[Filter]): Option[SearchArgument] = { - expr.reduceOption(And).flatMap { conjunction => - val builder = SearchArgumentFactory.newBuilder() - buildSearchArgument(conjunction, builder).map(_.build()) - } + def createFilter(filters: Array[Filter]): Option[SearchArgument] = { + for { + // Combines all filters with `And`s to produce a single conjunction predicate + conjunction <- filters.reduceOption(And) + // Then tries to build a single ORC `SearchArgument` for the conjunction predicate + builder <- buildSearchArgument(conjunction, SearchArgumentFactory.newBuilder()) + } yield builder.build() } private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { @@ -102,46 +104,32 @@ private[orc] object OrcFilters extends Logging { negate <- buildSearchArgument(child, builder.startNot()) } yield negate.end() - case EqualTo(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.equals(attribute, _)) + case EqualTo(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().equals(attribute, value).end()) - case EqualNullSafe(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.nullSafeEquals(attribute, _)) + case EqualNullSafe(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().nullSafeEquals(attribute, value).end()) - case LessThan(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.lessThan(attribute, _)) + case LessThan(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().lessThan(attribute, value).end()) - case LessThanOrEqual(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.lessThanEquals(attribute, _)) + case LessThanOrEqual(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().lessThanEquals(attribute, value).end()) - case GreaterThan(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.startNot().lessThanEquals(attribute, _).end()) + case GreaterThan(attribute, value) if isSearchableLiteral(value) => + Some(builder.startNot().lessThanEquals(attribute, value).end()) - case GreaterThanOrEqual(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.startNot().lessThan(attribute, _).end()) + case GreaterThanOrEqual(attribute, value) if isSearchableLiteral(value) => + Some(builder.startNot().lessThan(attribute, value).end()) case IsNull(attribute) => - Some(builder.isNull(attribute)) + Some(builder.startAnd().isNull(attribute).end()) case IsNotNull(attribute) => Some(builder.startNot().isNull(attribute).end()) - case In(attribute, values) => - Option(values) - .filter(_.forall(isSearchableLiteral)) - .map(builder.in(attribute, _)) + case In(attribute, values) if values.forall(isSearchableLiteral) => + Some(builder.startAnd().in(attribute, values.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 8bc33fcf5d906..5eb39b1129701 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -344,4 +344,34 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } } + + test("SPARK-10623 Enable ORC PPD") { + withTempPath { dir => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + import testImplicits._ + + val path = dir.getCanonicalPath + sqlContext.range(10).coalesce(1).write.orc(path) + val df = sqlContext.read.orc(path) + + def checkPredicate(pred: Column, answer: Seq[Long]): Unit = { + checkAnswer(df.where(pred), answer.map(Row(_))) + } + + checkPredicate('id === 5, Seq(5L)) + checkPredicate('id <=> 5, Seq(5L)) + checkPredicate('id < 5, 0L to 4L) + checkPredicate('id <= 5, 0L to 5L) + checkPredicate('id > 5, 6L to 9L) + checkPredicate('id >= 5, 5L to 9L) + checkPredicate('id.isNull, Seq.empty[Long]) + checkPredicate('id.isNotNull, 0L to 9L) + checkPredicate('id.isin(1L, 3L, 5L), Seq(1L, 3L, 5L)) + checkPredicate('id > 0 && 'id < 3, 1L to 2L) + checkPredicate('id < 1 || 'id > 8, Seq(0L, 9L)) + checkPredicate(!('id > 3), 0L to 3L) + checkPredicate(!('id > 0 && 'id < 3), Seq(0L) ++ (3L to 9L)) + } + } + } } From 39667c3ecf65b35f069a4194c6684e7c503559e3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 18 Sep 2015 23:58:25 -0700 Subject: [PATCH 0067/1168] [SPARK-10474] [SQL] Aggregation fails to allocate memory for pointer array When `TungstenAggregation` hits memory pressure, it switches from hash-based to sort-based aggregation in-place. However, in the process we try to allocate the pointer array for writing to the new `UnsafeExternalSorter` *before* actually freeing the memory from the hash map. This lead to the following exception: ``` java.io.IOException: Could not acquire 65536 bytes of memory at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.initializeForWriting(UnsafeExternalSorter.java:169) at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spill(UnsafeExternalSorter.java:220) at org.apache.spark.sql.execution.UnsafeKVExternalSorter.(UnsafeKVExternalSorter.java:126) at org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap.destructAndCreateExternalSorter(UnsafeFixedWidthAggregationMap.java:257) at org.apache.spark.sql.execution.aggregate.TungstenAggregationIterator.switchToSortBasedAggregation(TungstenAggregationIterator.scala:435) ``` Author: Andrew Or Closes #8827 from andrewor14/allocate-pointer-array. --- .../unsafe/sort/UnsafeExternalSorter.java | 14 +++++- .../sql/execution/UnsafeKVExternalSorter.java | 8 ++- .../UnsafeFixedWidthAggregationMapSuite.scala | 49 ++++++++++++++++++- 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index fc364e0a895b1..14b6aafdea7df 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -159,7 +159,7 @@ public BoxedUnit apply() { /** * Allocates new sort data structures. Called when creating the sorter and after each spill. */ - private void initializeForWriting() throws IOException { + public void initializeForWriting() throws IOException { this.writeMetrics = new ShuffleWriteMetrics(); final long pointerArrayMemory = UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize); @@ -187,6 +187,14 @@ public void closeCurrentPage() { * Sort and spill the current records in response to memory pressure. */ public void spill() throws IOException { + spill(true); + } + + /** + * Sort and spill the current records in response to memory pressure. + * @param shouldInitializeForWriting whether to allocate memory for writing after the spill + */ + public void spill(boolean shouldInitializeForWriting) throws IOException { assert(inMemSorter != null); logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), @@ -217,7 +225,9 @@ public void spill() throws IOException { // written to disk. This also counts the space needed to store the sorter's pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - initializeForWriting(); + if (shouldInitializeForWriting) { + initializeForWriting(); + } } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 7db6b7ff50f22..b81f67a16b815 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -85,6 +85,7 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, // We will use the number of elements in the map as the initialSize of the // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize, // we will use 1 as its initial size if the map is empty. + // TODO: track pointer array memory used by this in-memory sorter! final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements())); @@ -123,8 +124,13 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, pageSizeBytes, inMemSorter); - sorter.spill(); + // Note: This spill doesn't actually release any memory, so if we try to allocate a new + // pointer array immediately after the spill then we may fail to acquire sufficient space + // for it (SPARK-10474). For this reason, we must initialize for writing explicitly *after* + // we have actually freed memory from our map. + sorter.spill(false /* initialize for writing */); map.free(); + sorter.initializeForWriting(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index d1f0b2b1fc52f..ada4d42f991ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,9 +23,10 @@ import scala.util.{Try, Random} import org.scalatest.Matchers -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -325,7 +326,7 @@ class UnsafeFixedWidthAggregationMapSuite // At here, we also test if copy is correct. iter.getKey.copy() iter.getValue.copy() - count += 1; + count += 1 } // 1 record was from the map and 4096 records were explicitly inserted. @@ -333,4 +334,48 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } + + testWithMemoryLeakDetection("convert to external sorter under memory pressure (SPARK-10474)") { + val smm = ShuffleMemoryManager.createForTesting(65536) + val pageSize = 4096 + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + smm, + 128, // initial capacity + pageSize, + false // disable perf metrics + ) + + // Insert into the map until we've run out of space + val rand = new Random(42) + var hasSpace = true + while (hasSpace) { + val str = rand.nextString(1024) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + if (buf == null) { + hasSpace = false + } else { + buf.setInt(0, str.length) + } + } + + // Ensure we're actually maxed out by asserting that we can't acquire even just 1 byte + assert(smm.tryToAcquire(1) === 0) + + // Convert the map into a sorter. This used to fail before the fix for SPARK-10474 + // because we would try to acquire space for the in-memory sorter pointer array before + // actually releasing the pages despite having spilled all of them. + var sorter: UnsafeKVExternalSorter = null + try { + sorter = map.destructAndCreateExternalSorter() + } finally { + if (sorter != null) { + sorter.cleanupResources() + } + } + } + } From 0f41dba59ea6e2ddfca20a1ce07bf0f44cf9c04f Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 19 Sep 2015 01:59:36 -0700 Subject: [PATCH 0068/1168] [SPARK-10584] [SQL] [DOC] Documentation about the compatible Hive version is wrong. In Spark 1.5.0, Spark SQL is compatible with Hive 0.12.0 through 1.2.1 but the documentation is wrong. /CC yhuai Author: Kousuke Saruta Closes #8776 from sarutak/SPARK-10584-2. --- docs/sql-programming-guide.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a0b911d207243..82d4243cc6b27 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1954,7 +1954,7 @@ without the need to write any code. ## Running the Thrift JDBC/ODBC server The Thrift JDBC/ODBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) -in Hive 0.13. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.13. +in Hive 1.2.1 You can test the JDBC server with the beeline script that comes with either Spark or Hive 1.2.1. To start the JDBC/ODBC server, run the following in the Spark directory: @@ -2260,8 +2260,10 @@ Several caching related features are not supported yet: ## Compatibility with Apache Hive -Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Spark -SQL is based on Hive 0.12.0 and 0.13.1. +Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. +Currently Hive SerDes and UDFs are based on Hive 1.2.1, +and Spark SQL can be connected to different versions of Hive Metastore +(from 0.12.0 to 1.2.1. Also see http://spark.apache.org/docs/latest/sql-programming-guide.html#interacting-with-different-versions-of-hive-metastore). #### Deploying in Existing Hive Warehouses From 178bfb8f1c066b10acc4d19c778f110b2edd6390 Mon Sep 17 00:00:00 2001 From: Alexis Seigneurin Date: Sat, 19 Sep 2015 12:01:22 +0100 Subject: [PATCH 0069/1168] Fixed links to the API Submitting this change on the master branch as requested in https://github.com/apache/spark/pull/8819#issuecomment-141505941 Author: Alexis Seigneurin Closes #8838 from aseigneurin/patch-2. --- docs/ml-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c5d7f990021f1..0427ac6695aa1 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -619,13 +619,13 @@ for row in selected.collect(): An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. `Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. -Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator). +Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator). `CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. `CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. -The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator) -for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator) -for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator) +The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) +for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` method in each of these evaluators. From 49fd5210a513152068b5338df421881b0178cb1e Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 19 Sep 2015 18:22:43 -0700 Subject: [PATCH 0070/1168] [SPARK-10155] [SQL] Change SqlParser to object to avoid memory leak Since `scala.util.parsing.combinator.Parsers` is thread-safe since Scala 2.10 (See [SI-4929](https://issues.scala-lang.org/browse/SI-4929)), we can change SqlParser to object to avoid memory leak. I didn't change other subclasses of `scala.util.parsing.combinator.Parsers` because there is only one instance in one SQLContext, which should not be an issue. Author: zsxwing Closes #8357 from zsxwing/sql-memory-leak. --- .../apache/spark/sql/catalyst/AbstractSparkSQLParser.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/ParserDialect.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 6 +++--- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 6 +++--- .../main/scala/org/apache/spark/sql/DataFrameWriter.scala | 4 ++-- .../src/main/scala/org/apache/spark/sql/SQLContext.scala | 6 +++--- .../src/main/scala/org/apache/spark/sql/functions.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 6 +++--- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 4 ++-- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 5898a5f93f381..2bac08eac4fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { - def parse(input: String): LogicalPlan = { + def parse(input: String): LogicalPlan = synchronized { // Initialize the Keywords. initLexical phrase(start)(new lexical.Scanner(input)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala index 554fb4eb25eb1..e21d3c05464b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala @@ -61,7 +61,7 @@ abstract class ParserDialect { */ private[spark] class DefaultParserDialect extends ParserDialect { @transient - protected val sqlParser = new SqlParser + protected val sqlParser = SqlParser override def parse(sqlText: String): LogicalPlan = { sqlParser.parse(sqlText) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index f2498861c9573..dfab2398857e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -37,9 +37,9 @@ import org.apache.spark.unsafe.types.CalendarInterval * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends AbstractSparkSQLParser with DataTypeParser { +object SqlParser extends AbstractSparkSQLParser with DataTypeParser { - def parseExpression(input: String): Expression = { + def parseExpression(input: String): Expression = synchronized { // Initialize the Keywords. initLexical phrase(projection)(new lexical.Scanner(input)) match { @@ -48,7 +48,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } - def parseTableIdentifier(input: String): TableIdentifier = { + def parseTableIdentifier(input: String): TableIdentifier = synchronized { // Initialize the Keywords. initLexical phrase(tableIdentifier)(new lexical.Scanner(input)) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3e61123c145cd..8f737c2023931 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -720,7 +720,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => - Column(new SqlParser().parseExpression(expr)) + Column(SqlParser.parseExpression(expr)) }: _*) } @@ -745,7 +745,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def filter(conditionExpr: String): DataFrame = { - filter(Column(new SqlParser().parseExpression(conditionExpr))) + filter(Column(SqlParser.parseExpression(conditionExpr))) } /** @@ -769,7 +769,7 @@ class DataFrame private[sql]( * @since 1.5.0 */ def where(conditionExpr: String): DataFrame = { - filter(Column(new SqlParser().parseExpression(conditionExpr))) + filter(Column(SqlParser.parseExpression(conditionExpr))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 745bb4ec9cf1c..03e973666e888 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -163,7 +163,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(new SqlParser().parseTableIdentifier(tableName)) + insertInto(SqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { @@ -197,7 +197,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(new SqlParser().parseTableIdentifier(tableName)) + saveAsTable(SqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e3fdd782e6ff6..f099940800cc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -590,7 +590,7 @@ class SQLContext(@transient val sparkContext: SparkContext) tableName: String, source: String, options: Map[String, String]): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -636,7 +636,7 @@ class SQLContext(@transient val sparkContext: SparkContext) source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -732,7 +732,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ def table(tableName: String): DataFrame = { - table(new SqlParser().parseTableIdentifier(tableName)) + table(SqlParser.parseTableIdentifier(tableName)) } private def table(tableIdent: TableIdentifier): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 60d9c509104d5..2467b4e48415b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -823,7 +823,7 @@ object functions { * * @group normal_funcs */ - def expr(expr: String): Column = Column(new SqlParser().parseExpression(expr)) + def expr(expr: String): Column = Column(SqlParser.parseExpression(expr)) ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index d37ba5ddc2d80..c12a734863326 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -291,12 +291,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) catalog.invalidateTable(tableIdent) } @@ -311,7 +311,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { */ @Experimental def analyze(tableName: String) { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) relation match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0a5569b0a4446..0c1b41e3377e3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -199,7 +199,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive options: Map[String, String], isExternal: Boolean): Unit = { createDataSourceTable( - new SqlParser().parseTableIdentifier(tableName), + SqlParser.parseTableIdentifier(tableName), userSpecifiedSchema, partitionColumns, provider, @@ -375,7 +375,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } def hiveDefaultTableFilePath(tableName: String): String = { - hiveDefaultTableFilePath(new SqlParser().parseTableIdentifier(tableName)) + hiveDefaultTableFilePath(SqlParser.parseTableIdentifier(tableName)) } def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { From b33fbcf2bd3f7e3e3270a759238635a32df077d4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 19 Sep 2015 21:40:21 -0700 Subject: [PATCH 0071/1168] [SPARK-10710] Remove ability to disable spilling in core and SQL It does not make much sense to set `spark.shuffle.spill` or `spark.sql.planner.externalSort` to false: I believe that these configurations were initially added as "escape hatches" to guard against bugs in the external operators, but these operators are now mature and well-tested. In addition, these configurations are not handled in a consistent way anymore: SQL's Tungsten codepath ignores these configurations and will continue to use spilling operators. Similarly, Spark Core's `tungsten-sort` shuffle manager does not respect `spark.shuffle.spill=false`. This pull request removes these configurations, adds warnings at the appropriate places, and deletes a large amount of code which was only used in code paths that did not support spilling. Author: Josh Rosen Closes #8831 from JoshRosen/remove-ability-to-disable-spilling. --- .../scala/org/apache/spark/Aggregator.scala | 59 +++++-------------- .../org/apache/spark/rdd/CoGroupedRDD.scala | 40 ++++--------- .../shuffle/hash/HashShuffleManager.scala | 8 ++- .../shuffle/sort/SortShuffleManager.scala | 10 +++- .../util/collection/ExternalSorter.scala | 6 -- .../spark/deploy/SparkSubmitSuite.scala | 22 +++---- docs/configuration.md | 14 +---- docs/sql-programming-guide.md | 7 --- python/pyspark/rdd.py | 25 +++----- python/pyspark/shuffle.py | 30 ---------- python/pyspark/tests.py | 13 +--- .../scala/org/apache/spark/sql/SQLConf.scala | 8 +-- .../spark/sql/execution/SparkStrategies.scala | 2 - .../apache/spark/sql/execution/commands.scala | 9 +++ .../org/apache/spark/sql/execution/sort.scala | 30 +--------- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 ++------ .../execution/RowFormatConvertersSuite.scala | 2 +- .../spark/sql/execution/SortSuite.scala | 4 +- 18 files changed, 81 insertions(+), 234 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 289aab9bd9e51..7196e57d5d2e2 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} +import org.apache.spark.util.collection.ExternalAppendOnlyMap /** * :: DeveloperApi :: @@ -34,59 +34,30 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. - private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) - @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] = combineValuesByKey(iter, null) - def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], - context: TaskContext): Iterator[(K, C)] = { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kv: Product2[K, V] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) - } - while (iter.hasNext) { - kv = iter.next() - combiners.changeValue(kv._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) - combiners.insertAll(iter) - updateMetrics(context, combiners) - combiners.iterator - } + def combineValuesByKey( + iter: Iterator[_ <: Product2[K, V]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator } @deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0") def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] = combineCombinersByKey(iter, null) - def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext) - : Iterator[(K, C)] = - { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kc: Product2[K, C] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 - } - while (iter.hasNext) { - kc = iter.next() - combiners.changeValue(kc._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) - combiners.insertAll(iter) - updateMetrics(context, combiners) - combiners.iterator - } + def combineCombinersByKey( + iter: Iterator[_ <: Product2[K, C]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator } /** Update task metrics after populating the external map. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 7bad749d58327..935c3babd8ea1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -26,7 +26,7 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} +import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} import org.apache.spark.util.Utils import org.apache.spark.serializer.Serializer @@ -128,8 +128,6 @@ class CoGroupedRDD[K: ClassTag]( override val partitioner: Some[Partitioner] = Some(part) override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = { - val sparkConf = SparkEnv.get.conf - val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] val numRdds = dependencies.length @@ -150,34 +148,16 @@ class CoGroupedRDD[K: ClassTag]( rddIterators += ((it, depNum)) } - if (!externalSorting) { - val map = new AppendOnlyMap[K, CoGroupCombiner] - val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => { - if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup) - } - val getCombiner: K => CoGroupCombiner = key => { - map.changeValue(key, update) - } - rddIterators.foreach { case (it, depNum) => - while (it.hasNext) { - val kv = it.next() - getCombiner(kv._1)(depNum) += kv._2 - } - } - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) - } else { - val map = createExternalMap(numRdds) - for ((it, depNum) <- rddIterators) { - map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) - } - context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) - context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) + val map = createExternalMap(numRdds) + for ((it, depNum) <- rddIterators) { + map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } + context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + new InterruptibleIterator(context, + map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } private def createExternalMap(numRdds: Int) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index c089088f409dd..0b46634b8b466 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -24,7 +24,13 @@ import org.apache.spark.shuffle._ * A ShuffleManager using hashing, that creates one output file per reduce partition on each * mapper (possibly reusing these across waves of tasks). */ -private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index d7fab351ca3b8..476cc1f303da7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,11 +19,17 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap -import org.apache.spark.{SparkConf, TaskContext, ShuffleDependency} +import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ import org.apache.spark.shuffle.hash.HashShuffleReader -private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf) private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 31230d5978b2a..2a30f751ff03d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -116,8 +116,6 @@ private[spark] class ExternalSorter[K, V, C]( private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() - private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 @@ -229,10 +227,6 @@ private[spark] class ExternalSorter[K, V, C]( * @param usingMap whether we're using a map or buffer as our current in-memory collection */ private def maybeSpillCollection(usingMap: Boolean): Unit = { - if (!spillingEnabled) { - return - } - var estimatedSize = 0L if (usingMap) { estimatedSize = map.estimateSize() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 1110ca6051a40..1fd470cd3b01d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -147,7 +147,7 @@ class SparkSubmitSuite "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "beauty", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -166,7 +166,7 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.yarn.Client") classpath should have length (0) sysProps("spark.app.name") should be ("beauty") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") sysProps("SPARK_SUBMIT") should be ("true") sysProps.keys should not contain ("spark.jars") } @@ -185,7 +185,7 @@ class SparkSubmitSuite "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "trill", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -206,7 +206,7 @@ class SparkSubmitSuite sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") sysProps("spark.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") sysProps("SPARK_SUBMIT") should be ("true") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles standalone cluster mode") { @@ -229,7 +229,7 @@ class SparkSubmitSuite "--supervise", "--driver-memory", "4g", "--driver-cores", "5", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -253,9 +253,9 @@ class SparkSubmitSuite sysProps.keys should contain ("spark.driver.memory") sysProps.keys should contain ("spark.driver.cores") sysProps.keys should contain ("spark.driver.supervise") - sysProps.keys should contain ("spark.shuffle.spill") + sysProps.keys should contain ("spark.ui.enabled") sysProps.keys should contain ("spark.submit.deployMode") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles standalone client mode") { @@ -266,7 +266,7 @@ class SparkSubmitSuite "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -277,7 +277,7 @@ class SparkSubmitSuite classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles mesos client mode") { @@ -288,7 +288,7 @@ class SparkSubmitSuite "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -299,7 +299,7 @@ class SparkSubmitSuite classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles confs with flag equivalents") { diff --git a/docs/configuration.md b/docs/configuration.md index 3700051efb448..5ec097c78aa38 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -69,7 +69,7 @@ val sc = new SparkContext(new SparkConf()) Then, you can supply configuration values at runtime: {% highlight bash %} -./bin/spark-submit --name "My app" --master local[4] --conf spark.shuffle.spill=false +./bin/spark-submit --name "My app" --master local[4] --conf spark.eventLog.enabled=false --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar {% endhighlight %} @@ -449,8 +449,8 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.memoryFraction 0.2 - Fraction of Java heap to use for aggregation and cogroups during shuffles, if - spark.shuffle.spill is true. At any given time, the collective size of + Fraction of Java heap to use for aggregation and cogroups during shuffles. + At any given time, the collective size of all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will begin to spill to disk. If spills are often, consider increasing this value at the expense of spark.storage.memoryFraction. @@ -483,14 +483,6 @@ Apart from these, the following properties are also available, and may be useful map-side aggregation and there are at most this many reduce partitions. - - spark.shuffle.spill - true - - If set to "true", limits the amount of memory used during reduces by spilling data out to disk. - This spilling threshold is specified by spark.shuffle.memoryFraction. - - spark.shuffle.spill.compress true diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 82d4243cc6b27..7ae9244c271e3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1936,13 +1936,6 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. - - spark.sql.planner.externalSort - true - - When true, performs sorts spilling to disk as needed otherwise sort each partition in memory. - - # Distributed SQL Engine diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ab5aab1e115f7..73d7d9a5692a9 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -48,7 +48,7 @@ from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable -from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ +from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync @@ -580,12 +580,11 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true") memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) serializer = self._jrdd_deserializer def sortPartition(iterator): - sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending))) return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True) @@ -610,12 +609,11 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = self._can_spill() memory = self._memory_limit() serializer = self._jrdd_deserializer def sortPartition(iterator): - sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending))) if numPartitions == 1: @@ -1770,13 +1768,11 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = self._defaultReducePartitions() serializer = self.ctx.serializer - spill = self._can_spill() memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): - merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() @@ -1784,8 +1780,7 @@ def combineLocally(iterator): shuffled = locally_combined.partitionBy(numPartitions) def _mergeCombiners(iterator): - merger = ExternalMerger(agg, memory, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory, serializer) merger.mergeCombiners(iterator) return merger.items() @@ -1824,9 +1819,6 @@ def createZero(): return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) - def _can_spill(self): - return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" - def _memory_limit(self): return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) @@ -1857,14 +1849,12 @@ def mergeCombiners(a, b): a.extend(b) return a - spill = self._can_spill() memory = self._memory_limit() serializer = self._jrdd_deserializer agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combine(iterator): - merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() @@ -1872,8 +1862,7 @@ def combine(iterator): shuffled = locally_combined.partitionBy(numPartitions) def groupByKey(it): - merger = ExternalGroupBy(agg, memory, serializer)\ - if spill else InMemoryMerger(agg) + merger = ExternalGroupBy(agg, memory, serializer) merger.mergeCombiners(it) return merger.items() diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index b8118bdb7ca76..e974cda9fc3e1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -131,36 +131,6 @@ def items(self): raise NotImplementedError -class InMemoryMerger(Merger): - - """ - In memory merger based on in-memory dict. - """ - - def __init__(self, aggregator): - Merger.__init__(self, aggregator) - self.data = {} - - def mergeValues(self, iterator): - """ Combine the items by creator and combiner """ - # speed up attributes lookup - d, creator = self.data, self.agg.createCombiner - comb = self.agg.mergeValue - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else creator(v) - - def mergeCombiners(self, iterator): - """ Merge the combined items by mergeCombiner """ - # speed up attributes lookup - d, comb = self.data, self.agg.mergeCombiners - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else v - - def items(self): - """ Return the merged items ad iterator """ - return iter(self.data.items()) - - def _compressed_serializer(self, serializer=None): # always use PickleSerializer to simplify implementation ser = PickleSerializer() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 647504c32f156..f11aaf001c8df 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -62,7 +62,7 @@ CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ FlattenedValuesSerializer -from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter +from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler @@ -95,17 +95,6 @@ def setUp(self): lambda x, y: x.append(y) or x, lambda x, y: x.extend(y) or x) - def test_in_memory(self): - m = InMemoryMerger(self.agg) - m.mergeValues(self.data) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = InMemoryMerger(self.agg) - m.mergeCombiners(map(lambda x_y: (x_y[0], [x_y[1]]), self.data)) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - def test_small_dataset(self): m = ExternalMerger(self.agg, 1000) m.mergeValues(self.data) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 9de75f4c4d084..b9fb90d964206 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -330,11 +330,6 @@ private[spark] object SQLConf { // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. - val EXTERNAL_SORT = booleanConf("spark.sql.planner.externalSort", - defaultValue = Some(true), - doc = "When true, performs sorts spilling to disk as needed otherwise sort each partition in" + - " memory.") - val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", defaultValue = Some(true), doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.") @@ -422,6 +417,7 @@ private[spark] object SQLConf { object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + val EXTERNAL_SORT = "spark.sql.planner.externalSort" } } @@ -476,8 +472,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) - private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) - private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5e40d77689045..41b215c79296a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -312,8 +312,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && TungstenSort.supportsSchema(child.schema)) { execution.TungstenSort(sortExprs, global, child) - } else if (sqlContext.conf.externalSortEnabled) { - execution.ExternalSort(sortExprs, global, child) } else { execution.Sort(sortExprs, global, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 95209e6634519..af28e2dfa4186 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -105,6 +105,15 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Deprecated.EXTERNAL_SORT, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.EXTERNAL_SORT} is deprecated and will be ignored. " + + s"External sort will continue to be used.") + Seq(Row(SQLConf.Deprecated.EXTERNAL_SORT, "true")) + } + (keyValueOutput, runFunc) + // Configures a single property. case Some((key, Some(value))) => val runFunc = (sqlContext: SQLContext) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 40ef7c3b53530..27f26245a5ef0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -31,38 +31,12 @@ import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} // This file defines various sort operators. //////////////////////////////////////////////////////////////////////////////////////////////////// - -/** - * Performs a sort on-heap. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -case class Sort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - iterator.map(_.copy()).toArray.sorted(ordering).iterator - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - /** * Performs a sort, spilling to disk as needed. * @param global when true performs a global sort of all partitions by shuffling the data first * if necessary. */ -case class ExternalSort( +case class Sort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan) @@ -93,7 +67,7 @@ case class ExternalSort( } /** - * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of + * Optimized version of [[Sort]] that operates on binary data (implemented as part of * Project Tungsten). * * @param global when true performs a global sort of all partitions by shuffling the data first diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f9981356f364f..05b4127cbcaff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -581,28 +581,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq) } - test("sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false") { - sortTest() - } - } - test("external sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true") { - sortTest() - } - } - - test("SPARK-6927 sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "true") { - sortTest() - } + sortTest() } test("SPARK-6927 external sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true", - SQLConf.CODEGEN_ENABLED.key -> "true") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { sortTest() } } @@ -1731,10 +1715,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("external sorting updates peak execution memory") { - withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { - AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { - sortTest() - } + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { + sortTest() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 4492e37ad01ff..5dc37e5c3c238 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -32,7 +32,7 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { case c: ConvertToSafe => c } - private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + private val outputsSafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(!outputsSafe.outputsUnsafeRows) private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(outputsUnsafe.outputsUnsafeRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 3073d492e613b..847c188a30333 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -36,13 +36,13 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + Sort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + Sort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), sortAnswers = false) } From 9c690bb2bd611e07e82c721e746f701ae82252dc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 20 Sep 2015 16:05:12 -0700 Subject: [PATCH 0072/1168] [SPARK-5905] [MLLIB] Note requirements for certain RowMatrix methods in docs Note methods that fail for cols > 65535; note that SVD does not require n >= m CC mengxr Author: Sean Owen Closes #8839 from srowen/SPARK-5905. --- .../spark/mllib/linalg/distributed/RowMatrix.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index e55ef26858adb..7c7d900af3d5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -109,7 +109,8 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the Gramian matrix `A^T A`. + * Computes the Gramian matrix `A^T A`. Note that this cannot be computed on matrices with + * more than 65535 columns. */ @Since("1.0.0") def computeGramianMatrix(): Matrix = { @@ -150,7 +151,8 @@ class RowMatrix @Since("1.0.0") ( * - s is a Vector of size k, holding the singular values in descending order, * - V is a Matrix of size n x k that satisfies V' * V = eye(k). * - * We assume n is smaller than m. The singular values and the right singular vectors are derived + * We assume n is smaller than m, though this is not strictly required. + * The singular values and the right singular vectors are derived * from the eigenvalues and the eigenvectors of the Gramian matrix A' * A. U, the matrix * storing the right singular vectors, is computed via matrix multiplication as * U = A * (V * S^-1^), if requested by user. The actual method to use is determined @@ -320,7 +322,8 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the covariance matrix, treating each row as an observation. + * Computes the covariance matrix, treating each row as an observation. Note that this cannot + * be computed on matrices with more than 65535 columns. * @return a local dense matrix of size n x n */ @Since("1.0.0") @@ -374,6 +377,8 @@ class RowMatrix @Since("1.0.0") ( * The row data do not need to be "centered" first; it is not necessary for * the mean of each column to be 0. * + * Note that this cannot be computed on matrices with more than 65535 columns. + * * @param k number of top principal components. * @return a matrix of size n-by-k, whose columns are principal components */ From c63e39f490438e8c78eabf590eed526d423bd92f Mon Sep 17 00:00:00 2001 From: lewuathe Date: Sun, 20 Sep 2015 16:16:31 -0700 Subject: [PATCH 0073/1168] [SPARK-10715] [ML] Duplicate initialization flag in WeightedLeastSquare There are duplicate set of initialization flag in `WeightedLeastSquares#add`. `initialized` is already set in `init(Int)`. Author: lewuathe Closes #8837 from Lewuathe/duplicate-initialization-flag. --- .../scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 0ff8931b0bab4..4374e99631560 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -193,7 +193,6 @@ private[ml] object WeightedLeastSquares { val ak = a.size if (!initialized) { init(ak) - initialized = true } assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.") count += 1L From 88db1f9158a4f6114acc4f74a491266c1cfbbf67 Mon Sep 17 00:00:00 2001 From: vinodkc Date: Sun, 20 Sep 2015 22:55:24 -0700 Subject: [PATCH 0074/1168] [SPARK-10631] [DOCUMENTATION, MLLIB, PYSPARK] Added documentation for few APIs There are some missing API docs in pyspark.mllib.linalg.Vector (including DenseVector and SparseVector). We should add them based on their Scala counterparts. Author: vinodkc Closes #8834 from vinodkc/fix_SPARK-10631. --- python/pyspark/mllib/linalg/__init__.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 4829acb16ed8a..f929e3e96fbe2 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -301,11 +301,14 @@ def __reduce__(self): return DenseVector, (self.array.tostring(),) def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros + """ return np.count_nonzero(self.array) def norm(self, p): """ - Calculte the norm of a DenseVector. + Calculates the norm of a DenseVector. >>> a = DenseVector([0, -1, 2, -3]) >>> a.norm(2) @@ -397,10 +400,16 @@ def squared_distance(self, other): return np.dot(diff, diff) def toArray(self): + """ + Returns an numpy.ndarray + """ return self.array @property def values(self): + """ + Returns a list of values + """ return self.array def __getitem__(self, item): @@ -479,8 +488,8 @@ def __init__(self, size, *args): :param size: Size of the vector. :param args: Active entries, as a dictionary {index: value, ...}, - a list of tuples [(index, value), ...], or a list of strictly i - ncreasing indices and a list of corresponding values [index, ...], + a list of tuples [(index, value), ...], or a list of strictly + increasing indices and a list of corresponding values [index, ...], [value, ...]. Inactive entries are treated as zeros. >>> SparseVector(4, {1: 1.0, 3: 5.5}) @@ -521,11 +530,14 @@ def __init__(self, size, *args): raise TypeError("indices array must be sorted") def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros. + """ return np.count_nonzero(self.values) def norm(self, p): """ - Calculte the norm of a SparseVector. + Calculates the norm of a SparseVector. >>> a = SparseVector(4, [0, 1], [3., -4.]) >>> a.norm(1) @@ -797,7 +809,7 @@ def sparse(size, *args): values (sorted by index). :param size: Size of the vector. - :param args: Non-zero entries, as a dictionary, list of tupes, + :param args: Non-zero entries, as a dictionary, list of tuples, or two sorted lists containing indices and values. >>> Vectors.sparse(4, {1: 1.0, 3: 5.5}) From dabe52c0ff0480fc05cf16b6b9e1003eaeb080e5 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 18:53:28 +0100 Subject: [PATCH 0075/1168] [SPARK-10626] [MLLIB] create java friendly method for random rdd SPARK-3136 added a large number of functions for creating Java RandomRDDs, but for people that want to use custom RandomDataGenerators we should make a Java friendly method. Author: Holden Karau Closes #8782 from holdenk/SPARK-10626-create-java-friendly-method-for-randomRDD. --- .../spark/mllib/random/RandomRDDs.scala | 52 ++++++++++++++++++- .../mllib/random/JavaRandomRDDsSuite.java | 30 +++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 4dd5ea214d678..f8ff26b5795be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD} import org.apache.spark.rdd.RDD @@ -381,7 +382,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of `i.i.d.` samples produced by generator. + * @return RDD[T] comprised of `i.i.d.` samples produced by generator. */ @DeveloperApi @Since("1.1.0") @@ -394,6 +395,55 @@ object RandomRDDs { new RandomRDD[T](sc, size, numPartitionsOrDefault(sc, numPartitions), generator, seed) } + /** + * :: DeveloperApi :: + * Generates an RDD comprised of `i.i.d.` samples produced by the input RandomDataGenerator. + * + * @param jsc JavaSparkContext used to create the RDD. + * @param generator RandomDataGenerator used to populate the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[T] comprised of `i.i.d.` samples produced by generator. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long, + numPartitions: Int, + seed: Long): JavaRDD[T] = { + implicit val ctag: ClassTag[T] = fakeClassTag + val rdd = randomRDD(jsc.sc, generator, size, numPartitions, seed) + JavaRDD.fromRDD(rdd) + } + + /** + * [[RandomRDDs#randomJavaRDD]] with the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long, + numPartitions: Int): JavaRDD[T] = { + randomJavaRDD(jsc, generator, size, numPartitions, Utils.random.nextLong()) + } + + /** + * [[RandomRDDs#randomJavaRDD]] with the default seed & numPartitions + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long): JavaRDD[T] = { + randomJavaRDD(jsc, generator, size, 0); + } + // TODO Generate RDD[Vector] from multivariate distributions. /** diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index 33d81b1e9592b..fce5f6712f462 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.random; +import java.io.Serializable; import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; @@ -231,4 +232,33 @@ public void testGammaVectorRDD() { } } + @Test + public void testArbitrary() { + long size = 10; + long seed = 1L; + int numPartitions = 0; + StringGenerator gen = new StringGenerator(); + JavaRDD rdd1 = randomJavaRDD(sc, gen, size); + JavaRDD rdd2 = randomJavaRDD(sc, gen, size, numPartitions); + JavaRDD rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed); + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(size, rdd.count()); + Assert.assertEquals(2, rdd.first().length()); + } + } +} + +// This is just a test generator, it always returns a string of 42 +class StringGenerator implements RandomDataGenerator, Serializable { + @Override + public String nextValue() { + return "42"; + } + @Override + public StringGenerator copy() { + return new StringGenerator(); + } + @Override + public void setSeed(long seed) { + } } From 287b68027f8b29631ec0c42986465ab3c28dc4e2 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 21 Sep 2015 11:39:04 -0700 Subject: [PATCH 0076/1168] [SPARK-7989] [SPARK-10651] [CORE] [TESTS] Increase timeout to fix flaky tests I noticed only one block manager registered with master in an unsuccessful build (https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.2,label=spark-test/3534/) ``` 15/09/16 13:02:30.981 pool-1-thread-1-ScalaTest-running-BroadcastSuite INFO SparkContext: Running Spark version 1.6.0-SNAPSHOT ... 15/09/16 13:02:38.133 sparkDriver-akka.actor.default-dispatcher-19 INFO BlockManagerMasterEndpoint: Registering block manager localhost:48196 with 530.3 MB RAM, BlockManagerId(0, localhost, 48196) ``` In addition, the first block manager needed 7+ seconds to start. But the test expected 2 block managers so it failed. However, there was no exception in this log file. So I checked a successful build (https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/3536/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.2,label=spark-test/) and it needed 4-5 seconds to set up the local cluster: ``` 15/09/16 18:11:27.738 sparkWorker1-akka.actor.default-dispatcher-5 INFO Worker: Running Spark version 1.6.0-SNAPSHOT ... 15/09/16 18:11:30.838 sparkDriver-akka.actor.default-dispatcher-20 INFO BlockManagerMasterEndpoint: Registering block manager localhost:54202 with 530.3 MB RAM, BlockManagerId(1, localhost, 54202) 15/09/16 18:11:32.112 sparkDriver-akka.actor.default-dispatcher-20 INFO BlockManagerMasterEndpoint: Registering block manager localhost:32955 with 530.3 MB RAM, BlockManagerId(0, localhost, 32955) ``` In this build, the first block manager needed only 3+ seconds to start. Comparing these two builds, I guess it's possible that the local cluster in `BroadcastSuite` cannot be ready in 10 seconds if the Jenkins worker is busy. So I just increased the timeout to 60 seconds to see if this can fix the issue. Author: zsxwing Closes #8813 from zsxwing/fix-BroadcastSuite. --- .../scala/org/apache/spark/ExternalShuffleServiceSuite.scala | 2 +- .../test/scala/org/apache/spark/broadcast/BroadcastSuite.scala | 2 +- .../apache/spark/scheduler/SparkListenerWithClusterSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index e846a72c888c6..231f4631e0a47 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -61,7 +61,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. // In this case, we won't receive FetchFailed. And it will make this test fail. // Therefore, we should wait until all slaves are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index fb7a8ae3f9d41..ba21075ce6be5 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -311,7 +311,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up try { - _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) + _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 60000) _sc } catch { case e: Throwable => diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index d1e23ed527ff1..9fa8859382911 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -43,7 +43,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext // This test will check if the number of executors received by "SparkListener" is same as the // number of all executors, so we need to wait until all executors are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) From 95b3fd9d7ead3b6125a4b6b3f489bad0e27a6cbf Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 21 Sep 2015 19:46:39 +0100 Subject: [PATCH 0077/1168] [SPARK-10662] [DOCS] Code snippets are not properly formatted in tables * Backticks are processed properly in Spark Properties table * Removed unnecessary spaces * See http://people.apache.org/~pwendell/spark-nightly/spark-master-docs/latest/running-on-yarn.html Author: Jacek Laskowski Closes #8795 from jaceklaskowski/docs-yarn-formatting. --- docs/configuration.md | 97 +++++++++++++++-------------- docs/programming-guide.md | 100 +++++++++++++++--------------- docs/running-on-mesos.md | 14 ++--- docs/running-on-yarn.md | 106 ++++++++++++++++---------------- docs/sql-programming-guide.md | 16 ++--- docs/submitting-applications.md | 8 +-- 6 files changed, 171 insertions(+), 170 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 5ec097c78aa38..b22587c70316b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -34,20 +34,20 @@ val conf = new SparkConf() val sc = new SparkContext(conf) {% endhighlight %} -Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may +Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may actually require one to prevent any sort of starvation issues. -Properties that specify some time duration should be configured with a unit of time. +Properties that specify some time duration should be configured with a unit of time. The following format is accepted: - + 25ms (milliseconds) 5s (seconds) 10m or 10min (minutes) 3h (hours) 5d (days) 1y (years) - - + + Properties that specify a byte size should be configured with a unit of size. The following format is accepted: @@ -140,7 +140,7 @@ of the most common options to set are: Amount of memory to use for the driver process, i.e. where SparkContext is initialized. (e.g. 1g, 2g). - +
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-memory command line option @@ -207,7 +207,7 @@ Apart from these, the following properties are also available, and may be useful
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-class-path command line option or in + Instead, please set this through the --driver-class-path command line option or in your default properties file. @@ -216,10 +216,10 @@ Apart from these, the following properties are also available, and may be useful (none) A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. - +
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-java-options command line option or in + Instead, please set this through the --driver-java-options command line option or in your default properties file. @@ -228,10 +228,10 @@ Apart from these, the following properties are also available, and may be useful (none) Set a special library path to use when launching the driver JVM. - +
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-library-path command line option or in + Instead, please set this through the --driver-library-path command line option or in your default properties file. @@ -242,7 +242,7 @@ Apart from these, the following properties are also available, and may be useful (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading classes in the the driver. This feature can be used to mitigate conflicts between Spark's dependencies and user dependencies. It is currently an experimental feature. - + This is used in cluster mode only. @@ -250,8 +250,8 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraClassPath (none) - Extra classpath entries to prepend to the classpath of executors. This exists primarily for - backwards-compatibility with older versions of Spark. Users typically should not need to set + Extra classpath entries to prepend to the classpath of executors. This exists primarily for + backwards-compatibility with older versions of Spark. Users typically should not need to set this option. @@ -259,9 +259,9 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraJavaOptions (none) - A string of extra JVM options to pass to executors. For instance, GC settings or other logging. - Note that it is illegal to set Spark properties or heap size settings with this option. Spark - properties should be set using a SparkConf object or the spark-defaults.conf file used with the + A string of extra JVM options to pass to executors. For instance, GC settings or other logging. + Note that it is illegal to set Spark properties or heap size settings with this option. Spark + properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Heap size settings can be set with spark.executor.memory. @@ -305,7 +305,7 @@ Apart from these, the following properties are also available, and may be useful daily Set the time interval by which the executor logs will be rolled over. - Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or + Rolling is disabled by default. Valid values are daily, hourly, minutely or any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs. @@ -330,13 +330,13 @@ Apart from these, the following properties are also available, and may be useful spark.python.profile false - Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`, + Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), or it will be displayed before the driver exiting. It also can be dumped into disk by - `sc.dump_profiles(path)`. If some of the profile results had been displayed manually, + sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. - By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by - passing a profiler class in as a parameter to the `SparkContext` constructor. + By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by + passing a profiler class in as a parameter to the SparkContext constructor. @@ -460,11 +460,11 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.service.enabled false - Enables the external shuffle service. This service preserves the shuffle files written by - executors so the executors can be safely removed. This must be enabled if + Enables the external shuffle service. This service preserves the shuffle files written by + executors so the executors can be safely removed. This must be enabled if spark.dynamicAllocation.enabled is "true". The external shuffle service must be set up in order to enable it. See - dynamic allocation + dynamic allocation configuration and setup documentation for more information. @@ -747,9 +747,9 @@ Apart from these, the following properties are also available, and may be useful 1 in YARN mode, all the available cores on the worker in standalone mode. The number of cores to use on each executor. For YARN and standalone mode only. - - In standalone mode, setting this parameter allows an application to run multiple executors on - the same worker, provided that there are enough cores on that worker. Otherwise, only one + + In standalone mode, setting this parameter allows an application to run multiple executors on + the same worker, provided that there are enough cores on that worker. Otherwise, only one executor per application will run on each worker. @@ -893,14 +893,14 @@ Apart from these, the following properties are also available, and may be useful spark.akka.heartbeat.interval 1000s - This is set to a larger value to disable the transport failure detector that comes built in to - Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger - interval value reduces network overhead and a smaller value ( ~ 1 s) might be more - informative for Akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` - if you need to. A likely positive use case for using failure detector would be: a sensistive - failure detector can help evict rogue executors quickly. However this is usually not the case - as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling - this leads to a lot of exchanges of heart beats between nodes leading to flooding the network + This is set to a larger value to disable the transport failure detector that comes built in to + Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger + interval value reduces network overhead and a smaller value ( ~ 1 s) might be more + informative for Akka's failure detector. Tune this in combination of spark.akka.heartbeat.pauses + if you need to. A likely positive use case for using failure detector would be: a sensistive + failure detector can help evict rogue executors quickly. However this is usually not the case + as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling + this leads to a lot of exchanges of heart beats between nodes leading to flooding the network with those. @@ -909,9 +909,9 @@ Apart from these, the following properties are also available, and may be useful 6000s This is set to a larger value to disable the transport failure detector that comes built in to Akka. - It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart + It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause for Akka. This can be used to control sensitivity to GC pauses. Tune - this along with `spark.akka.heartbeat.interval` if you need to. + this along with spark.akka.heartbeat.interval if you need to. @@ -978,7 +978,7 @@ Apart from these, the following properties are also available, and may be useful spark.network.timeout 120s - Default timeout for all network interactions. This config will be used in place of + Default timeout for all network interactions. This config will be used in place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, spark.storage.blockManagerSlaveTimeoutMs, spark.shuffle.io.connectionTimeout, spark.rpc.askTimeout or @@ -991,8 +991,8 @@ Apart from these, the following properties are also available, and may be useful Maximum number of retries when binding to a port before giving up. When a port is given a specific value (non 0), each subsequent retry will - increment the port used in the previous attempt by 1 before retrying. This - essentially allows it to try a range of ports from the start port specified + increment the port used in the previous attempt by 1 before retrying. This + essentially allows it to try a range of ports from the start port specified to port + maxRetries. @@ -1191,7 +1191,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.executorIdleTimeout 60s - If dynamic allocation is enabled and an executor has been idle for more than this duration, + If dynamic allocation is enabled and an executor has been idle for more than this duration, the executor will be removed. For more detail, see this description. @@ -1424,11 +1424,11 @@ Apart from these, the following properties are also available, and may be useful false Enables or disables Spark Streaming's internal backpressure mechanism (since 1.5). - This enables the Spark Streaming to control the receiving rate based on the + This enables the Spark Streaming to control the receiving rate based on the current batch scheduling delays and processing times so that the system receives - only as fast as the system can process. Internally, this dynamically sets the + only as fast as the system can process. Internally, this dynamically sets the maximum receiving rate of receivers. This rate is upper bounded by the values - `spark.streaming.receiver.maxRate` and `spark.streaming.kafka.maxRatePerPartition` + spark.streaming.receiver.maxRate and spark.streaming.kafka.maxRatePerPartition if they are set (see below). @@ -1542,15 +1542,15 @@ The following variables can be set in `spark-env.sh`: Environment VariableMeaning JAVA_HOME - Location where Java is installed (if it's not on your default `PATH`). + Location where Java is installed (if it's not on your default PATH). PYSPARK_PYTHON - Python binary executable to use for PySpark in both driver and workers (default is `python`). + Python binary executable to use for PySpark in both driver and workers (default is python). PYSPARK_DRIVER_PYTHON - Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). + Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). SPARK_LOCAL_IP @@ -1580,4 +1580,3 @@ Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can config To specify a different configuration directory other than the default "SPARK_HOME/conf", you can set SPARK_CONF_DIR. Spark will use the the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc) from this directory. - diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 4cf83bb392636..8ad238315f12c 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -182,8 +182,8 @@ in-process. In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add JARs to the classpath -by passing a comma-separated list to the `--jars` argument. You can also add dependencies -(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +by passing a comma-separated list to the `--jars` argument. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) can be passed to the `--repositories` argument. For example, to run `bin/spark-shell` on exactly four cores, use: @@ -217,7 +217,7 @@ context connects to using the `--master` argument, and you can add Python .zip, to the runtime path by passing a comma-separated list to `--py-files`. You can also add dependencies (e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) -can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in +can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in the requirements.txt of that package) must be manually installed using pip when necessary. For example, to run `bin/pyspark` on exactly four cores, use: @@ -249,8 +249,8 @@ the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support $ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook" ./bin/pyspark {% endhighlight %} -After the IPython Notebook server is launched, you can create a new "Python 2" notebook from -the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of +After the IPython Notebook server is launched, you can create a new "Python 2" notebook from +the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of your notebook before you start to try Spark from the IPython notebook.
@@ -418,9 +418,9 @@ Apart from text files, Spark's Python API also supports several other data forma **Writable Support** -PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the -resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, -PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following +PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the +resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, +PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following Writables are automatically converted: @@ -435,9 +435,9 @@ Writables are automatically converted:
MapWritabledict
-Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, -users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default -converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get +Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, +users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default +converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get Python `array.array` for arrays of primitive types, users need to specify custom converters. **Saving and Loading SequenceFiles** @@ -454,7 +454,7 @@ classes can be specified, but for standard Writables this is not required. **Saving and Loading Other Hadoop Input/Output Formats** -PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. +PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. If required, a Hadoop configuration can be passed in as a Python dict. Here is an example using the Elasticsearch ESInputFormat: @@ -474,15 +474,15 @@ Note that, if the InputFormat simply depends on a Hadoop configuration and/or in the key and value classes can easily be converted according to the above table, then this approach should work well for such cases. -If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to +If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to transform that data on the Scala/Java side to something which can be handled by Pyrolite's pickler. -A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided -for this. Simply extend this trait and implement your transformation code in the ```convert``` -method. Remember to ensure that this class, along with any dependencies required to access your ```InputFormat```, are packaged into your Spark job jar and included on the PySpark +A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided +for this. Simply extend this trait and implement your transformation code in the ```convert``` +method. Remember to ensure that this class, along with any dependencies required to access your ```InputFormat```, are packaged into your Spark job jar and included on the PySpark classpath. -See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and -the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) +See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and +the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` with custom converters.
@@ -758,7 +758,7 @@ One of the harder things about Spark is understanding the scope and life cycle o #### Example -Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN): +Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN):
@@ -777,7 +777,7 @@ println("Counter value: " + counter)
{% highlight java %} int counter = 0; -JavaRDD rdd = sc.parallelize(data); +JavaRDD rdd = sc.parallelize(data); // Wrong: Don't do this!! rdd.foreach(x -> counter += x); @@ -803,7 +803,7 @@ print("Counter value: " + counter) #### Local vs. cluster modes -The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. +The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. However, in `cluster` mode, what happens is more complicated, and the above may not work as intended. To execute jobs, Spark breaks up the processing of RDD operations into tasks - each of which is operated on by an executor. Prior to execution, Spark computes the **closure**. The closure is those variables and methods which must be visible for the executor to perform its computations on the RDD (in this case `foreach()`). This closure is serialized and sent to each executor. In `local` mode, there is only the one executors so everything shares the same closure. In other modes however, this is not the case and the executors running on seperate worker nodes each have their own copy of the closure. @@ -813,9 +813,9 @@ To ensure well-defined behavior in these sorts of scenarios one should use an [` In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. -#### Printing elements of an RDD +#### Printing elements of an RDD Another common idiom is attempting to print out the elements of an RDD using `rdd.foreach(println)` or `rdd.map(println)`. On a single machine, this will generate the expected output and print all the RDD's elements. However, in `cluster` mode, the output to `stdout` being called by the executors is now writing to the executor's `stdout` instead, not the one on the driver, so `stdout` on the driver won't show these! To print all elements on the driver, one can use the `collect()` method to first bring the RDD to the driver node thus: `rdd.collect().foreach(println)`. This can cause the driver to run out of memory, though, because `collect()` fetches the entire RDD to a single machine; if you only need to print a few elements of the RDD, a safer approach is to use the `take()`: `rdd.take(100).foreach(println)`. - + ### Working with Key-Value Pairs
@@ -859,7 +859,7 @@ only available on RDDs of key-value pairs. The most common ones are distributed "shuffle" operations, such as grouping or aggregating the elements by a key. -In Java, key-value pairs are represented using the +In Java, key-value pairs are represented using the [scala.Tuple2](http://www.scala-lang.org/api/{{site.SCALA_VERSION}}/index.html#scala.Tuple2) class from the Scala standard library. You can simply call `new Tuple2(a, b)` to create a tuple, and access its fields later with `tuple._1()` and `tuple._2()`. @@ -974,7 +974,7 @@ for details. groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
Note: If you are grouping in order to perform an aggregation (such as a sum or - average) over each key, using reduceByKey or aggregateByKey will yield much better + average) over each key, using reduceByKey or aggregateByKey will yield much better performance.
Note: By default, the level of parallelism in the output depends on the number of partitions of the parent RDD. @@ -1025,7 +1025,7 @@ for details. repartitionAndSortWithinPartitions(partitioner) Repartition the RDD according to the given partitioner and, within each resulting partition, - sort records by their keys. This is more efficient than calling repartition and then sorting within + sort records by their keys. This is more efficient than calling repartition and then sorting within each partition because it can push the sorting down into the shuffle machinery. @@ -1038,7 +1038,7 @@ RDD API doc [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), [Python](api/python/pyspark.html#pyspark.RDD), [R](api/R/index.html)) - + and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1094,7 +1094,7 @@ for details. foreach(func) - Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems. + Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems.
Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. @@ -1118,13 +1118,13 @@ co-located to compute the result. In Spark, data is generally not distributed across partitions to be in the necessary place for a specific operation. During computations, a single task will operate on a single partition - thus, to organize all the data for a single `reduceByKey` reduce task to execute, Spark needs to perform an -all-to-all operation. It must read from all partitions to find all the values for all keys, -and then bring together values across partitions to compute the final result for each key - +all-to-all operation. It must read from all partitions to find all the values for all keys, +and then bring together values across partitions to compute the final result for each key - this is called the **shuffle**. Although the set of elements in each partition of newly shuffled data will be deterministic, and so -is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably -ordered data following shuffle then it's possible to use: +is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably +ordered data following shuffle then it's possible to use: * `mapPartitions` to sort each partition using, for example, `.sorted` * `repartitionAndSortWithinPartitions` to efficiently sort partitions while simultaneously repartitioning @@ -1141,26 +1141,26 @@ network I/O. To organize data for the shuffle, Spark generates sets of tasks - * organize the data, and a set of *reduce* tasks to aggregate it. This nomenclature comes from MapReduce and does not directly relate to Spark's `map` and `reduce` operations. -Internally, results from individual map tasks are kept in memory until they can't fit. Then, these -are sorted based on the target partition and written to a single file. On the reduce side, tasks +Internally, results from individual map tasks are kept in memory until they can't fit. Then, these +are sorted based on the target partition and written to a single file. On the reduce side, tasks read the relevant sorted blocks. - -Certain shuffle operations can consume significant amounts of heap memory since they employ -in-memory data structures to organize records before or after transferring them. Specifically, -`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations -generate these on the reduce side. When data does not fit in memory Spark will spill these tables + +Certain shuffle operations can consume significant amounts of heap memory since they employ +in-memory data structures to organize records before or after transferring them. Specifically, +`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations +generate these on the reduce side. When data does not fit in memory Spark will spill these tables to disk, incurring the additional overhead of disk I/O and increased garbage collection. Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files -are preserved until the corresponding RDDs are no longer used and are garbage collected. -This is done so the shuffle files don't need to be re-created if the lineage is re-computed. -Garbage collection may happen only after a long period time, if the application retains references -to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may +are preserved until the corresponding RDDs are no longer used and are garbage collected. +This is done so the shuffle files don't need to be re-created if the lineage is re-computed. +Garbage collection may happen only after a long period time, if the application retains references +to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may consume a large amount of disk space. The temporary storage directory is specified by the `spark.local.dir` configuration parameter when configuring the Spark context. Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the -'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). +'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). ## RDD Persistence @@ -1246,7 +1246,7 @@ efficiency. We recommend going through the following process to select one: This is the most CPU-efficient option, allowing operations on the RDDs to run as fast as possible. * If not, try using `MEMORY_ONLY_SER` and [selecting a fast serialization library](tuning.html) to -make the objects much more space-efficient, but still reasonably fast to access. +make the objects much more space-efficient, but still reasonably fast to access. * Don't spill to disk unless the functions that computed your datasets are expensive, or they filter a large amount of the data. Otherwise, recomputing a partition may be as fast as reading it from @@ -1345,7 +1345,7 @@ Accumulators are variables that are only "added" to through an associative opera therefore be efficiently supported in parallel. They can be used to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric types, and programmers can add support for new types. If accumulators are created with a name, they will be -displayed in Spark's UI. This can be useful for understanding the progress of +displayed in Spark's UI. This can be useful for understanding the progress of running stages (NOTE: this is not yet supported in Python). An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks @@ -1474,8 +1474,8 @@ vecAccum = sc.accumulator(Vector(...), VectorAccumulatorParam())
-For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator -will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware +For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator +will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware of that each task's update may be applied more than once if tasks or job stages are re-executed. Accumulators do not change the lazy evaluation model of Spark. If they are being updated within an operation on an RDD, their value is only updated once that RDD is computed as part of an action. Consequently, accumulator updates are not guaranteed to be executed when made within a lazy transformation like `map()`. The below code fragment demonstrates this property: @@ -1486,7 +1486,7 @@ Accumulators do not change the lazy evaluation model of Spark. If they are being {% highlight scala %} val accum = sc.accumulator(0) data.map { x => accum += x; f(x) } -// Here, accum is still 0 because no actions have caused the `map` to be computed. +// Here, accum is still 0 because no actions have caused the map to be computed. {% endhighlight %}
@@ -1553,7 +1553,7 @@ Several changes were made to the Java API: code that `extends Function` should `implement Function` instead. * New variants of the `map` transformations, like `mapToPair` and `mapToDouble`, were added to create RDDs of special data types. -* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning +* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning `(Key, List)` pairs to `(Key, Iterable)`.
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 330c159c67bca..460a66f37dd64 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -245,7 +245,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.coarse false - If set to "true", runs over Mesos clusters in + If set to true, runs over Mesos clusters in "coarse-grained" sharing mode, where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use @@ -254,16 +254,16 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.extra.cores - 0 + 0 Set the extra amount of cpus to request per task. This setting is only used for Mesos coarse grain mode. The total amount of cores requested per task is the number of cores in the offer plus the extra cores configured. - Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. + Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. spark.mesos.mesosExecutor.cores - 1.0 + 1.0 (Fine-grained mode only) Number of cores to give each Mesos executor. This does not include the cores used to run the Spark tasks. In other words, even if no Spark task @@ -287,7 +287,7 @@ See the [configuration page](configuration.html) for information on Spark config Set the list of volumes which will be mounted into the Docker image, which was set using spark.mesos.executor.docker.image. The format of this property is a comma-separated list of - mappings following the form passed to docker run -v. That is they take the form: + mappings following the form passed to docker run -v. That is they take the form:
[host_path:]container_path[:ro|:rw]
@@ -318,7 +318,7 @@ See the [configuration page](configuration.html) for information on Spark config executor memory * 0.10, with minimum of 384 The amount of additional memory, specified in MB, to be allocated per executor. By default, - the overhead will be larger of either 384 or 10% of `spark.executor.memory`. If it's set, + the overhead will be larger of either 384 or 10% of spark.executor.memory. If set, the final overhead will be this value. @@ -339,7 +339,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.secret - (none)/td> + (none) Set the secret with which Spark framework will use to authenticate with Mesos. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 3a961d245f3de..0e25ccf512c02 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -23,7 +23,7 @@ Unlike [Spark standalone](spark-standalone.html) and [Mesos](running-on-mesos.ht To launch a Spark application in `yarn-cluster` mode: $ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] - + For example: $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ @@ -43,7 +43,7 @@ To launch a Spark application in `yarn-client` mode, do the same, but replace `y ## Adding Other JARs -In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. +In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. $ ./bin/spark-submit --class my.main.Class \ --master yarn-cluster \ @@ -64,16 +64,16 @@ Most of the configs are the same for Spark on YARN as for other deployment modes # Debugging your Application -In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the `yarn logs` command. yarn logs -applicationId - + will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). The logs are also available on the Spark Web UI under the Executors Tab. You need to have both the Spark history server and the MapReduce history server running and configure `yarn.log.server.url` in `yarn-site.xml` properly. The log URL on the Spark history server UI will redirect you to the MapReduce history server to show the aggregated logs. When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. The logs are also available on the Spark Web UI under the Executors Tab and doesn't require running the MapReduce history server. To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a -large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` +large value (e.g. `36000`), and then access the application cache through `yarn.nodemanager.local-dirs` on the nodes on which containers are launched. This directory contains the launch script, JARs, and all environment variables used for launching each container. This process is useful for debugging classpath problems in particular. (Note that enabling this requires admin privileges on cluster @@ -92,7 +92,7 @@ Note that for the first option, both executors and the application master will s log4j configuration, which may cause issues when they run on the same node (e.g. trying to write to the same log file). -If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your `log4j.properties`. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming applications, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log files, and logs can be accessed using YARN's log utility. #### Spark Properties @@ -100,24 +100,26 @@ If you need a reference to the proper location to put log files in the YARN so t Property NameDefaultMeaning spark.yarn.am.memory - 512m + 512m Amount of memory to use for the YARN Application Master in client mode, in the same format as JVM memory strings (e.g. 512m, 2g). In cluster mode, use spark.driver.memory instead. +

+ Use lower-case suffixes, e.g. k, m, g, t, and p, for kibi-, mebi-, gibi-, tebi-, and pebibytes, respectively. spark.driver.cores - 1 + 1 Number of cores used by the driver in YARN cluster mode. - Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN AM. - In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN AM instead. + Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN Application Master. + In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN Application Master instead. spark.yarn.am.cores - 1 + 1 Number of cores to use for the YARN Application Master in client mode. In cluster mode, use spark.driver.cores instead. @@ -125,39 +127,39 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.am.waitTime - 100s + 100s - In `yarn-cluster` mode, time for the application master to wait for the - SparkContext to be initialized. In `yarn-client` mode, time for the application master to wait + In yarn-cluster mode, time for the YARN Application Master to wait for the + SparkContext to be initialized. In yarn-client mode, time for the YARN Application Master to wait for the driver to connect to it. spark.yarn.submit.file.replication - The default HDFS replication (usually 3) + The default HDFS replication (usually 3) HDFS replication level for the files uploaded into HDFS for the application. These include things like the Spark jar, the app jar, and any distributed cache files/archives. spark.yarn.preserve.staging.files - false + false - Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. + Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. spark.yarn.scheduler.heartbeat.interval-ms - 3000 + 3000 The interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. - The value is capped at half the value of YARN's configuration for the expiry interval - (yarn.am.liveness-monitor.expiry-interval-ms). + The value is capped at half the value of YARN's configuration for the expiry interval, i.e. + yarn.am.liveness-monitor.expiry-interval-ms. spark.yarn.scheduler.initial-allocation.interval - 200ms + 200ms The initial interval in which the Spark application master eagerly heartbeats to the YARN ResourceManager when there are pending container allocation requests. It should be no larger than @@ -177,8 +179,8 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.historyServer.address (none) - The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. - For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For eg, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to `${hadoopconf-yarn.resourcemanager.hostname}:18080`. + The address of the Spark history server, e.g. host.com:18080. The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For example, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to ${hadoopconf-yarn.resourcemanager.hostname}:18080. @@ -197,42 +199,42 @@ If you need a reference to the proper location to put log files in the YARN so t spark.executor.instances - 2 + 2 - The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. spark.yarn.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). + The amount of off-heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). spark.yarn.driver.memoryOverhead driverMemory * 0.10, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). + The amount of off-heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). spark.yarn.am.memoryOverhead AM memory * 0.10, with minimum of 384 - Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode. + Same as spark.yarn.driver.memoryOverhead, but for the YARN Application Master in client mode. spark.yarn.am.port (random) - Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. + Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the YARN Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. spark.yarn.queue - default + default The name of the YARN queue to which the application is submitted. @@ -245,18 +247,18 @@ If you need a reference to the proper location to put log files in the YARN so t By default, Spark on YARN will use a Spark jar installed locally, but the Spark jar can also be in a world-readable location on HDFS. This allows YARN to cache it on nodes so that it doesn't need to be distributed each time an application runs. To point to a jar on HDFS, for example, - set this configuration to "hdfs:///some/path". + set this configuration to hdfs:///some/path. spark.yarn.access.namenodes (none) - A list of secure HDFS namenodes your Spark application is going to access. For - example, `spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032`. - The Spark application must have acess to the namenodes listed and Kerberos must - be properly configured to be able to access them (either in the same realm or in - a trusted realm). Spark acquires security tokens for each of the namenodes so that + A comma-separated list of secure HDFS namenodes your Spark application is going to access. For + example, spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032. + The Spark application must have access to the namenodes listed and Kerberos must + be properly configured to be able to access them (either in the same realm or in + a trusted realm). Spark acquires security tokens for each of the namenodes so that the Spark application can access those remote HDFS clusters. @@ -264,18 +266,18 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.appMasterEnv.[EnvironmentVariableName] (none) - Add the environment variable specified by EnvironmentVariableName to the - Application Master process launched on YARN. The user can specify multiple of - these and to set multiple environment variables. In `yarn-cluster` mode this controls - the environment of the SPARK driver and in `yarn-client` mode it only controls - the environment of the executor launcher. + Add the environment variable specified by EnvironmentVariableName to the + Application Master process launched on YARN. The user can specify multiple of + these and to set multiple environment variables. In yarn-cluster mode this controls + the environment of the Spark driver and in yarn-client mode it only controls + the environment of the executor launcher. spark.yarn.containerLauncherMaxThreads - 25 + 25 - The maximum number of threads to use in the application master for launching executor containers. + The maximum number of threads to use in the YARN Application Master for launching executor containers. @@ -283,19 +285,19 @@ If you need a reference to the proper location to put log files in the YARN so t (none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use `spark.driver.extraJavaOptions` instead. + In cluster mode, use spark.driver.extraJavaOptions instead. spark.yarn.am.extraLibraryPath (none) - Set a special library path to use when launching the application master in client mode. + Set a special library path to use when launching the YARN Application Master in client mode. spark.yarn.maxAppAttempts - yarn.resourcemanager.am.max-attempts in YARN + yarn.resourcemanager.am.max-attempts in YARN The maximum number of attempts that will be made to submit the application. It should be no larger than the global number of max attempts in the YARN configuration. @@ -303,10 +305,10 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.submit.waitAppCompletion - true + true In YARN cluster mode, controls whether the client waits to exit until the application completes. - If set to true, the client process will stay alive reporting the application's status. + If set to true, the client process will stay alive reporting the application's status. Otherwise, the client process will exit after submission. @@ -332,7 +334,7 @@ If you need a reference to the proper location to put log files in the YARN so t (none) The full path to the file that contains the keytab for the principal specified above. - This keytab will be copied to the node running the Application Master via the Secure Distributed Cache, + This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, for renewing the login tickets and the delegation tokens periodically. @@ -371,14 +373,14 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.security.tokens.${service}.enabled - true + true Controls whether to retrieve delegation tokens for non-HDFS services when security is enabled. By default, delegation tokens for all supported services are retrieved when those services are configured, but it's possible to disable that behavior if it somehow conflicts with the application being run.

- Currently supported services are: hive, hbase + Currently supported services are: hive, hbase @@ -387,5 +389,5 @@ If you need a reference to the proper location to put log files in the YARN so t - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. - In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do. -- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. +- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7ae9244c271e3..a1cbc7de97c65 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1676,7 +1676,7 @@ results <- collect(sql(sqlContext, "FROM src SELECT key, value")) ### Interacting with Different Versions of Hive Metastore One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, -which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. Note that independent of the version of Hive that is being used to talk to the metastore, internally Spark SQL will compile against Hive 1.2.1 and use those classes for internal execution (serdes, UDFs, UDAFs, etc). @@ -1706,8 +1706,8 @@ The following options can be used to configure the version of Hive that is used either 1.2.1 or not defined.

  • maven
  • Use Hive jars of specified version downloaded from Maven repositories. This configuration - is not generally recommended for production deployments. -
  • A classpath in the standard format for the JVM. This classpath must include all of Hive + is not generally recommended for production deployments. +
  • A classpath in the standard format for the JVM. This classpath must include all of Hive and its dependencies, including the correct version of Hadoop. These jars only need to be present on the driver, but if you are running in yarn cluster mode then you must ensure they are packaged with you application.
  • @@ -1806,7 +1806,7 @@ the Data Sources API. The following options are supported:
    {% highlight scala %} -val jdbcDF = sqlContext.read.format("jdbc").options( +val jdbcDF = sqlContext.read.format("jdbc").options( Map("url" -> "jdbc:postgresql:dbserver", "dbtable" -> "schema.tablename")).load() {% endhighlight %} @@ -2023,11 +2023,11 @@ options. - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with code generation for expression evaluation. These features can both be disabled by setting - `spark.sql.tungsten.enabled` to `false. - - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting + `spark.sql.tungsten.enabled` to `false`. + - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting `spark.sql.parquet.mergeSchema` to `true`. - - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or - access nested values. For example `df['table.column.nestedField']`. However, this means that if + - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or + access nested values. For example `df['table.column.nestedField']`. However, this means that if your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). - In-memory columnar storage partition pruning is on by default. It can be disabled by setting `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 7ea4d6f1a3f8f..915be0f479157 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -103,7 +103,7 @@ run it with `--help`. Here are a few examples of common options: export HADOOP_CONF_DIR=XXX ./bin/spark-submit \ --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ # can also be `yarn-client` for client mode + --master yarn-cluster \ # can also be yarn-client for client mode --executor-memory 20G \ --num-executors 50 \ /path/to/examples.jar \ @@ -174,9 +174,9 @@ This can use up a significant amount of space over time and will need to be clea is handled automatically, and with Spark standalone, automatic cleanup can be configured with the `spark.worker.cleanup.appDataTtl` property. -Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates -with `--packages`. All transitive dependencies will be handled when using this command. Additional -repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. +Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates +with `--packages`. All transitive dependencies will be handled when using this command. Additional +repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. These commands can be used with `pyspark`, `spark-shell`, and `spark-submit` to include Spark Packages. For Python, the equivalent `--py-files` option can be used to distribute `.egg`, `.zip` and `.py` libraries From b4de81be9baef44b8a4f1555b15be573feb473e7 Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Mon, 21 Sep 2015 12:09:00 -0700 Subject: [PATCH 0078/1168] [SPARK-9642] [ML] LinearRegression should supported weighted data In many modeling application, data points are not necessarily sampled with equal probabilities. Linear regression should support weighting which account the over or under sampling. work in progress. Author: Meihua Wu Closes #8631 from rotationsymmetry/SPARK-9642. --- .../ml/regression/LinearRegression.scala | 164 +++++++++++------- .../ml/regression/LinearRegressionSuite.scala | 88 ++++++++++ project/MimaExcludes.scala | 8 +- 3 files changed, 191 insertions(+), 69 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index e4602d36ccc87..78a67c5fdab20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -31,21 +31,29 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.functions.{col, udf, lit} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.StatCounter /** * Params for linear regression. */ private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol - with HasFitIntercept with HasStandardization + with HasFitIntercept with HasStandardization with HasWeightCol + +/** + * Class that represents an instance of weighted data point with label and features. + * + * TODO: Refactor this class to proper place. + * + * @param label Label for this data point. + * @param weight The weight of this instance. + * @param features The vector of features for this data point. + */ +private[regression] case class Instance(label: Double, weight: Double, features: Vector) /** * :: Experimental :: @@ -123,30 +131,43 @@ class LinearRegression(override val uid: String) def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) + /** + * Whether to over-/under-sample training instances according to the given weights in weightCol. + * If empty, all instances are treated equally (weight 1.0). + * Default is empty, so all instances have weight one. + * @group setParam + */ + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + override protected def train(dataset: DataFrame): LinearRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist instances. - val instances = extractLabeledPoints(dataset).map { - case LabeledPoint(label: Double, features: Vector) => (label, features) + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val (summarizer, statCounter) = instances.treeAggregate( - (new MultivariateOnlineSummarizer, new StatCounter))( - seqOp = (c, v) => (c, v) match { - case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter), - (label: Double, features: Vector)) => - (summarizer.add(features), statCounter.merge(label)) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter), - (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) => - (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2)) - }) - - val numFeatures = summarizer.mean.size - val yMean = statCounter.mean - val yStd = math.sqrt(statCounter.variance) + val (featuresSummarizer, ySummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), + c._2.add(Vectors.dense(instance.label), instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.treeAggregate( + new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer)(seqOp, combOp) + } + + val numFeatures = featuresSummarizer.mean.size + val yMean = ySummarizer.mean(0) + val yStd = math.sqrt(ySummarizer.variance(0)) // If the yStd is zero, then the intercept is yMean with zero weights; // as a result, training is not needed. @@ -167,8 +188,8 @@ class LinearRegression(override val uid: String) return copyValues(model.setSummary(trainingSummary)) } - val featuresMean = summarizer.mean.toArray - val featuresStd = summarizer.variance.toArray.map(math.sqrt) + val featuresMean = featuresSummarizer.mean.toArray + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) // Since we implicitly do the feature scaling when we compute the cost function // to improve the convergence, the effective regParam will be changed. @@ -318,7 +339,8 @@ class LinearRegressionModel private[ml] ( /** * :: Experimental :: - * Linear regression training results. + * Linear regression training results. Currently, the training summary ignores the + * training weights except for the objective trace. * @param predictions predictions outputted by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @@ -477,7 +499,7 @@ class LinearRegressionSummary private[regression] ( * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) * }}}, * - * @param weights The weights/coefficients corresponding to the features. + * @param coefficients The coefficients corresponding to the features. * @param labelStd The standard deviation value of the label. * @param labelMean The mean value of the label. * @param fitIntercept Whether to fit an intercept term. @@ -485,7 +507,7 @@ class LinearRegressionSummary private[regression] ( * @param featuresMean The mean values of the features. */ private class LeastSquaresAggregator( - weights: Vector, + coefficients: Vector, labelStd: Double, labelMean: Double, fitIntercept: Boolean, @@ -493,26 +515,28 @@ private class LeastSquaresAggregator( featuresMean: Array[Double]) extends Serializable { private var totalCnt: Long = 0L + private var weightSum: Double = 0.0 private var lossSum = 0.0 - private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = { - val weightsArray = weights.toArray.clone() + private val (effectiveCoefficientsArray: Array[Double], offset: Double, dim: Int) = { + val coefficientsArray = coefficients.toArray.clone() var sum = 0.0 var i = 0 - val len = weightsArray.length + val len = coefficientsArray.length while (i < len) { if (featuresStd(i) != 0.0) { - weightsArray(i) /= featuresStd(i) - sum += weightsArray(i) * featuresMean(i) + coefficientsArray(i) /= featuresStd(i) + sum += coefficientsArray(i) * featuresMean(i) } else { - weightsArray(i) = 0.0 + coefficientsArray(i) = 0.0 } i += 1 } - (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length) + val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 + (coefficientsArray, offset, coefficientsArray.length) } - private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) + private val effectiveCoefficientsVector = Vectors.dense(effectiveCoefficientsArray) private val gradientSumArray = Array.ofDim[Double](dim) @@ -520,30 +544,33 @@ private class LeastSquaresAggregator( * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient * of the objective function. * - * @param label The label for this data point. - * @param data The features for one data point in dense/sparse vector format to be added - * into this aggregator. + * @param instance The data point instance to be added. * @return This LeastSquaresAggregator object. */ - def add(label: Double, data: Vector): this.type = { - require(dim == data.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $dim but got ${data.size}.") + def add(instance: Instance): this.type = + instance match { case Instance(label, weight, features) => + require(dim == features.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${features.size}.") + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") - val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset + if (weight == 0.0) return this - if (diff != 0) { - val localGradientSumArray = gradientSumArray - data.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - localGradientSumArray(index) += diff * value / featuresStd(index) + val diff = dot(features, effectiveCoefficientsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + features.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += weight * diff * value / featuresStd(index) + } } + lossSum += weight * diff * diff / 2.0 } - lossSum += diff * diff / 2.0 - } - totalCnt += 1 - this - } + totalCnt += 1 + weightSum += weight + this + } /** * Merge another LeastSquaresAggregator, and update the loss and gradient @@ -557,8 +584,9 @@ private class LeastSquaresAggregator( require(dim == other.dim, s"Dimensions mismatch when merging with another " + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") - if (other.totalCnt != 0) { + if (other.weightSum != 0) { totalCnt += other.totalCnt + weightSum += other.weightSum lossSum += other.lossSum var i = 0 @@ -574,11 +602,17 @@ private class LeastSquaresAggregator( def count: Long = totalCnt - def loss: Double = lossSum / totalCnt + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") + lossSum / weightSum + } def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / totalCnt, result) + scal(1.0 / weightSum, result) result } } @@ -589,7 +623,7 @@ private class LeastSquaresAggregator( * It's used in Breeze's convex optimization routines. */ private class LeastSquaresCostFun( - data: RDD[(Double, Vector)], + data: RDD[Instance], labelStd: Double, labelMean: Double, fitIntercept: Boolean, @@ -598,17 +632,13 @@ private class LeastSquaresCostFun( featuresMean: Array[Double], effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { - override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { - val w = Vectors.fromBreeze(weights) + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + val coeff = Vectors.fromBreeze(coefficients) - val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, + val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(coeff, labelStd, labelMean, fitIntercept, featuresStd, featuresMean))( - seqOp = (c, v) => (c, v) match { - case (aggregator, (label, features)) => aggregator.add(label, features) - }, - combOp = (c1, c2) => (c1, c2) match { - case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - }) + seqOp = (aggregator, instance) => aggregator.add(instance), + combOp = (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) val totalGradientArray = leastSquaresAggregator.gradient.toArray @@ -616,7 +646,7 @@ private class LeastSquaresCostFun( 0.0 } else { var sum = 0.0 - w.foreachActive { (index, value) => + coeff.foreachActive { (index, value) => // The following code will compute the loss of the regularization; also // the gradient of the regularization, and add back to totalGradientArray. sum += { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 2aaee71ecc734..8428f4f00b370 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.ml.regression +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -510,4 +513,89 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .zip(testSummary.residuals.select("residuals").collect()) .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } } + + test("linear regression with weighted samples"){ + val (data, weightedData) = { + val activeData = LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + + val rnd = new Random(8392) + val signedData = activeData.map { case p: LabeledPoint => + (rnd.nextGaussian() > 0.0, p) + } + + val data1 = signedData.flatMap { + case (true, p) => Iterator(p, p) + case (false, p) => Iterator(p) + } + + val weightedSignedData = signedData.flatMap { + case (true, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 1.2, features), + Instance(label, weight = 0.8, features) + ) + case (false, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 0.3, features), + Instance(label, weight = 0.1, features), + Instance(label, weight = 0.6, features) + ) + } + + val noiseData = LinearDataGenerator.generateLinearInput( + 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + val weightedNoiseData = noiseData.map { + case LabeledPoint(label, features) => Instance(label, weight = 0, features) + } + val data2 = weightedSignedData ++ weightedNoiseData + + (sqlContext.createDataFrame(sc.parallelize(data1, 4)), + sqlContext.createDataFrame(sc.parallelize(data2, 4))) + } + + val trainer1a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val model1a0 = trainer1a.fit(data) + val model1a1 = trainer1a.fit(weightedData) + val model1b = trainer1b.fit(weightedData) + assert(model1a0.weights !~= model1a1.weights absTol 1E-3) + assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) + assert(model1a0.weights ~== model1b.weights absTol 1E-3) + assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + + val trainer2a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val model2a0 = trainer2a.fit(data) + val model2a1 = trainer2a.fit(weightedData) + val model2b = trainer2b.fit(weightedData) + assert(model2a0.weights !~= model2a1.weights absTol 1E-3) + assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3) + assert(model2a0.weights ~== model2b.weights absTol 1E-3) + assert(model2a0.intercept ~== model2b.intercept absTol 1E-3) + + val trainer3a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val model3a0 = trainer3a.fit(data) + val model3a1 = trainer3a.fit(weightedData) + val model3b = trainer3b.fit(weightedData) + assert(model3a0.weights !~= model3a1.weights absTol 1E-3) + assert(model3a0.weights ~== model3b.weights absTol 1E-3) + + val trainer4a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val model4a0 = trainer4a.fit(data) + val model4a1 = trainer4a.fit(weightedData) + val model4b = trainer4b.fit(weightedData) + assert(model4a0.weights !~= model4a1.weights absTol 1E-3) + assert(model4a0.weights ~== model4b.weights absTol 1E-3) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 814a11e588ceb..b2e6be706637b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -70,10 +70,14 @@ object MimaExcludes { "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") - ) ++ - Seq( + ) ++ Seq( ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") + ) ++ Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.add"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this") ) case v if v.startsWith("1.5") => Seq( From ab0b015272312b7006d7bf46540be9d5518c86cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?hushan=5B=E8=83=A1=E7=8F=8A=5D?= Date: Mon, 21 Sep 2015 14:26:15 -0500 Subject: [PATCH 0079/1168] [SPARK-5259] [CORE] don't submit stage until its dependencies map outputs are registered MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Track pending tasks by partition ID instead of Task objects. Before this change, failure & retry could result in a case where a stage got submitted before the map output from its dependencies get registered. This was due to an error in the condition for registering map outputs. Author: hushan[胡珊] Author: Imran Rashid Closes #7699 from squito/SPARK-5259. --- .../apache/spark/scheduler/DAGScheduler.scala | 12 +- .../org/apache/spark/scheduler/Stage.scala | 2 +- .../spark/scheduler/TaskSetManager.scala | 4 +- .../spark/scheduler/DAGSchedulerSuite.scala | 197 ++++++++++++++++-- 4 files changed, 191 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3c9a66e504403..394228b2728d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -944,7 +944,7 @@ class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry - stage.pendingTasks.clear() + stage.pendingPartitions.clear() // First figure out the indexes of partition ids to compute. val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = { @@ -1060,8 +1060,8 @@ class DAGScheduler( if (tasks.size > 0) { logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - stage.pendingTasks ++= tasks - logDebug("New pending tasks: " + stage.pendingTasks) + stage.pendingPartitions ++= tasks.map(_.partitionId) + logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) @@ -1152,7 +1152,7 @@ class DAGScheduler( case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) - stage.pendingTasks -= task + stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask @@ -1198,7 +1198,7 @@ class DAGScheduler( shuffleStage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { + if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) @@ -1242,7 +1242,7 @@ class DAGScheduler( case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") - stage.pendingTasks += task + stage.pendingPartitions += task.partitionId case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index b37eccbd0f7b8..a3829c319c48d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -66,7 +66,7 @@ private[scheduler] abstract class Stage( /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - var pendingTasks = new HashSet[Task[_]] + val pendingPartitions = new HashSet[Int] /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 62af9031b9f8b..c02597c4365c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -487,8 +487,8 @@ private[spark] class TaskSetManager( // a good proxy to task serialization time. // val timeTaken = clock.getTime() - startTime val taskName = s"task ${info.id} in stage ${taskSet.id}" - logInfo("Starting %s (TID %d, %s, %s, %d bytes)".format( - taskName, taskId, host, taskLocality, serializedTask.limit)) + logInfo(s"Starting $taskName (TID $taskId, $host, partition ${task.partitionId}," + + s"$taskLocality, ${serializedTask.limit} bytes)") sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 1c55f90ad9b44..6b5bcf0574de6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -479,8 +479,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), @@ -490,7 +490,7 @@ class DAGSchedulerSuite // ask the scheduler to try it again scheduler.resubmitFailedStages() // have the 2nd attempt pass - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) // we can see both result blocks now assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) @@ -782,8 +782,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) @@ -1035,6 +1035,173 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + /** + * This test runs a three stage job, with a fetch failure in stage 1. but during the retry, we + * have completions from both the first & second attempt of stage 1. So all the map output is + * available before we finish any task set for stage 1. We want to make sure that we don't + * submit stage 2 until the map output for stage 1 is registered + */ + test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") { + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleId = firstShuffleDep.shuffleId + val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + + // things start out smoothly, stage 0 completes with no issues + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.length)), + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.length)), + (Success, makeMapStatus("hostA", shuffleMapRdd.partitions.length)) + )) + + // then one executor dies, and a task fails in stage 1 + runEvent(ExecutorLost("exec-hostA")) + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), + null, + null, + createFakeTaskInfo(), + null)) + + // so we resubmit stage 0, which completes happily + scheduler.resubmitFailedStages() + val stage0Resubmit = taskSets(2) + assert(stage0Resubmit.stageId == 0) + assert(stage0Resubmit.stageAttemptId === 1) + val task = stage0Resubmit.tasks(0) + assert(task.partitionId === 2) + runEvent(CompletionEvent( + task, + Success, + makeMapStatus("hostC", shuffleMapRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + + // now here is where things get tricky : we will now have a task set representing + // the second attempt for stage 1, but we *also* have some tasks for the first attempt for + // stage 1 still going + val stage1Resubmit = taskSets(3) + assert(stage1Resubmit.stageId == 1) + assert(stage1Resubmit.stageAttemptId === 1) + assert(stage1Resubmit.tasks.length === 3) + + // we'll have some tasks finish from the first attempt, and some finish from the second attempt, + // so that we actually have all stage outputs, though no attempt has completed all its + // tasks + runEvent(CompletionEvent( + taskSets(3).tasks(0), + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + runEvent(CompletionEvent( + taskSets(3).tasks(1), + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + // late task finish from the first attempt + runEvent(CompletionEvent( + taskSets(1).tasks(2), + Success, + makeMapStatus("hostB", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + + // What should happen now is that we submit stage 2. However, we might not see an error + // b/c of DAGScheduler's error handling (it tends to swallow errors and just log them). But + // we can check some conditions. + // Note that the really important thing here is not so much that we submit stage 2 *immediately* + // but that we don't end up with some error from these interleaved completions. It would also + // be OK (though sub-optimal) if stage 2 simply waited until the resubmission of stage 1 had + // all its tasks complete + + // check that we have all the map output for stage 0 (it should have been there even before + // the last round of completions from stage 1, but just to double check it hasn't been messed + // up) and also the newly available stage 1 + val stageToReduceIdxs = Seq( + 0 -> (0 until 3), + 1 -> (0 until 1) + ) + for { + (stage, reduceIdxs) <- stageToReduceIdxs + reduceIdx <- reduceIdxs + } { + // this would throw an exception if the map status hadn't been registered + val statuses = mapOutputTracker.getMapSizesByExecutorId(stage, reduceIdx) + // really we should have already thrown an exception rather than fail either of these + // asserts, but just to be extra defensive let's double check the statuses are OK + assert(statuses != null) + assert(statuses.nonEmpty) + } + + // and check that stage 2 has been submitted + assert(taskSets.size == 5) + val stage2TaskSet = taskSets(4) + assert(stage2TaskSet.stageId == 2) + assert(stage2TaskSet.stageAttemptId == 0) + } + + /** + * We lose an executor after completing some shuffle map tasks on it. Those tasks get + * resubmitted, and when they finish the job completes normally + */ + test("register map outputs correctly after ExecutorLost and task Resubmitted") { + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep)) + submit(reduceRdd, Array(0)) + + // complete some of the tasks from the first stage, on one host + runEvent(CompletionEvent( + taskSets(0).tasks(0), Success, + makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSets(0).tasks(1), Success, + makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + + // now that host goes down + runEvent(ExecutorLost("exec-hostA")) + + // so we resubmit those tasks + runEvent(CompletionEvent( + taskSets(0).tasks(0), Resubmitted, null, null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSets(0).tasks(1), Resubmitted, null, null, createFakeTaskInfo(), null)) + + // now complete everything on a different host + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)) + )) + + // now we should submit stage 1, and the map output from stage 0 should be registered + + // check that we have all the map output for stage 0 + (0 until reduceRdd.partitions.length).foreach { reduceIdx => + val statuses = mapOutputTracker.getMapSizesByExecutorId(0, reduceIdx) + // really we should have already thrown an exception rather than fail either of these + // asserts, but just to be extra defensive let's double check the statuses are OK + assert(statuses != null) + assert(statuses.nonEmpty) + } + + // and check that stage 1 has been submitted + assert(taskSets.size == 2) + val stage1TaskSet = taskSets(1) + assert(stage1TaskSet.stageId == 1) + assert(stage1TaskSet.stageAttemptId == 0) + } + /** * Makes sure that failures of stage used by multiple jobs are correctly handled. * @@ -1393,8 +1560,8 @@ class DAGSchedulerSuite // Submit a map stage by itself submitMapStage(shuffleDep) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) assert(results.size === 1) results.clear() assertDataStructuresEmpty() @@ -1407,7 +1574,7 @@ class DAGSchedulerSuite // Ask the scheduler to try it again; TaskSet 2 will rerun the map task that we couldn't fetch // from, then TaskSet 3 will run the reduce stage scheduler.resubmitFailedStages() - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) results.clear() @@ -1452,8 +1619,8 @@ class DAGSchedulerSuite // Complete the first stage assert(taskSets(0).stageId === 0) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", rdd1.partitions.size)), - (Success, makeMapStatus("hostB", rdd1.partitions.size)))) + (Success, makeMapStatus("hostA", rdd1.partitions.length)), + (Success, makeMapStatus("hostB", rdd1.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) assert(listener1.results.size === 1) @@ -1461,7 +1628,7 @@ class DAGSchedulerSuite // When attempting the second stage, show a fetch failure assert(taskSets(1).stageId === 1) complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", rdd2.partitions.size)), + (Success, makeMapStatus("hostA", rdd2.partitions.length)), (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) scheduler.resubmitFailedStages() assert(listener2.results.size === 0) // Second stage listener should not have a result yet @@ -1469,7 +1636,7 @@ class DAGSchedulerSuite // Stage 0 should now be running as task set 2; make its task succeed assert(taskSets(2).stageId === 0) complete(taskSets(2), Seq( - (Success, makeMapStatus("hostC", rdd2.partitions.size)))) + (Success, makeMapStatus("hostC", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) assert(listener2.results.size === 0) // Second stage listener should still not have a result @@ -1477,8 +1644,8 @@ class DAGSchedulerSuite // Stage 1 should now be running as task set 3; make its first task succeed assert(taskSets(3).stageId === 1) complete(taskSets(3), Seq( - (Success, makeMapStatus("hostB", rdd2.partitions.size)), - (Success, makeMapStatus("hostD", rdd2.partitions.size)))) + (Success, makeMapStatus("hostB", rdd2.partitions.length)), + (Success, makeMapStatus("hostD", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) assert(listener2.results.size === 1) @@ -1494,7 +1661,7 @@ class DAGSchedulerSuite // TaskSet 5 will rerun stage 1's lost task, then TaskSet 6 will rerun stage 2 assert(taskSets(5).stageId === 1) complete(taskSets(5), Seq( - (Success, makeMapStatus("hostE", rdd2.partitions.size)))) + (Success, makeMapStatus("hostE", rdd2.partitions.length)))) complete(taskSets(6), Seq( (Success, 53))) assert(listener3.results === Map(0 -> 52, 1 -> 53)) From 9c1c533e0453cc6192ac2bd0782cc3980b1a7ea4 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 13:06:23 -0700 Subject: [PATCH 0080/1168] [SPARK-9769] [ML] [PY] add python api for countvectorizermodel From JIRA: Add Python API, user guide and example for ml.feature.CountVectorizerModel Author: Holden Karau Closes #8561 from holdenk/SPARK-9769-add-python-api-for-countvectorizermodel. --- python/pyspark/ml/feature.py | 148 +++++++++++++++++++++++++++++++++-- 1 file changed, 142 insertions(+), 6 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 92db8df80280b..f41d72f877256 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -26,12 +26,13 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', - 'IndexToString', 'MinMaxScaler', 'MinMaxScalerModel', 'NGram', 'Normalizer', - 'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion', 'RegexTokenizer', - 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', - 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', - 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] +__all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT', + 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', 'MinMaxScaler', + 'MinMaxScalerModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', + 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', + 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', + 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', + 'Word2Vec', 'Word2VecModel'] @inherit_doc @@ -171,6 +172,141 @@ def getSplits(self): return self.getOrDefault(self.splits) +@inherit_doc +class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. + >>> df = sqlContext.createDataFrame( + ... [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])], + ... ["label", "raw"]) + >>> cv = CountVectorizer(inputCol="raw", outputCol="vectors") + >>> model = cv.fit(df) + >>> model.transform(df).show(truncate=False) + +-----+---------------+-------------------------+ + |label|raw |vectors | + +-----+---------------+-------------------------+ + |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| + |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| + +-----+---------------+-------------------------+ + ... + >>> sorted(map(str, model.vocabulary)) + ['a', 'b', 'c'] + """ + + # a placeholder to make it appear in the generated doc + minTF = Param( + Params._dummy(), "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then this " + + "specifies a fraction (out of the document's token count). Note that the parameter is " + + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0") + minDF = Param( + Params._dummy(), "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + + " Default 1.0") + vocabSize = Param( + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.") + + @keyword_only + def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + """ + __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + """ + super(CountVectorizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", + self.uid) + self.minTF = Param( + self, "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then " + + "this specifies a fraction (out of the document's token count). Note that the " + + "parameter is only used in transform of CountVectorizerModel and does not affect" + + "fitting. Default 1.0") + self.minDF = Param( + self, "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of " + + "documents. Default 1.0") + self.vocabSize = Param( + self, "vocabSize", "max size of the vocabulary. Default 1 << 18.") + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + """ + setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + Set the params for the CountVectorizer + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setMinTF(self, value): + """ + Sets the value of :py:attr:`minTF`. + """ + self._paramMap[self.minTF] = value + return self + + def getMinTF(self): + """ + Gets the value of minTF or its default value. + """ + return self.getOrDefault(self.minTF) + + def setMinDF(self, value): + """ + Sets the value of :py:attr:`minDF`. + """ + self._paramMap[self.minDF] = value + return self + + def getMinDF(self): + """ + Gets the value of minDF or its default value. + """ + return self.getOrDefault(self.minDF) + + def setVocabSize(self, value): + """ + Sets the value of :py:attr:`vocabSize`. + """ + self._paramMap[self.vocabSize] = value + return self + + def getVocabSize(self): + """ + Gets the value of vocabSize or its default value. + """ + return self.getOrDefault(self.vocabSize) + + def _create_model(self, java_model): + return CountVectorizerModel(java_model) + + +class CountVectorizerModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by CountVectorizer. + """ + + @property + def vocabulary(self): + """ + An array of terms in the vocabulary. + """ + return self._call_java("vocabulary") + + @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol): """ From 02590454f43c9ea43cd4e3d28876f1de5a65bb0c Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 21 Sep 2015 13:11:28 -0700 Subject: [PATCH 0081/1168] [SPARK-3147] [MLLIB] [STREAMING] Streaming 2-sample statistical significance testing Implementation of significance testing using Streaming API. Author: Feynman Liang Author: Feynman Liang Closes #4716 from feynmanliang/ab_testing. --- .../examples/mllib/StreamingTestExample.scala | 90 +++++++ .../spark/mllib/stat/test/StreamingTest.scala | 145 +++++++++++ .../mllib/stat/test/StreamingTestMethod.scala | 167 ++++++++++++ .../spark/mllib/stat/test/TestResult.scala | 22 ++ .../spark/mllib/stat/StreamingTestSuite.scala | 243 ++++++++++++++++++ 5 files changed, 667 insertions(+) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala new file mode 100644 index 0000000000000..ab29f90254d34 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.mllib.stat.test.StreamingTest +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.util.Utils + +/** + * Perform streaming testing using Welch's 2-sample t-test on a stream of data, where the data + * stream arrives as text files in a directory. Stops when the two groups are statistically + * significant (p-value < 0.05) or after a user-specified timeout in number of batches is exceeded. + * + * The rows of the text files must be in the form `Boolean, Double`. For example: + * false, -3.92 + * true, 99.32 + * + * Usage: + * StreamingTestExample + * + * To run on your local machine using the directory `dataDir` with 5 seconds between each batch and + * a timeout after 100 insignificant batches, call: + * $ bin/run-example mllib.StreamingTestExample dataDir 5 100 + * + * As you add text files to `dataDir` the significance test wil continually update every + * `batchDuration` seconds until the test becomes significant (p-value < 0.05) or the number of + * batches processed exceeds `numBatchesTimeout`. + */ +object StreamingTestExample { + + def main(args: Array[String]) { + if (args.length != 3) { + // scalastyle:off println + System.err.println( + "Usage: StreamingTestExample " + + " ") + // scalastyle:on println + System.exit(1) + } + val dataDir = args(0) + val batchDuration = Seconds(args(1).toLong) + val numBatchesTimeout = args(2).toInt + + val conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample") + val ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint({ + val dir = Utils.createTempDir() + dir.toString + }) + + val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { + case Array(label, value) => (label.toBoolean, value.toDouble) + }) + + val streamingTest = new StreamingTest() + .setPeacePeriod(0) + .setWindowSize(0) + .setTestMethod("welch") + + val out = streamingTest.registerStream(data) + out.print() + + // Stop processing if test becomes significant or we time out + var timeoutCounter = numBatchesTimeout + out.foreachRDD { rdd => + timeoutCounter -= 1 + val anySignificant = rdd.map(_.pValue < 0.05).fold(false)(_ || _) + if (timeoutCounter == 0 || anySignificant) rdd.context.stop() + } + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala new file mode 100644 index 0000000000000..75c6a51d09571 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * :: Experimental :: + * Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The + * Boolean identifies which sample each observation comes from, and the Double is the numeric value + * of the observation. + * + * To address novelty affects, the `peacePeriod` specifies a set number of initial + * [[org.apache.spark.rdd.RDD]] batches of the [[DStream]] to be dropped from significance testing. + * + * The `windowSize` sets the number of batches each significance test is to be performed over. The + * window is sliding with a stride length of 1 batch. Setting windowSize to 0 will perform + * cumulative processing, using all batches seen so far. + * + * Different tests may be used for assessing statistical significance depending on assumptions + * satisfied by data. For more details, see [[StreamingTestMethod]]. The `testMethod` specifies + * which test will be used. + * + * Use a builder pattern to construct a streaming test in an application, for example: + * {{{ + * val model = new StreamingTest() + * .setPeacePeriod(10) + * .setWindowSize(0) + * .setTestMethod("welch") + * .registerStream(DStream) + * }}} + */ +@Experimental +@Since("1.6.0") +class StreamingTest @Since("1.6.0") () extends Logging with Serializable { + private var peacePeriod: Int = 0 + private var windowSize: Int = 0 + private var testMethod: StreamingTestMethod = WelchTTest + + /** Set the number of initial batches to ignore. Default: 0. */ + @Since("1.6.0") + def setPeacePeriod(peacePeriod: Int): this.type = { + this.peacePeriod = peacePeriod + this + } + + /** + * Set the number of batches to compute significance tests over. Default: 0. + * A value of 0 will use all batches seen so far. + */ + @Since("1.6.0") + def setWindowSize(windowSize: Int): this.type = { + this.windowSize = windowSize + this + } + + /** Set the statistical method used for significance testing. Default: "welch" */ + @Since("1.6.0") + def setTestMethod(method: String): this.type = { + this.testMethod = StreamingTestMethod.getTestMethodFromName(method) + this + } + + /** + * Register a [[DStream]] of values for significance testing. + * + * @param data stream of (key,value) pairs where the key denotes group membership (true = + * experiment, false = control) and the value is the numerical metric to test for + * significance + * @return stream of significance testing results + */ + @Since("1.6.0") + def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = { + val dataAfterPeacePeriod = dropPeacePeriod(data) + val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod) + val pairedSummaries = pairSummaries(summarizedData) + + testMethod.doTest(pairedSummaries) + } + + /** Drop all batches inside the peace period. */ + private[stat] def dropPeacePeriod( + data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = { + data.transform { (rdd, time) => + if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) { + rdd + } else { + data.context.sparkContext.parallelize(Seq()) + } + } + } + + /** Compute summary statistics over each key and the specified test window size. */ + private[stat] def summarizeByKeyAndWindow( + data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = { + if (this.windowSize == 0) { + data.updateStateByKey[StatCounter]( + (newValues: Seq[Double], oldSummary: Option[StatCounter]) => { + val newSummary = oldSummary.getOrElse(new StatCounter()) + newSummary.merge(newValues) + Some(newSummary) + }) + } else { + val windowDuration = data.slideDuration * this.windowSize + data + .groupByKeyAndWindow(windowDuration) + .mapValues { values => + val summary = new StatCounter() + values.foreach(value => summary.merge(value)) + summary + } + } + } + + /** + * Transform a stream of summaries into pairs representing summary statistics for control group + * and experiment group up to this batch. + */ + private[stat] def pairSummaries(summarizedData: DStream[(Boolean, StatCounter)]) + : DStream[(StatCounter, StatCounter)] = { + summarizedData + .map[(Int, StatCounter)](x => (0, x._2)) + .groupByKey() // should be length two (control/experiment group) + .map(x => (x._2.head, x._2.last)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala new file mode 100644 index 0000000000000..a7eaed51b4d55 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import java.io.Serializable + +import scala.language.implicitConversions +import scala.math.pow + +import com.twitter.chill.MeatLocker +import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues +import org.apache.commons.math3.stat.inference.TTest + +import org.apache.spark.Logging +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * Significance testing methods for [[StreamingTest]]. New 2-sample statistical significance tests + * should extend [[StreamingTestMethod]] and introduce a new entry in + * [[StreamingTestMethod.TEST_NAME_TO_OBJECT]] + */ +private[stat] sealed trait StreamingTestMethod extends Serializable { + + val methodName: String + val nullHypothesis: String + + protected type SummaryPairStream = + DStream[(StatCounter, StatCounter)] + + /** + * Perform streaming 2-sample statistical significance testing. + * + * @param sampleSummaries stream pairs of summary statistics for the 2 samples + * @return stream of rest results + */ + def doTest(sampleSummaries: SummaryPairStream): DStream[StreamingTestResult] + + /** + * Implicit adapter to convert between streaming summary statistics type and the type required by + * the t-testing libraries. + */ + protected implicit def toApacheCommonsStats( + summaryStats: StatCounter): StatisticalSummaryValues = { + new StatisticalSummaryValues( + summaryStats.mean, + summaryStats.variance, + summaryStats.count, + summaryStats.max, + summaryStats.min, + summaryStats.mean * summaryStats.count + ) + } +} + +/** + * Performs Welch's 2-sample t-test. The null hypothesis is that the two data sets have equal mean. + * This test does not assume equal variance between the two samples and does not assume equal + * sample size. + * + * @see http://en.wikipedia.org/wiki/Welch%27s_t_test + */ +private[stat] object WelchTTest extends StreamingTestMethod with Logging { + + override final val methodName = "Welch's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" + + private final val tTester = MeatLocker(new TTest()) + + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): StreamingTestResult = { + def welchDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = { + val s1 = sample1.getVariance + val n1 = sample1.getN + val s2 = sample2.getVariance + val n2 = sample2.getN + + val a = pow(s1, 2) / n1 + val b = pow(s2, 2) / n2 + + pow(a + b, 2) / ((pow(a, 2) / (n1 - 1)) + (pow(b, 2) / (n2 - 1))) + } + + new StreamingTestResult( + tTester.get.tTest(statsA, statsB), + welchDF(statsA, statsB), + tTester.get.t(statsA, statsB), + methodName, + nullHypothesis + ) + } +} + +/** + * Performs Students's 2-sample t-test. The null hypothesis is that the two data sets have equal + * mean. This test assumes equal variance between the two samples and does not assume equal sample + * size. For unequal variances, Welch's t-test should be used instead. + * + * @see http://en.wikipedia.org/wiki/Student%27s_t-test + */ +private[stat] object StudentTTest extends StreamingTestMethod with Logging { + + override final val methodName = "Student's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" + + private final val tTester = MeatLocker(new TTest()) + + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): StreamingTestResult = { + def studentDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = + sample1.getN + sample2.getN - 2 + + new StreamingTestResult( + tTester.get.homoscedasticTTest(statsA, statsB), + studentDF(statsA, statsB), + tTester.get.homoscedasticT(statsA, statsB), + methodName, + nullHypothesis + ) + } +} + +/** + * Companion object holding supported [[StreamingTestMethod]] names and handles conversion between + * strings used in [[StreamingTest]] configuration and actual method implementation. + * + * Currently supported tests: `welch`, `student`. + */ +private[stat] object StreamingTestMethod { + // Note: after new `StreamingTestMethod`s are implemented, please update this map. + private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map( + "welch"->WelchTTest, + "student"->StudentTTest) + + def getTestMethodFromName(method: String): StreamingTestMethod = + TEST_NAME_TO_OBJECT.get(method) match { + case Some(test) => test + case None => + throw new IllegalArgumentException( + "Unrecognized method name. Supported streaming test methods: " + + TEST_NAME_TO_OBJECT.keys.mkString(", ")) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index d01b3707be944..b0916d3e84651 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -115,3 +115,25 @@ class KolmogorovSmirnovTestResult private[stat] ( "Kolmogorov-Smirnov test summary:\n" + super.toString } } + +/** + * :: Experimental :: + * Object containing the test results for streaming testing. + */ +@Experimental +@Since("1.6.0") +private[stat] class StreamingTestResult @Since("1.6.0") ( + @Since("1.6.0") override val pValue: Double, + @Since("1.6.0") override val degreesOfFreedom: Double, + @Since("1.6.0") override val statistic: Double, + @Since("1.6.0") val method: String, + @Since("1.6.0") override val nullHypothesis: String) + extends TestResult[Double] with Serializable { + + override def toString: String = { + "Streaming test summary:\n" + + s"method: $method\n" + + super.toString + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala new file mode 100644 index 0000000000000..d3e9ef4ff079c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest} +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter +import org.apache.spark.util.random.XORShiftRandom + +class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { + + override def maxWaitTimeMillis : Int = 30000 + + test("accuracy for null hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == WelchTTest.methodName)) + } + + test("accuracy for alternative hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == WelchTTest.methodName)) + } + + test("accuracy for null hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == StudentTTest.methodName)) + } + + test("accuracy for alternative hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == StudentTTest.methodName)) + } + + test("batches within same test window are grouped") { + // set parameters + val testWindow = 3 + val numBatches = 5 + val pointsPerBatch = 100 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(testWindow) + .setPeacePeriod(0) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, + (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream)) + val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches) + val outputCounts = outputBatches.flatten.map(_._2.count) + + // number of batches seen so far does not exceed testWindow, expect counts to continue growing + for (i <- 0 until testWindow) { + assert(outputCounts.drop(2 * i).take(2).forall(_ == (i + 1) * pointsPerBatch / 2)) + } + + // number of batches seen exceeds testWindow, expect counts to be constant + assert(outputCounts.drop(2 * (testWindow - 1)).forall(_ == testWindow * pointsPerBatch / 2)) + } + + + test("entries in peace period are dropped") { + // set parameters + val peacePeriod = 3 + val numBatches = 7 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(peacePeriod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream)) + val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch) + } + + test("null hypothesis when only data from one group is present") { + // set parameters + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + + val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + .map(batch => batch.filter(_._1)) // only keep one test group + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001)) + } + + // Generate testing input with half of the entries in group A and half in group B + private def generateTestData( + numBatches: Int, + pointsPerBatch: Int, + meanA: Double, + stdevA: Double, + meanB: Double, + stdevB: Double, + seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = { + val rand = new XORShiftRandom(seed) + val numTrues = pointsPerBatch / 2 + val data = (0 until numBatches).map { i => + (0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++ + (pointsPerBatch / 2 until pointsPerBatch).map { idx => + (false, meanB + stdevB * rand.nextGaussian()) + } + } + + data + } +} From cc5b13cac7b466d69d42806d9401cd06c4812958 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 21 Sep 2015 13:15:44 -0700 Subject: [PATCH 0082/1168] [SPARK-10676] [DOCS] Add documentation for SASL encryption options. Author: Marcelo Vanzin Closes #8803 from vanzin/SPARK-10676. --- docs/configuration.md | 16 ++++++++++++++++ docs/security.md | 22 ++++++++++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index b22587c70316b..284f97ad09ec3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1285,6 +1285,22 @@ Apart from these, the following properties are also available, and may be useful not running on YARN and authentication is enabled. + + spark.authenticate.enableSaslEncryption + false + + Enable encrypted communication when authentication is enabled. This option is currently + only supported by the block transfer service. + + + + spark.network.sasl.serverAlwaysEncrypt + false + + Disable unencrypted connections for services that support SASL authentication. This is + currently supported by the external shuffle service. + + spark.core.connection.ack.wait.timeout 60s diff --git a/docs/security.md b/docs/security.md index d4ffa60e59a33..177109415180b 100644 --- a/docs/security.md +++ b/docs/security.md @@ -23,9 +23,16 @@ If your applications are using event logging, the directory where the event logs ## Encryption -Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. However SSL is not supported yet for WebUI and block transfer service. +Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. SASL encryption is +supported for the block transfer service. Encryption is not yet supported for the WebUI. -Connection encryption (SSL) configuration is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). +Encryption is not yet supported for data stored by Spark in temporary local storage, such as shuffle +files, cached data, and other application files. If encrypting this data is desired, a workaround is +to configure your cluster manager to store application data on encrypted disks. + +### SSL Configuration + +Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. @@ -47,6 +54,17 @@ follows: * Import all exported public keys into a single trust-store * Distribute the trust-store over the nodes +### Configuring SASL Encryption + +SASL encryption is currently supported for the block transfer service when authentication +(`spark.authenticate`) is enabled. To enable SASL encryption for an application, set +`spark.authenticate.enableSaslEncryption` to `true` in the application's configuration. + +When using an external shuffle service, it's possible to disable unencrypted connections by setting +`spark.network.sasl.serverAlwaysEncrypt` to `true` in the shuffle service's configuration. If that +option is enabled, applications that are not set up to use SASL encryption will fail to connect to +the shuffle service. + ## Configuring Ports for Network Security Spark makes heavy use of the network, and some environments have strict requirements for using tight From f1640f5fd8bac073672401e370c6b7728955249c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 13:33:10 -0700 Subject: [PATCH 0083/1168] [SPARK-10630] [SQL] Add a createDataFrame API that takes in a java list It would be nice to support creating a DataFrame directly from a Java List of Row. Author: Holden Karau Closes #8779 from holdenk/SPARK-10630-create-DataFrame-from-Java-List. --- .../scala/org/apache/spark/sql/SQLContext.scala | 14 ++++++++++++++ .../org/apache/spark/sql/JavaDataFrameSuite.java | 10 ++++++++++ 2 files changed, 24 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f099940800cc0..1bd4e26fb3162 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -476,6 +476,20 @@ class SQLContext(@transient val sparkContext: SparkContext) createDataFrame(rowRDD.rdd, schema) } + /** + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[java.util.List]] containing [[Row]]s using the given schema. + * It is important to make sure that the structure of every [[Row]] of the provided List matches + * the provided schema. Otherwise, there will be runtime exception. + * + * @group dataframes + * @since 1.6.0 + */ + @DeveloperApi + def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { + DataFrame(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) + } + /** * Applies a schema to an RDD of Java Beans. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 5f9abd4999ce0..250ac2e1092d4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -37,6 +37,7 @@ import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -181,6 +182,15 @@ public void testCreateDataFrameFromJavaBeans() { } } + @Test + public void testCreateDataFromFromList() { + StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); + List rows = Arrays.asList(RowFactory.create(0)); + DataFrame df = context.createDataFrame(rows, schema); + Row[] result = df.collect(); + Assert.assertEquals(1, result.length); + } + private static final Comparator crosstabRowComparator = new Comparator() { @Override public int compare(Row row1, Row row2) { From 47492374cd73a7e3d3b52fb31aa24342dd0cf7dd Mon Sep 17 00:00:00 2001 From: noelsmith Date: Mon, 21 Sep 2015 14:24:19 -0700 Subject: [PATCH 0084/1168] [DOC] [PYSPARK] [MLLIB] Added newlines to docstrings to fix parameter formatting Added newlines before `:param ...:` and `:return:` markup. Without these, parameter lists aren't formatted correctly in the API docs. I.e: ![screen shot 2015-09-21 at 21 49 26](https://cloud.githubusercontent.com/assets/11915197/10004686/de3c41d4-60aa-11e5-9c50-a46dcb51243f.png) .. looks like this once newline is added: ![screen shot 2015-09-21 at 21 50 14](https://cloud.githubusercontent.com/assets/11915197/10004706/f86bfb08-60aa-11e5-8524-ae4436713502.png) Author: noelsmith Closes #8851 from noel-smith/docstring-missing-newline-fix. --- python/pyspark/ml/param/__init__.py | 4 ++++ python/pyspark/ml/pipeline.py | 1 + python/pyspark/ml/tuning.py | 2 ++ python/pyspark/ml/wrapper.py | 2 ++ python/pyspark/mllib/evaluation.py | 2 +- python/pyspark/mllib/linalg/__init__.py | 1 + python/pyspark/streaming/context.py | 2 ++ python/pyspark/streaming/mqtt.py | 1 + 8 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index eeeac49b21980..2e0c63cb47b17 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -164,6 +164,7 @@ def extractParamMap(self, extra=None): a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra. + :param extra: extra param values :return: merged param map """ @@ -182,6 +183,7 @@ def copy(self, extra=None): embedded and extra parameters over and returns the copy. Subclasses should override this method if the default approach is not sufficient. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ @@ -201,6 +203,7 @@ def _shouldOwn(self, param): def _resolveParam(self, param): """ Resolves a param and validates the ownership. + :param param: param name or the param instance, which must belong to this Params instance :return: resolved param instance @@ -243,6 +246,7 @@ def _copyValues(self, to, extra=None): """ Copies param values from this instance to another instance for params shared by them. + :param to: the target instance :param extra: extra params to be copied :return: the target instance with param values copied diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 13cf2b0f7bbd9..312a8502b3a2c 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -154,6 +154,7 @@ def __init__(self, stages=None): def setStages(self, value): """ Set pipeline stages. + :param value: a list of transformers or estimators :return: the pipeline instance """ diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ab5621f45c72c..705ee53685752 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -254,6 +254,7 @@ def copy(self, extra=None): Creates a copy of this instance with a randomly generated uid and some extra params. This copies creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ @@ -290,6 +291,7 @@ def copy(self, extra=None): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 8218c7c5f801c..4bcb4aaec89de 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -119,6 +119,7 @@ def _create_model(self, java_model): def _fit_java(self, dataset): """ Fits a Java model to the input dataset. + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` :param params: additional params (overwriting embedded values) @@ -173,6 +174,7 @@ def copy(self, extra=None): extra params. This implementation first calls Params.copy and then make a copy of the companion Java model with extra params. So both the Python wrapper and the Java model get copied. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 4398ca86f2ec2..a90e5c50e54b9 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -147,7 +147,7 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. - :param predictionAndLabels an RDD of (prediction, label) pairs. + :param predictionAndLabels: an RDD of (prediction, label) pairs. >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index f929e3e96fbe2..ea42127f1651f 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -240,6 +240,7 @@ class Vector(object): def toArray(self): """ Convert the vector into an numpy.ndarray + :return: numpy.ndarray """ raise NotImplementedError diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 4069d7a149986..a8c9ffc235b9e 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -240,6 +240,7 @@ def start(self): def awaitTermination(self, timeout=None): """ Wait for the execution to stop. + @param timeout: time to wait in seconds """ if timeout is None: @@ -252,6 +253,7 @@ def awaitTerminationOrTimeout(self, timeout): Wait for the execution to stop. Return `true` if it's stopped; or throw the reported error during the execution; or `false` if the waiting time elapsed before returning from the method. + @param timeout: time to wait in seconds """ self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py index f06598971c548..fa83006c36db6 100644 --- a/python/pyspark/streaming/mqtt.py +++ b/python/pyspark/streaming/mqtt.py @@ -31,6 +31,7 @@ def createStream(ssc, brokerUrl, topic, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): """ Create an input stream that pulls messages from a Mqtt Broker. + :param ssc: StreamingContext object :param brokerUrl: Url of remote mqtt publisher :param topic: topic name to subscribe to From f6df004c0fcfa0bd8ea3965f1d34378d57b147a6 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 21 Sep 2015 16:47:52 -0700 Subject: [PATCH 0085/1168] [SPARK-10649] [STREAMING] Prevent inheriting job group and irrelevant job description in streaming jobs The job group, and job descriptions information is passed through thread local properties, and get inherited by child threads. In case of spark streaming, the streaming jobs inherit these properties from the thread that called streamingContext.start(). This may not make sense. 1. Job group: This is mainly used for cancelling a group of jobs together. It does not make sense to cancel streaming jobs like this, as the effect will be unpredictable. And its not a valid usecase any way, to cancel a streaming context, call streamingContext.stop() 2. Job description: This is used to pass on nice text descriptions for jobs to show up in the UI. The job description of the thread that calls streamingContext.start() is not useful for all the streaming jobs, as it does not make sense for all of the streaming jobs to have the same description, and the description may or may not be related to streaming. The solution in this PR is meant for the Spark master branch, where local properties are inherited by cloning the properties. The job group and job description in the thread that starts the streaming scheduler are explicitly removed, so that all the subsequent child threads does not inherit them. Also, the starting is done in a new child thread, so that setting the job group and description for streaming, does not change those properties in the thread that called streamingContext.start(). Author: Tathagata Das Closes #8781 from tdas/SPARK-10649. --- .../org/apache/spark/util/ThreadUtils.scala | 59 +++++++++++++++++++ .../apache/spark/util/ThreadUtilsSuite.scala | 24 +++++++- .../spark/streaming/StreamingContext.scala | 15 ++++- .../streaming/StreamingContextSuite.scala | 32 ++++++++++ 4 files changed, 126 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index ca5624a3d8b3d..22e291a2b48d6 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -21,6 +21,7 @@ package org.apache.spark.util import java.util.concurrent._ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} +import scala.util.control.NonFatal import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} @@ -86,4 +87,62 @@ private[spark] object ThreadUtils { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() Executors.newSingleThreadScheduledExecutor(threadFactory) } + + /** + * Run a piece of code in a new thread and return the result. Exception in the new thread is + * thrown in the caller thread with an adjusted stack trace that removes references to this + * method for clarity. The exception stack traces will be like the following + * + * SomeException: exception-message + * at CallerClass.body-method (sourcefile.scala) + * at ... run in separate thread using org.apache.spark.util.ThreadUtils ... () + * at CallerClass.caller-method (sourcefile.scala) + * ... + */ + def runInNewThread[T]( + threadName: String, + isDaemon: Boolean = true)(body: => T): T = { + @volatile var exception: Option[Throwable] = None + @volatile var result: T = null.asInstanceOf[T] + + val thread = new Thread(threadName) { + override def run(): Unit = { + try { + result = body + } catch { + case NonFatal(e) => + exception = Some(e) + } + } + } + thread.setDaemon(isDaemon) + thread.start() + thread.join() + + exception match { + case Some(realException) => + // Remove the part of the stack that shows method calls into this helper method + // This means drop everything from the top until the stack element + // ThreadUtils.runInNewThread(), and then drop that as well (hence the `drop(1)`). + val baseStackTrace = Thread.currentThread().getStackTrace().dropWhile( + ! _.getClassName.contains(this.getClass.getSimpleName)).drop(1) + + // Remove the part of the new thread stack that shows methods call from this helper method + val extraStackTrace = realException.getStackTrace.takeWhile( + ! _.getClassName.contains(this.getClass.getSimpleName)) + + // Combine the two stack traces, with a place holder just specifying that there + // was a helper method used, without any further details of the helper + val placeHolderStackElem = new StackTraceElement( + s"... run in separate thread using ${ThreadUtils.getClass.getName.stripSuffix("$")} ..", + " ", "", -1) + val finalStackTrace = extraStackTrace ++ Seq(placeHolderStackElem) ++ baseStackTrace + + // Update the stack trace and rethrow the exception in the caller thread + realException.setStackTrace(finalStackTrace) + throw realException + case None => + result + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 8c51e6b14b7fc..620e4debf4e08 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.util import java.util.concurrent.{CountDownLatch, TimeUnit} -import scala.concurrent.{Await, Future} import scala.concurrent.duration._ +import scala.concurrent.{Await, Future} +import scala.util.Random import org.apache.spark.SparkFunSuite @@ -66,4 +67,25 @@ class ThreadUtilsSuite extends SparkFunSuite { val futureThreadName = Await.result(f, 10.seconds) assert(futureThreadName === callerThreadName) } + + test("runInNewThread") { + import ThreadUtils._ + assert(runInNewThread("thread-name") { Thread.currentThread().getName } === "thread-name") + assert(runInNewThread("thread-name") { Thread.currentThread().isDaemon } === true) + assert( + runInNewThread("thread-name", isDaemon = false) { Thread.currentThread().isDaemon } === false + ) + val uniqueExceptionMessage = "test" + Random.nextInt() + val exception = intercept[IllegalArgumentException] { + runInNewThread("thread-name") { throw new IllegalArgumentException(uniqueExceptionMessage) } + } + assert(exception.asInstanceOf[IllegalArgumentException].getMessage === uniqueExceptionMessage) + assert(exception.getStackTrace.mkString("\n").contains( + "... run in separate thread using org.apache.spark.util.ThreadUtils ...") === true, + "stack trace does not contain expected place holder" + ) + assert(exception.getStackTrace.mkString("\n").contains("ThreadUtils.scala") === false, + "stack trace contains unexpected references to ThreadUtils" + ) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index b496d1f341a0b..6720ba4f72cf3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, ShutdownHookManager, Utils} +import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -588,12 +588,20 @@ class StreamingContext private[streaming] ( state match { case INITIALIZED => startSite.set(DStream.getCreationSite()) - sparkContext.setCallSite(startSite.get) StreamingContext.ACTIVATION_LOCK.synchronized { StreamingContext.assertNoOtherContextIsActive() try { validate() - scheduler.start() + + // Start the streaming scheduler in a new thread, so that thread local properties + // like call sites and job groups can be reset without affecting those of the + // current thread. + ThreadUtils.runInNewThread("streaming-start") { + sparkContext.setCallSite(startSite.get) + sparkContext.clearJobGroup() + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + scheduler.start() + } state = StreamingContextState.ACTIVE } catch { case NonFatal(e) => @@ -618,6 +626,7 @@ class StreamingContext private[streaming] ( } } + /** * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index d26894e88fc26..3b9d0d15ea04c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -180,6 +180,38 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.scheduler.isStarted === false) } + test("start should set job group and description of streaming jobs correctly") { + ssc = new StreamingContext(conf, batchDuration) + ssc.sc.setJobGroup("non-streaming", "non-streaming", true) + val sc = ssc.sc + + @volatile var jobGroupFound: String = "" + @volatile var jobDescFound: String = "" + @volatile var jobInterruptFound: String = "" + @volatile var allFound: Boolean = false + + addInputStream(ssc).foreachRDD { rdd => + jobGroupFound = sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) + jobDescFound = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + jobInterruptFound = sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) + allFound = true + } + ssc.start() + + eventually(timeout(10 seconds), interval(10 milliseconds)) { + assert(allFound === true) + } + + // Verify streaming jobs have expected thread-local properties + assert(jobGroupFound === null) + assert(jobDescFound === null) + assert(jobInterruptFound === "false") + + // Verify current thread's thread-local properties have not changed + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) === "true") + } test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) From 006bf95c6bd87625d23e9bc43763187938eb9b0d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 21 Sep 2015 18:06:45 -0700 Subject: [PATCH 0086/1168] [SPARK-10495] [SQL] Read date values in JSON data stored by Spark 1.5.0. https://issues.apache.org/jira/browse/SPARK-10681 Author: Yin Huai Closes #8806 from yhuai/SPARK-10495. --- .../datasources/json/JacksonGenerator.scala | 36 ++++++ .../datasources/json/JacksonParser.scala | 15 ++- .../datasources/json/JsonSuite.scala | 103 +++++++++++++++++- 3 files changed, 152 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index f65c7bbd6e29d..23bada1ddd92f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -73,6 +73,38 @@ private[sql] object JacksonGenerator { valWriter(field.dataType, v) } gen.writeEndObject() + + // For UDT, udt.serialize will produce SQL types. So, we need the following three cases. + case (ArrayType(ty, _), v: ArrayData) => + gen.writeStartArray() + v.foreach(ty, (_, value) => valWriter(ty, value)) + gen.writeEndArray() + + case (MapType(kt, vt, _), v: MapData) => + gen.writeStartObject() + v.foreach(kt, vt, { (k, v) => + gen.writeFieldName(k.toString) + valWriter(vt, v) + }) + gen.writeEndObject() + + case (StructType(ty), v: InternalRow) => + gen.writeStartObject() + var i = 0 + while (i < ty.length) { + val field = ty(i) + val value = v.get(i, field.dataType) + if (value != null) { + gen.writeFieldName(field.name) + valWriter(field.dataType, value) + } + i += 1 + } + gen.writeEndObject() + + case (dt, v) => + sys.error( + s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.") } valWriter(rowSchema, row) @@ -133,6 +165,10 @@ private[sql] object JacksonGenerator { i += 1 } gen.writeEndObject() + + case (dt, v) => + sys.error( + s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.") } valWriter(rowSchema, row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index ff4d8c04e8eaf..c51140749c8e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -62,10 +62,23 @@ private[sql] object JacksonParser { // guard the non string type null + case (VALUE_STRING, BinaryType) => + parser.getBinaryValue + case (VALUE_STRING, DateType) => - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) + val stringValue = parser.getText + if (stringValue.contains("-")) { + // The format of this string will probably be "yyyy-mm-dd". + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) + } else { + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + stringValue.toInt + } case (VALUE_STRING, TimestampType) => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. DateTimeUtils.stringToTime(parser.getText).getTime * 1000L case (VALUE_NUMBER_INT, TimestampType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6a18cc6d27138..b614e6c4148fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -24,7 +24,7 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType @@ -1159,4 +1159,105 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) }) } + + test("backward compatibility") { + // This test we make sure our JSON support can read JSON data generated by previous version + // of Spark generated through toJSON method and JSON data source. + // The data is generated by the following program. + // Here are a few notes: + // - Spark 1.5.0 cannot save timestamp data. So, we manually added timestamp field (col13) + // in the JSON object. + // - For Spark before 1.5.1, we do not generate UDTs. So, we manually added the UDT value to + // JSON objects generated by those Spark versions (col17). + // - If the type is NullType, we do not write data out. + + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + + val dataTypes = + Seq( + StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + val schema = StructType(fields) + + val constantValues = + Seq( + "a string in binary".getBytes("UTF-8"), + null, + true, + 1.toByte, + 2.toShort, + 3, + Long.MaxValue, + 0.25.toFloat, + 0.75, + new java.math.BigDecimal(s"1234.23456"), + new java.math.BigDecimal(s"1.23456"), + java.sql.Date.valueOf("2015-01-01"), + java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"), + Seq(2, 3, 4), + Map("a string" -> 2000L), + Row(4.75.toFloat, Seq(false, true)), + new MyDenseVector(Array(0.25, 2.25, 4.25))) + val data = + Row.fromSeq(Seq("Spark " + sqlContext.sparkContext.version) ++ constantValues) :: Nil + + // Data generated by previous versions. + // scalastyle:off + val existingJSONData = + """{"col0":"Spark 1.2.2","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"16436","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: Nil + // scalastyle:on + + // Generate data for the current version. + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data, 1), schema) + withTempPath { path => + df.write.format("json").mode("overwrite").save(path.getCanonicalPath) + + // df.toJSON will convert internal rows to external rows first and then generate + // JSON objects. While, df.write.format("json") will write internal rows directly. + val allJSON = + existingJSONData ++ + df.toJSON.collect() ++ + sparkContext.textFile(path.getCanonicalPath).collect() + + Utils.deleteRecursively(path) + sparkContext.parallelize(allJSON, 1).saveAsTextFile(path.getCanonicalPath) + + // Read data back with the schema specified. + val col0Values = + Seq( + "Spark 1.2.2", + "Spark 1.3.1", + "Spark 1.3.1", + "Spark 1.4.1", + "Spark 1.4.1", + "Spark 1.5.0", + "Spark 1.5.0", + "Spark " + sqlContext.sparkContext.version, + "Spark " + sqlContext.sparkContext.version) + val expectedResult = col0Values.map { v => + Row.fromSeq(Seq(v) ++ constantValues) + } + checkAnswer( + sqlContext.read.format("json").schema(schema).load(path.getCanonicalPath), + expectedResult + ) + } + } } From 861b057f6c1c7b58f5c15557b1a37b24d171fd41 Mon Sep 17 00:00:00 2001 From: Hossein Date: Mon, 21 Sep 2015 21:09:59 -0700 Subject: [PATCH 0087/1168] [SPARK-10711] [SPARKR] Do not assume spark.submit.deployMode is always set In ```RUtils.sparkRPackagePath()``` we 1. Call ``` sys.props("spark.submit.deployMode")``` which returns null if ```spark.submit.deployMode``` is not suet 2. Call ``` sparkConf.get("spark.submit.deployMode")``` which throws ```NoSuchElementException``` if ```spark.submit.deployMode``` is not set. This patch simply passes a default value ("cluster") for ```spark.submit.deployMode```. cc rxin Author: Hossein Closes #8832 from falaki/SPARK-10711. --- core/src/main/scala/org/apache/spark/api/r/RUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 9e807cc52f18c..fd5646b5b6372 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -44,7 +44,7 @@ private[spark] object RUtils { (sys.props("spark.master"), sys.props("spark.submit.deployMode")) } else { val sparkConf = SparkEnv.get.conf - (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode")) + (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode", "client")) } val isYarnCluster = master != null && master.contains("yarn") && deployMode == "cluster" From 44710976769bde76a1a0aa5ccb2955947c4632e0 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 23:21:24 -0700 Subject: [PATCH 0088/1168] [SPARK-9821] [PYSPARK] pyspark-reduceByKey-should-take-a-custom-partitioner from the issue: In Scala, I can supply a custom partitioner to reduceByKey (and other aggregation/repartitioning methods like aggregateByKey and combinedByKey), but as far as I can tell from the Pyspark API, there's no way to do the same in Python. Here's an example of my code in Scala: weblogs.map(s => (getFileType(s), 1)).reduceByKey(new FileTypePartitioner(),_+_) But I can't figure out how to do the same in Python. The closest I can get is to call repartition before reduceByKey like so: weblogs.map(lambda s: (getFileType(s), 1)).partitionBy(3,hash_filetype).reduceByKey(lambda v1,v2: v1+v2).collect() But that defeats the purpose, because I'm shuffling twice instead of once, so my performance is worse instead of better. Author: Holden Karau Closes #8569 from holdenk/SPARK-9821-pyspark-reduceByKey-should-take-a-custom-partitioner. --- python/pyspark/rdd.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 73d7d9a5692a9..56e892243c79c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -686,7 +686,7 @@ def cartesian(self, other): other._jrdd_deserializer) return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) - def groupBy(self, f, numPartitions=None): + def groupBy(self, f, numPartitions=None, partitionFunc=portable_hash): """ Return an RDD of grouped items. @@ -695,7 +695,7 @@ def groupBy(self, f, numPartitions=None): >>> sorted([(x, sorted(y)) for (x, y) in result]) [(0, [2, 8]), (1, [1, 1, 3, 5])] """ - return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) + return self.map(lambda x: (f(x), x)).groupByKey(numPartitions, partitionFunc) @ignore_unicode_prefix def pipe(self, command, env=None, checkCode=False): @@ -1539,22 +1539,23 @@ def values(self): """ return self.map(lambda x: x[1]) - def reduceByKey(self, func, numPartitions=None): + def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): """ Merge the values for each key using an associative reduce function. This will also perform the merging locally on each mapper before sending results to a reducer, similarly to a "combiner" in MapReduce. - Output will be hash-partitioned with C{numPartitions} partitions, or + Output will be partitioned with C{numPartitions} partitions, or the default parallelism level if C{numPartitions} is not specified. + Default partitioner is hash-partition. >>> from operator import add >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(rdd.reduceByKey(add).collect()) [('a', 2), ('b', 1)] """ - return self.combineByKey(lambda x: x, func, func, numPartitions) + return self.combineByKey(lambda x: x, func, func, numPartitions, partitionFunc) def reduceByKeyLocally(self, func): """ @@ -1739,7 +1740,7 @@ def add_shuffle_key(split, iterator): # TODO: add control over map-side aggregation def combineByKey(self, createCombiner, mergeValue, mergeCombiners, - numPartitions=None): + numPartitions=None, partitionFunc=portable_hash): """ Generic function to combine the elements for each key using a custom set of aggregation functions. @@ -1777,7 +1778,7 @@ def combineLocally(iterator): return merger.items() locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True) - shuffled = locally_combined.partitionBy(numPartitions) + shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) def _mergeCombiners(iterator): merger = ExternalMerger(agg, memory, serializer) @@ -1786,7 +1787,8 @@ def _mergeCombiners(iterator): return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True) - def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None, + partitionFunc=portable_hash): """ Aggregate the values of each key, using given combine functions and a neutral "zero value". This function can return a different result type, U, than the type @@ -1800,9 +1802,9 @@ def createZero(): return copy.deepcopy(zeroValue) return self.combineByKey( - lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) + lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions, partitionFunc) - def foldByKey(self, zeroValue, func, numPartitions=None): + def foldByKey(self, zeroValue, func, numPartitions=None, partitionFunc=portable_hash): """ Merge the values for each key using an associative function "func" and a neutral "zeroValue" which may be added to the result an @@ -1817,13 +1819,14 @@ def foldByKey(self, zeroValue, func, numPartitions=None): def createZero(): return copy.deepcopy(zeroValue) - return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) + return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions, + partitionFunc) def _memory_limit(self): return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) # TODO: support variant with custom partitioner - def groupByKey(self, numPartitions=None): + def groupByKey(self, numPartitions=None, partitionFunc=portable_hash): """ Group the values for each key in the RDD into a single sequence. Hash-partitions the resulting RDD with numPartitions partitions. @@ -1859,7 +1862,7 @@ def combine(iterator): return merger.items() locally_combined = self.mapPartitions(combine, preservesPartitioning=True) - shuffled = locally_combined.partitionBy(numPartitions) + shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) def groupByKey(it): merger = ExternalGroupBy(agg, memory, serializer) From 53e430b111de01979e539d28772892d75d8f4f1d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 21 Sep 2015 23:29:59 -0700 Subject: [PATCH 0089/1168] [SPARK-10716] [BUILD] spark-1.5.0-bin-hadoop2.6.tgz file doesn't uncompress on OS X due to hidden file Remove ._SUCCESS.crc hidden file that may cause problems in distribution tar archive, and is not used Author: Sean Owen Closes #8846 from srowen/SPARK-10716. --- .../test_support/sql/orc_partitioned/._SUCCESS.crc | Bin 8 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 python/test_support/sql/orc_partitioned/._SUCCESS.crc diff --git a/python/test_support/sql/orc_partitioned/._SUCCESS.crc b/python/test_support/sql/orc_partitioned/._SUCCESS.crc deleted file mode 100644 index 3b7b044936a890cd8d651d349a752d819d71d22c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8 PcmYc;N@ieSU}69O2$TUk From 4c8b9715d942384b346efca121c50270f0e74f16 Mon Sep 17 00:00:00 2001 From: Jian Feng Date: Mon, 21 Sep 2015 23:36:41 -0700 Subject: [PATCH 0090/1168] [SPARK-10577] [PYSPARK] DataFrame hint for broadcast join https://issues.apache.org/jira/browse/SPARK-10577 Author: Jian Feng Closes #8801 from Jianfeng-chs/master. --- python/pyspark/sql/functions.py | 9 +++++++++ python/pyspark/sql/tests.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 26b8662718a60..fa04f4cd83b6f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -29,6 +29,7 @@ from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq +from pyspark.sql.dataframe import DataFrame def _create_function(name, doc=""): @@ -189,6 +190,14 @@ def approxCountDistinct(col, rsd=None): return Column(jc) +@since(1.6) +def broadcast(df): + """Marks a DataFrame as small enough for use in broadcast joins.""" + + sc = SparkContext._active_spark_context + return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sql_ctx) + + @since(1.4) def coalesce(*cols): """Returns the first column that is not null. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3e680f1030a71..645133b2b2d84 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1075,6 +1075,24 @@ def foo(): self.assertRaises(TypeError, foo) + # add test for SPARK-10577 (test broadcast join hint) + def test_functions_broadcast(self): + from pyspark.sql.functions import broadcast + + df1 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + df2 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + + # equijoin - should be converted into broadcast join + plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan() + self.assertEqual(1, plan1.toString().count("BroadcastHashJoin")) + + # no join key -- should not be a broadcast join + plan2 = df1.join(broadcast(df2))._jdf.queryExecution().executedPlan() + self.assertEqual(0, plan2.toString().count("BroadcastHashJoin")) + + # planner should not crash without a join + broadcast(df1)._jdf.queryExecution().executedPlan() + class HiveContextSQLTests(ReusedPySparkTestCase): From 49970e5c56e0e56bc1d665fd66eeaffdca5e7924 Mon Sep 17 00:00:00 2001 From: Ewan Leith Date: Mon, 21 Sep 2015 23:43:20 -0700 Subject: [PATCH 0091/1168] [SPARK-10419] [SQL] Adding SQLServer support for datetimeoffset types to JdbcDialects Reading from Microsoft SQL Server over jdbc fails when the table contains datetimeoffset types. This patch registers a SQLServer JDBC Dialect that maps datetimeoffset to a String, as Microsoft suggest. Author: Ewan Leith Closes #8575 from realitymine-coordinator/sqlserver. --- .../apache/spark/sql/jdbc/JdbcDialects.scala | 18 ++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 1 + 2 files changed, 19 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 68ebaaca6c53d..c70fea1c3f50e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -137,6 +137,8 @@ object JdbcDialects { registerDialect(MySQLDialect) registerDialect(PostgresDialect) registerDialect(DB2Dialect) + registerDialect(MsSqlServerDialect) + /** * Fetch the JdbcDialect class corresponding to a given database url. @@ -260,3 +262,19 @@ case object DB2Dialect extends JdbcDialect { case _ => None } } + +/** + * :: DeveloperApi :: + * Default Microsoft SQL Server dialect, mapping the datetimeoffset types to a String on read. + */ +@DeveloperApi +case object MsSqlServerDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (typeName.contains("datetimeoffset")) { + // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients + Some(StringType) + } else None + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5ab9381de4d66..c4b039a9c5359 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -408,6 +408,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) + assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") == MsSqlServerDialect) assert(JdbcDialects.get("test.invalid") == NoopDialect) } From c34cafb749bccdbb6b055e4ea5bf0fe0f37cca53 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Sep 2015 23:46:00 -0700 Subject: [PATCH 0092/1168] [SPARK-10446][SQL] Support to specify join type when calling join with usingColumns JIRA: https://issues.apache.org/jira/browse/SPARK-10446 Currently the method `join(right: DataFrame, usingColumns: Seq[String])` only supports inner join. It is more convenient to have it support other join types. Author: Liang-Chi Hsieh Closes #8600 from viirya/usingcolumns_df. --- python/pyspark/sql/dataframe.py | 6 ++++- .../org/apache/spark/sql/DataFrame.scala | 22 ++++++++++++++++++- .../apache/spark/sql/DataFrameJoinSuite.scala | 13 +++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index fb995fa3a76b5..80f8d8a0eb37d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -567,7 +567,11 @@ def join(self, other, on=None, how=None): if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) elif isinstance(on[0], basestring): - jdf = self._jdf.join(other._jdf, self._jseq(on)) + if how is None: + jdf = self._jdf.join(other._jdf, self._jseq(on), "inner") + else: + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, self._jseq(on), how) else: assert isinstance(on[0], Column), "on should be Column or list of Column" if len(on) > 1: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 8f737c2023931..ba94d77b2e60e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -484,6 +484,26 @@ class DataFrame private[sql]( * @since 1.4.0 */ def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { + join(right, usingColumns, "inner") + } + + /** + * Equi-join with another [[DataFrame]] using the given columns. + * + * Different from other join functions, the join columns will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * @group dfops + * @since 1.6.0 + */ + def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( @@ -502,7 +522,7 @@ class DataFrame private[sql]( Join( joined.left, joined.right, - joinType = Inner, + joinType = JoinType(joinType), condition) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e2716d7841d85..56ad71ea4f487 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -42,6 +42,19 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } + test("join - join using multiple columns and specifying join type") { + val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") + val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "str"), "left"), + Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "right"), + Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil) + } + test("join - join using self join") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") From 96cc0d1fd8620f629d34d244eb85ef9b21f6c24e Mon Sep 17 00:00:00 2001 From: Madhusudanan Kandasamy Date: Tue, 22 Sep 2015 00:03:48 -0700 Subject: [PATCH 0093/1168] [SPARK-10458] [SPARK CORE] Added isStopped() method in SparkContext Added isStopped() method in SparkContext Author: Madhusudanan Kandasamy Closes #8749 from kmadhugit/SPARK-10458. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ebd8e946ee7a2..967fec9f42bcf 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -265,6 +265,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val tachyonFolderName = externalBlockStoreFolderName def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + /** + * @return true if context is stopped or in the midst of stopping. + */ + def isStopped: Boolean = stopped.get() // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus From ea6944b49a1ad05bd1b0b671993e8059a576a47c Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 22 Sep 2015 00:05:30 -0700 Subject: [PATCH 0094/1168] [Minor] style fix for previous commit f24316e --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 967fec9f42bcf..bf3aeb488d597 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -265,6 +265,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val tachyonFolderName = externalBlockStoreFolderName def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + /** * @return true if context is stopped or in the midst of stopping. */ From ce31d86476755f5f6b0e976c460fdbbf90019983 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 00:07:30 -0700 Subject: [PATCH 0095/1168] [SPARK-8567] [SQL] Increase the timeout of o.a.s.sql.hive.HiveSparkSubmitSuite to 5 minutes. https://issues.apache.org/jira/browse/SPARK-8567 Looks like "SPARK-8368: includes jars passed in through --jars" is pretty flaky now. Based on some history runs, the time spent on a successful run may be from 1.5 minutes to almost 3 minutes. Let's try to increase the timeout and see if we can fix this test. https://amplab.cs.berkeley.edu/jenkins/job/Spark-1.5-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.0,label=spark-test/385/testReport/junit/org.apache.spark.sql.hive/HiveSparkSubmitSuite/SPARK_8368__includes_jars_passed_in_through___jars/history/?start=25 Author: Yin Huai Closes #8850 from yhuai/SPARK-8567-anotherTry. --- .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 97df249bdb6d6..5f1660b62d418 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -139,7 +139,7 @@ class HiveSparkSubmitSuite new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { - val exitCode = failAfter(180.seconds) { process.waitFor() } + val exitCode = failAfter(300.seconds) { process.waitFor() } if (exitCode != 0) { // include logs in output. Note that logging is async and may not have completed // at the time this exception is raised From 6b8c6bec3f2458f7c212d5e69a98828324ee3f01 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Sep 2015 00:09:29 -0700 Subject: [PATCH 0096/1168] [SQL] [MINOR] map -> foreach. DataFrame.explain should use foreach to print the explain content. Author: Reynold Xin Closes #8862 from rxin/map-foreach. --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ba94d77b2e60e..a11140b717360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -320,9 +320,8 @@ class DataFrame private[sql]( * @since 1.3.0 */ def explain(extended: Boolean): Unit = { - ExplainCommand( - queryExecution.logical, - extended = extended).queryExecution.executedPlan.executeCollect().map { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + explain.queryExecution.executedPlan.executeCollect().foreach { // scalastyle:off println r => println(r.getString(0)) // scalastyle:on println From 39bbf347e97d7501a88403256dae912a31968d24 Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Tue, 22 Sep 2015 00:14:27 -0700 Subject: [PATCH 0097/1168] =?UTF-8?q?[SPARK-10695]=20[DOCUMENTATION]=20[ME?= =?UTF-8?q?SOS]=20Fixing=20incorrect=20value=20informati=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …on for spark.mesos.constraints parameter. Author: Akash Mishra Closes #8816 from SleepyThread/constraint-fix. --- docs/running-on-mesos.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 460a66f37dd64..ec5a44d79212b 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -189,10 +189,10 @@ using `conf.set("spark.cores.max", "10")` (for example). You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. {% highlight scala %} -conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +conf.set("spark.mesos.constraints", "tachyon:true;us-east-1:false") {% endhighlight %} -For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. +For example, Let's say `spark.mesos.constraints` is set to `tachyon:true;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. # Mesos Docker Support From 3815a48fc5ac025919310151a421083d59ce43a4 Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Tue, 22 Sep 2015 11:03:21 +0100 Subject: [PATCH 0098/1168] [SPARK-10718] [BUILD] Update License on conf files and corresponding excludes file update Update License on conf files and corresponding excludes file update Author: Rekha Joshi Author: Joshi Closes #8842 from rekhajoshm/SPARK-10718. --- .rat-excludes | 12 ------------ conf/docker.properties.template | 17 +++++++++++++++++ conf/fairscheduler.xml.template | 18 ++++++++++++++++++ conf/log4j.properties.template | 17 +++++++++++++++++ conf/metrics.properties.template | 17 +++++++++++++++++ conf/slaves.template | 17 +++++++++++++++++ conf/spark-defaults.conf.template | 17 +++++++++++++++++ conf/spark-env.sh.template | 17 +++++++++++++++++ .../spark/log4j-defaults-repl.properties | 17 +++++++++++++++++ .../org/apache/spark/log4j-defaults.properties | 17 +++++++++++++++++ 10 files changed, 154 insertions(+), 12 deletions(-) diff --git a/.rat-excludes b/.rat-excludes index 9165872b9fb27..08fba6d351d6a 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -15,20 +15,8 @@ TAGS RELEASE control docs -docker.properties.template -fairscheduler.xml.template -spark-defaults.conf.template -log4j.properties -log4j.properties.template -metrics.properties -metrics.properties.template slaves -slaves.template -spark-env.sh spark-env.cmd -spark-env.sh.template -log4j-defaults.properties -log4j-defaults-repl.properties bootstrap-tooltip.js jquery-1.11.1.min.js d3.min.js diff --git a/conf/docker.properties.template b/conf/docker.properties.template index 26e3bfd9c5b9b..55cb094b4af46 100644 --- a/conf/docker.properties.template +++ b/conf/docker.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + spark.mesos.executor.docker.image: spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro spark.mesos.executor.home: /opt/spark diff --git a/conf/fairscheduler.xml.template b/conf/fairscheduler.xml.template index acf59e2a35986..385b2e772d2c8 100644 --- a/conf/fairscheduler.xml.template +++ b/conf/fairscheduler.xml.template @@ -1,4 +1,22 @@ + + + FAIR diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 74c5cea94403a..f3046be54d7c6 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 7f17bc7eea4f5..d6962e0da2f30 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # syntax: [instance].sink|source.[name].[options]=[value] # This file configures Spark's internal metrics system. The metrics system is diff --git a/conf/slaves.template b/conf/slaves.template index da0a01343d20a..be42a638230b7 100644 --- a/conf/slaves.template +++ b/conf/slaves.template @@ -1,2 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # A Spark Worker will be started on each of the machines listed below. localhost \ No newline at end of file diff --git a/conf/spark-defaults.conf.template b/conf/spark-defaults.conf.template index a48dcc70e1363..19cba6e71ed19 100644 --- a/conf/spark-defaults.conf.template +++ b/conf/spark-defaults.conf.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Default system properties included when running spark-submit. # This is useful for setting default environmental settings. diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index c05fe381a36a7..990ded420be72 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -1,5 +1,22 @@ #!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties index 689afea64f8db..c85abc35b93bf 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=WARN, console log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 27006e45e932b..d44cc85dcbd82 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender From 71afa877150f357987a0b5ea91ec70deb5fc9430 Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Tue, 22 Sep 2015 11:05:24 +0100 Subject: [PATCH 0099/1168] [SPARK-10706] [MLLIB] Add java wrapper for random vector rdd Add java wrapper for random vector rdd holdenk srowen Author: Meihua Wu Closes #8841 from rotationsymmetry/SPARK-10706. --- .../spark/mllib/random/RandomRDDs.scala | 42 +++++++++++++++++++ .../mllib/random/JavaRandomRDDsSuite.java | 17 ++++++++ 2 files changed, 59 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index f8ff26b5795be..41d7c4d355f61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -855,6 +855,48 @@ object RandomRDDs { sc, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), generator, seed) } + /** + * Java-friendly version of [[RandomRDDs#randomVectorRDD]]. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions, seed).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default number of partitions and the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols).toJavaRDD() + } + /** * Returns `numPartitions` if it is positive, or `sc.defaultParallelism` otherwise. */ diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index fce5f6712f462..5728df5aeebdc 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -246,6 +246,23 @@ public void testArbitrary() { Assert.assertEquals(2, rdd.first().length()); } } + + @Test + @SuppressWarnings("unchecked") + public void testRandomVectorRDD() { + UniformGenerator generator = new UniformGenerator(); + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = randomJavaVectorRDD(sc, generator, m, n); + JavaRDD rdd2 = randomJavaVectorRDD(sc, generator, m, n, p); + JavaRDD rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed); + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } } // This is just a test generator, it always returns a string of 42 From 41279e63965ab4bdcbb44b7c78f9a42275a8ad46 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 22 Sep 2015 10:19:08 -0700 Subject: [PATCH 0100/1168] [SPARK-9962] [ML] Decision Tree training: prevNodeIdsForInstances.unpersist() at end of training NodeIdCache: prevNodeIdsForInstances.unpersist() needs to be called at end of training. Author: Holden Karau Closes #8541 from holdenk/SPARK-9962-decission-tree-training-prevNodeIdsForiNstances-unpersist-at-end-of-training. --- .../scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala | 8 ++++---- .../org/apache/spark/mllib/tree/impl/NodeIdCache.scala | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 488e8e4fb5dcd..c5ad8df73fac9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -164,10 +164,10 @@ private[spark] class NodeIdCache( } } } - } - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 8f9eb24b57b55..0abed5411143d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -166,6 +166,10 @@ private[spark] class NodeIdCache( fs.delete(new Path(old.getCheckpointFile.get), true) } } + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } } } From 5c990ff3266e29471803ee651c5eb7b3e61a1069 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 22 Sep 2015 11:00:33 -0700 Subject: [PATCH 0101/1168] [SPARK-10750] [ML] ML Param validate should print better error information Currently when you set illegal value for params of array type (such as IntArrayParam, DoubleArrayParam, StringArrayParam), it will throw IllegalArgumentException but with incomprehensible error information. Take ```VectorSlicer.setNames``` as an example: ```scala val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") // The value of setNames must be contain distinct elements, so the next line will throw exception. vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4", "f1")) ``` It will throw IllegalArgumentException as: ``` vectorSlicer_b3b4d1a10f43 parameter names given invalid value [Ljava.lang.String;798256c5. java.lang.IllegalArgumentException: vectorSlicer_b3b4d1a10f43 parameter names given invalid value [Ljava.lang.String;798256c5. ``` We should distinguish the value of array type from primitive type at Param.validate(value: T), and we will get better error information. ``` vectorSlicer_3b744ea277b2 parameter names given invalid value [f1,f4,f1]. java.lang.IllegalArgumentException: vectorSlicer_3b744ea277b2 parameter names given invalid value [f1,f4,f1]. ``` Author: Yanbo Liang Closes #8863 from yanboliang/spark-10750. --- .../src/main/scala/org/apache/spark/ml/param/params.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index de32b7218c277..48f6269e57e98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -65,7 +65,12 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali */ private[param] def validate(value: T): Unit = { if (!isValid(value)) { - throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value.") + val valueToString = value match { + case v: Array[_] => v.mkString("[", ",", "]") + case _ => value.toString + } + throw new IllegalArgumentException( + s"$parent parameter $name given invalid value $valueToString.") } } From 1e718183e48eb2d16806690868f939038004aa31 Mon Sep 17 00:00:00 2001 From: xutingjun Date: Tue, 22 Sep 2015 11:01:32 -0700 Subject: [PATCH 0102/1168] [SPARK-9585] Delete the input format caching because some input format are non thread safe If we cache the InputFormat, all tasks on the same executor will share it. Some InputFormat is thread safety, but some are not, such as HiveHBaseTableInputFormat. If tasks share a non thread safe InputFormat, unexpected error may be occurs. To avoid it, I think we should delete the input format caching. Author: xutingjun Author: meiyoula <1039320815@qq.com> Author: Xutingjun Closes #7918 from XuTingjun/cached_inputFormat. --- core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 8f2655d63b797..77b57132b9f1f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -182,17 +182,11 @@ class HadoopRDD[K, V]( } protected def getInputFormat(conf: JobConf): InputFormat[K, V] = { - if (HadoopRDD.containsCachedMetadata(inputFormatCacheKey)) { - return HadoopRDD.getCachedMetadata(inputFormatCacheKey).asInstanceOf[InputFormat[K, V]] - } - // Once an InputFormat for this RDD is created, cache it so that only one reflection call is - // done in each local process. val newInputFormat = ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) .asInstanceOf[InputFormat[K, V]] if (newInputFormat.isInstanceOf[Configurable]) { newInputFormat.asInstanceOf[Configurable].setConf(conf) } - HadoopRDD.putCachedMetadata(inputFormatCacheKey, newInputFormat) newInputFormat } From 77a6bce2d866083420dff12ca0d695dbe306f83c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 22 Sep 2015 11:07:01 -0700 Subject: [PATCH 0103/1168] [SPARK-10593] [SQL] fix resolve output of Generate The output of Generate should not be resolved as Reference. Author: Davies Liu Closes #8755 from davies/view. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 16 ++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 1 - .../catalyst/plans/logical/basicOperators.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 14 ++++++++++++++ 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02f34cbf58ad0..bf72d47ce1ea6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -378,6 +378,22 @@ class Analyzer( val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) + // A special case for Generate, because the output of Generate should not be resolved by + // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. + case g @ Generate(generator, join, outer, qualifier, output, child) + if child.resolved && !generator.resolved => + val newG = generator transformUp { + case u @ UnresolvedAttribute(nameParts) => + withPosition(u) { child.resolve(nameParts, resolver).getOrElse(u) } + case UnresolvedExtractValue(child, fieldExpr) => + ExtractValue(child, fieldExpr, resolver) + } + if (newG.fastEquals(generator)) { + g + } else { + Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) + } + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 55286f9f2fc5c..0ec9f08571082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 722f69cdca827..ae9482c10f126 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -68,7 +68,7 @@ case class Generate( generator.resolved && childrenResolved && generator.elementTypes.length == generatorOutput.length && - !generatorOutput.exists(!_.resolved) + generatorOutput.forall(_.resolved) } // we don't want the gOutput to be taken as part of the expressions diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8126d02335217..bb02473dd17ca 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1170,4 +1170,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sqlContext.table("`db.t`"), df) } } + + test("SPARK-10593 same column names in lateral view") { + val df = sqlContext.sql( + """ + |select + |insideLayer2.json as a2 + |from (select '{"layer1": {"layer2": "text inside layer 2"}}' json) test + |lateral view json_tuple(json, 'layer1') insideLayer1 as json + |lateral view json_tuple(insideLayer1.json, 'layer2') insideLayer2 as json + """.stripMargin + ) + + checkAnswer(df, Row("text inside layer 2") :: Nil) + } } From 32644d8e8ab0a909a956226c9b6e4397abec69fa Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Sep 2015 11:50:22 -0700 Subject: [PATCH 0104/1168] [SPARK-10704] Rename HashShuffleReader to BlockStoreShuffleReader The current shuffle code has an interface named ShuffleReader with only one implementation, HashShuffleReader. This naming is confusing, since the same read path code is used for both sort- and hash-based shuffle. This patch addresses this by renaming HashShuffleReader to BlockStoreShuffleReader. Author: Josh Rosen Closes #8825 from JoshRosen/shuffle-reader-cleanup. --- ...shShuffleReader.scala => BlockStoreShuffleReader.scala} | 5 ++--- .../org/apache/spark/shuffle/hash/HashShuffleManager.scala | 2 +- .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 3 +-- ...eaderSuite.scala => BlockStoreShuffleReaderSuite.scala} | 7 +++---- 4 files changed, 7 insertions(+), 10 deletions(-) rename core/src/main/scala/org/apache/spark/shuffle/{hash/HashShuffleReader.scala => BlockStoreShuffleReader.scala} (97%) rename core/src/test/scala/org/apache/spark/shuffle/{hash/HashShuffleReaderSuite.scala => BlockStoreShuffleReaderSuite.scala} (96%) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala rename to core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0c8f08f0f3b1b..6dc9a16e58531 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -15,16 +15,15 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -private[spark] class HashShuffleReader[K, C]( +private[spark] class BlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 0b46634b8b466..d2e2fc4c110a7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -51,7 +51,7 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 476cc1f303da7..9df4e551669cc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.hash.HashShuffleReader private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { @@ -54,7 +53,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { // We currently use the same block store shuffle fetcher as the hash-based shuffle. - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala similarity index 96% rename from core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala rename to core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 05b3afef5b839..a5eafb1b5529e 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer @@ -28,7 +28,6 @@ import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.BaseShuffleHandle import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** @@ -56,7 +55,7 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed } } -class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { +class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { /** * This test makes sure that, when data is read from a HashShuffleReader, the underlying @@ -134,7 +133,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { new BaseShuffleHandle(shuffleId, numMaps, dependency) } - val shuffleReader = new HashShuffleReader( + val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, From b81b8bafbdce2cd19cd711e941c02ffe2a6c9982 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 Sep 2015 12:14:15 -0700 Subject: [PATCH 0105/1168] [SPARK-10740] [SQL] handle nondeterministic expressions correctly for set operations https://issues.apache.org/jira/browse/SPARK-10740 Author: Wenchen Fan Closes #8858 from cloud-fan/non-deter. --- .../sql/catalyst/optimizer/Optimizer.scala | 69 ++++++++++++++----- .../optimizer/SetOperationPushDownSuite.scala | 3 +- .../org/apache/spark/sql/DataFrameSuite.scala | 41 +++++++++++ 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 324f40a051c38..63602eaa8ccd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -95,14 +95,14 @@ object SamplePushDown extends Rule[LogicalPlan] { * Intersect: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * because we will not have non-deterministic expressions. + * with deterministic condition. * * Except: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * because we will not have non-deterministic expressions. + * with deterministic condition. */ -object SetOperationPushDown extends Rule[LogicalPlan] { +object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. @@ -129,34 +129,65 @@ object SetOperationPushDown extends Rule[LogicalPlan] { result.asInstanceOf[A] } + /** + * Splits the condition expression into small conditions by `And`, and partition them by + * deterministic, and finally recombine them by `And`. It returns an expression containing + * all deterministic expressions (the first field of the returned Tuple2) and an expression + * containing all non-deterministic expressions (the second field of the returned Tuple2). + */ + private def partitionByDeterministic(condition: Expression): (Expression, Expression) = { + val andConditions = splitConjunctivePredicates(condition) + andConditions.partition(_.deterministic) match { + case (deterministic, nondeterministic) => + deterministic.reduceOption(And).getOrElse(Literal(true)) -> + nondeterministic.reduceOption(And).getOrElse(Literal(true)) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Push down filter into union case Filter(condition, u @ Union(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(u) - Union( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) - - // Push down projection through UNION ALL - case Project(projectList, u @ Union(left, right)) => - val rewrites = buildRewrites(u) - Union( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) + Filter(nondeterministic, + Union( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) + + // Push down deterministic projection through UNION ALL + case p @ Project(projectList, u @ Union(left, right)) => + if (projectList.forall(_.deterministic)) { + val rewrites = buildRewrites(u) + Union( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) + } else { + p + } // Push down filter through INTERSECT case Filter(condition, i @ Intersect(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(i) - Intersect( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) + Filter(nondeterministic, + Intersect( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) // Push down filter through EXCEPT case Filter(condition, e @ Except(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(e) - Except( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) + Filter(nondeterministic, + Except( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 3fca47a023dc6..1595ad9327423 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -30,7 +30,8 @@ class SetOperationPushDownSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, - SetOperationPushDown) :: Nil + SetOperationPushDown, + SimplifyFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1370713975f2f..d919877746c72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -916,4 +916,45 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(intersect.count() === 30) assert(except.count() === 70) } + + test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { + val df1 = (1 to 20).map(Tuple1.apply).toDF("i") + val df2 = (1 to 10).map(Tuple1.apply).toDF("i") + + // When generating expected results at here, we need to follow the implementation of + // Rand expression. + def expected(df: DataFrame): Seq[Row] = { + df.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.filter(_.getInt(0) < rng.nextDouble() * 10) + } + } + + val union = df1.unionAll(df2) + checkAnswer( + union.filter('i < rand(7) * 10), + expected(union) + ) + checkAnswer( + union.select(rand(7)), + union.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.map(_ => rng.nextDouble()).map(i => Row(i)) + } + ) + + val intersect = df1.intersect(df2) + checkAnswer( + intersect.filter('i < rand(7) * 10), + expected(intersect) + ) + + val except = df1.except(df2) + checkAnswer( + except.filter('i < rand(7) * 10), + expected(except) + ) + } } From 5d52d79c40068153a27fb870ad9f872b2ba48159 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 13:29:39 -0700 Subject: [PATCH 0106/1168] [SPARK-10672] [SQL] Do not fail when we cannot save the metadata of a data source table in a hive compatible way https://issues.apache.org/jira/browse/SPARK-10672 With changes in this PR, we will fallback to same the metadata of a table in Spark SQL specific way if we fail to save it in a hive compatible way (Hive throws an exception because of its internal restrictions, e.g. binary and decimal types cannot be saved to parquet if the metastore is running Hive 0.13). I manually tested the fix with the following test in `DataSourceWithHiveMetastoreCatalogSuite` (`spark.sql.hive.metastore.version=0.13` and `spark.sql.hive.metastore.jars`=`maven`). ``` test(s"fail to save metadata of a parquet table in hive 0.13") { withTempPath { dir => withTable("t") { val path = dir.getCanonicalPath sql( s"""CREATE TABLE t USING $provider |OPTIONS (path '$path') |AS SELECT 1 AS d1, cast("val_1" as binary) AS d2 """.stripMargin) sql( s"""describe formatted t """.stripMargin).collect.foreach(println) sqlContext.table("t").show } } } } ``` Without this fix, we will fail with the following error. ``` org.apache.hadoop.hive.ql.metadata.HiveException: java.lang.UnsupportedOperationException: Unknown field type: binary at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:619) at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:576) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply$mcV$sp(ClientWrapper.scala:359) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply(ClientWrapper.scala:357) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply(ClientWrapper.scala:357) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$withHiveState$1.apply(ClientWrapper.scala:256) at org.apache.spark.sql.hive.client.ClientWrapper.retryLocked(ClientWrapper.scala:211) at org.apache.spark.sql.hive.client.ClientWrapper.withHiveState(ClientWrapper.scala:248) at org.apache.spark.sql.hive.client.ClientWrapper.createTable(ClientWrapper.scala:357) at org.apache.spark.sql.hive.HiveMetastoreCatalog.createDataSourceTable(HiveMetastoreCatalog.scala:358) at org.apache.spark.sql.hive.execution.CreateMetastoreDataSourceAsSelect.run(commands.scala:285) at org.apache.spark.sql.execution.ExecutedCommand.sideEffectResult$lzycompute(commands.scala:57) at org.apache.spark.sql.execution.ExecutedCommand.sideEffectResult(commands.scala:57) at org.apache.spark.sql.execution.ExecutedCommand.doExecute(commands.scala:69) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$5.apply(SparkPlan.scala:140) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$5.apply(SparkPlan.scala:138) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:150) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:138) at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:58) at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:58) at org.apache.spark.sql.DataFrame.(DataFrame.scala:144) at org.apache.spark.sql.DataFrame.(DataFrame.scala:129) at org.apache.spark.sql.DataFrame$.apply(DataFrame.scala:51) at org.apache.spark.sql.SQLContext.sql(SQLContext.scala:725) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$sql$1.apply(SQLTestUtils.scala:56) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$sql$1.apply(SQLTestUtils.scala:56) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2$$anonfun$apply$2.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:165) at org.apache.spark.sql.test.SQLTestUtils$class.withTable(SQLTestUtils.scala:150) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTable(HiveMetastoreCatalogSuite.scala:52) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2.apply(HiveMetastoreCatalogSuite.scala:162) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2.apply(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.test.SQLTestUtils$class.withTempPath(SQLTestUtils.scala:125) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTempPath(HiveMetastoreCatalogSuite.scala:52) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:161) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:42) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:175) at org.scalatest.FunSuite.runTest(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:413) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:401) at scala.collection.immutable.List.foreach(List.scala:318) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:396) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:483) at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:208) at org.scalatest.FunSuite.runTests(FunSuite.scala:1555) at org.scalatest.Suite$class.run(Suite.scala:1424) at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.SuperEngine.runImpl(Engine.scala:545) at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:212) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.org$scalatest$BeforeAndAfterAll$$super$run(HiveMetastoreCatalogSuite.scala:52) at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:257) at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:256) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.run(HiveMetastoreCatalogSuite.scala:52) at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:462) at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:671) at sbt.ForkMain$Run$2.call(ForkMain.java:294) at sbt.ForkMain$Run$2.call(ForkMain.java:284) at java.util.concurrent.FutureTask.run(FutureTask.java:262) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) Caused by: java.lang.UnsupportedOperationException: Unknown field type: binary at org.apache.hadoop.hive.ql.io.parquet.serde.ArrayWritableObjectInspector.getObjectInspector(ArrayWritableObjectInspector.java:108) at org.apache.hadoop.hive.ql.io.parquet.serde.ArrayWritableObjectInspector.(ArrayWritableObjectInspector.java:60) at org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe.initialize(ParquetHiveSerDe.java:113) at org.apache.hadoop.hive.metastore.MetaStoreUtils.getDeserializer(MetaStoreUtils.java:339) at org.apache.hadoop.hive.ql.metadata.Table.getDeserializerFromMetaStore(Table.java:288) at org.apache.hadoop.hive.ql.metadata.Table.checkValidity(Table.java:194) at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:597) ... 76 more ``` Author: Yin Huai Closes #8824 from yhuai/datasourceMetadata. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 101 +++++++++--------- 1 file changed, 50 insertions(+), 51 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0c1b41e3377e3..012634cb5aeb5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -309,69 +309,68 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } // TODO: Support persisting partitioned data source relations in Hive compatible format - val hiveTable = (maybeSerDe, dataSource.relation) match { + val qualifiedTableName = tableIdent.quotedString + val (hiveCompitiableTable, logMessage) = (maybeSerDe, dataSource.relation) match { case (Some(serde), relation: HadoopFsRelation) - if relation.paths.length == 1 && relation.partitionColumns.isEmpty => - // Hive ParquetSerDe doesn't support decimal type until 1.2.0. - val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet")) - val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType]) - - val hiveParquetSupportsDecimal = client.version match { - case org.apache.spark.sql.hive.client.hive.v1_2 => true - case _ => false - } - - if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) { - // If Hive version is below 1.2.0, we cannot save Hive compatible schema to - // metastore when the file format is Parquet and the schema has DecimalType. - logWarning { - "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " + - "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " + - s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384." - } - newSparkSQLSpecificMetastoreTable() - } else { - logInfo { - "Persisting data source relation with a single input path into Hive metastore in " + - s"Hive compatible format. Input path: ${relation.paths.head}" - } - newHiveCompatibleMetastoreTable(relation, serde) - } + if relation.paths.length == 1 && relation.partitionColumns.isEmpty => + val hiveTable = newHiveCompatibleMetastoreTable(relation, serde) + val message = + s"Persisting data source relation $qualifiedTableName with a single input path " + + s"into Hive metastore in Hive compatible format. Input path: ${relation.paths.head}." + (Some(hiveTable), message) case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => - logWarning { - "Persisting partitioned data source relation into Hive metastore in " + - s"Spark SQL specific format, which is NOT compatible with Hive. Input path(s): " + - relation.paths.mkString("\n", "\n", "") - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Persisting partitioned data source relation $qualifiedTableName into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + + "Input path(s): " + relation.paths.mkString("\n", "\n", "") + (None, message) case (Some(serde), relation: HadoopFsRelation) => - logWarning { - "Persisting data source relation with multiple input paths into Hive metastore in " + - s"Spark SQL specific format, which is NOT compatible with Hive. Input paths: " + - relation.paths.mkString("\n", "\n", "") - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Persisting data source relation $qualifiedTableName with multiple input paths into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + + s"Input paths: " + relation.paths.mkString("\n", "\n", "") + (None, message) case (Some(serde), _) => - logWarning { - s"Data source relation is not a ${classOf[HadoopFsRelation].getSimpleName}. " + - "Persisting it into Hive metastore in Spark SQL specific format, " + - "which is NOT compatible with Hive." - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Data source relation $qualifiedTableName is not a " + + s"${classOf[HadoopFsRelation].getSimpleName}. Persisting it into Hive metastore " + + "in Spark SQL specific format, which is NOT compatible with Hive." + (None, message) case _ => - logWarning { + val message = s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + - "Persisting data source relation into Hive metastore in Spark SQL specific format, " + - "which is NOT compatible with Hive." - } - newSparkSQLSpecificMetastoreTable() + s"Persisting data source relation $qualifiedTableName into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive." + (None, message) } - client.createTable(hiveTable) + (hiveCompitiableTable, logMessage) match { + case (Some(table), message) => + // We first try to save the metadata of the table in a Hive compatiable way. + // If Hive throws an error, we fall back to save its metadata in the Spark SQL + // specific way. + try { + logInfo(message) + client.createTable(table) + } catch { + case throwable: Throwable => + val warningMessage = + s"Could not persist $qualifiedTableName in a Hive compatible way. Persisting " + + s"it into Hive metastore in Spark SQL specific format." + logWarning(warningMessage, throwable) + val sparkSqlSpecificTable = newSparkSQLSpecificMetastoreTable() + client.createTable(sparkSqlSpecificTable) + } + + case (None, message) => + logWarning(message) + val hiveTable = newSparkSQLSpecificMetastoreTable() + client.createTable(hiveTable) + } } def hiveDefaultTableFilePath(tableName: String): String = { From 59b6d34fd06d726017da46df97a8d4ed75bf87df Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 13:31:35 -0700 Subject: [PATCH 0107/1168] [SPARK-10737] [SQL] When using UnsafeRows, SortMergeJoin may return wrong results https://issues.apache.org/jira/browse/SPARK-10737 Author: Yin Huai Closes #8854 from yhuai/SMJBug. --- .../codegen/GenerateProjection.scala | 2 ++ .../apache/spark/sql/execution/Window.scala | 9 ++++-- .../sql/execution/joins/SortMergeJoin.scala | 25 +++++++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 28 +++++++++++++++++++ 4 files changed, 59 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 2164ddf03d1b2..75524b568d685 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -171,6 +171,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { @Override public Object apply(Object r) { + // GenerateProjection does not work with UnsafeRows. + assert(!(r instanceof ${classOf[UnsafeRow].getName})); return new SpecificRow((InternalRow) r); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 0269d6d4b7a1c..f8929530c5036 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -253,7 +253,11 @@ case class Window( // Get all relevant projections. val result = createResultProjection(unboundExpressions) - val grouping = newProjection(partitionSpec, child.output) + val grouping = if (child.outputsUnsafeRows) { + UnsafeProjection.create(partitionSpec, child.output) + } else { + newProjection(partitionSpec, child.output) + } // Manage the stream and the grouping. var nextRow: InternalRow = EmptyRow @@ -277,7 +281,8 @@ case class Window( val numFrames = frames.length private[this] def fetchNextPartition() { // Collect all the rows in the current partition. - val currentGroup = nextGroup + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() rows = new CompactBuffer while (nextRowAvailable && nextGroup == currentGroup) { rows += nextRow.copy() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 906f20d2a7289..70a1af6a7063a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -56,9 +56,6 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) - @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) - protected[this] def isUnsafeMode: Boolean = { (codegenEnabled && unsafeEnabled && UnsafeProjection.canSupport(leftKeys) @@ -82,6 +79,28 @@ case class SortMergeJoin( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => new RowIterator { + // The projection used to extract keys from input rows of the left child. + private[this] val leftKeyGenerator = { + if (isUnsafeMode) { + // It is very important to use UnsafeProjection if input rows are UnsafeRows. + // Otherwise, GenerateProjection will cause wrong results. + UnsafeProjection.create(leftKeys, left.output) + } else { + newProjection(leftKeys, left.output) + } + } + + // The projection used to extract keys from input rows of the right child. + private[this] val rightKeyGenerator = { + if (isUnsafeMode) { + // It is very important to use UnsafeProjection if input rows are UnsafeRows. + // Otherwise, GenerateProjection will cause wrong results. + UnsafeProjection.create(rightKeys, right.output) + } else { + newProjection(rightKeys, right.output) + } + } + // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) private[this] var currentLeftRow: InternalRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 05b4127cbcaff..eca6f1073889a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1781,4 +1781,32 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row(1), Row(1))) } } + + test("SortMergeJoin returns wrong results when using UnsafeRows") { + // This test is for the fix of https://issues.apache.org/jira/browse/SPARK-10737. + // This bug will be triggered when Tungsten is enabled and there are multiple + // SortMergeJoin operators executed in the same task. + val confs = + SQLConf.SORTMERGE_JOIN.key -> "true" :: + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: + SQLConf.TUNGSTEN_ENABLED.key -> "true" :: Nil + withSQLConf(confs: _*) { + val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j") + val df2 = + df1 + .join(df1.select(df1("i")), "i") + .select(df1("i"), df1("j")) + + val df3 = df2.withColumnRenamed("i", "i1").withColumnRenamed("j", "j1") + val df4 = + df2 + .join(df3, df2("i") === df3("i1")) + .withColumn("diff", $"j" - $"j1") + .select(df2("i"), df2("j"), $"diff") + + checkAnswer( + df4, + df1.withColumn("diff", lit(0))) + } + } } From b56e635c57b7763f095534c088c605f9c9108b7d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Sep 2015 14:11:46 -0700 Subject: [PATCH 0108/1168] [SPARK-10714] [SPARK-8632] [SPARK-10685] [SQL] Refactor Python UDF handling This patch refactors Python UDF handling: 1. Extract the per-partition Python UDF calling logic from PythonRDD into a PythonRunner. PythonRunner itself expects iterator as input/output, and thus has no dependency on RDD. This way, we can use PythonRunner directly in a mapPartitions call, or in the future in an environment without RDDs. 2. Use PythonRunner in Spark SQL's BatchPythonEvaluation. 3. Updated BatchPythonEvaluation to only use its input once, rather than twice. This should fix Python UDF performance regression in Spark 1.5. There are a number of small cleanups I wanted to do when I looked at the code, but I kept most of those out so the diff looks small. This basically implements the approach in https://github.com/apache/spark/pull/8833, but with some code moving around so the correctness doesn't depend on the inner workings of Spark serialization and task execution. Author: Reynold Xin Closes #8835 from rxin/python-iter-refactor. --- .../apache/spark/api/python/PythonRDD.scala | 54 ++++++++++--- .../spark/sql/execution/pythonUDFs.scala | 80 +++++++++++-------- 2 files changed, 89 insertions(+), 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 69da180593bb5..3788d1829758a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -24,6 +24,7 @@ import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JM import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials +import scala.util.control.NonFatal import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.conf.Configuration @@ -38,7 +39,6 @@ import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.util.{SerializableConfiguration, Utils} -import scala.util.control.NonFatal private[spark] class PythonRDD( parent: RDD[_], @@ -61,11 +61,39 @@ private[spark] class PythonRDD( if (preservePartitoning) firstParent.partitioner else None } + val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val runner = new PythonRunner( + command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator, + bufferSize, reuse_worker) + runner.compute(firstParent.iterator(split, context), split.index, context) + } +} + + +/** + * A helper class to run Python UDFs in Spark. + */ +private[spark] class PythonRunner( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVer: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]], + bufferSize: Int, + reuse_worker: Boolean) + extends Logging { + + def compute( + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get - val localdir = env.blockManager.diskBlockManager.localDirs.map( - f => f.getPath()).mkString(",") + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread if (reuse_worker) { envVars.put("SPARK_REUSE_WORKER", "1") @@ -75,7 +103,7 @@ private[spark] class PythonRDD( @volatile var released = false // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, split, context) + val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() @@ -183,13 +211,16 @@ private[spark] class PythonRDD( new InterruptibleIterator(context, stdoutIterator) } - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) - /** * The thread responsible for writing the data from the PythonRDD's parent iterator to the * Python process. */ - class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext) + class WriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext) extends Thread(s"stdout writer for $pythonExec") { @volatile private var _exception: Exception = null @@ -211,11 +242,11 @@ private[spark] class PythonRDD( val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index - dataOut.writeInt(split.index) + dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) + PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.size()) for (include <- pythonIncludes.asScala) { @@ -246,7 +277,7 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() @@ -327,7 +358,8 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() - private def getWorkerBroadcasts(worker: Socket) = { + + def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index d0411da6fdf5a..c35c726bfc503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil} +import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Accumulator, Logging => SparkLogging} +import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator} /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. @@ -329,7 +329,13 @@ case class EvaluatePython( /** * :: DeveloperApi :: * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. - * The input data is zipped with the result of the udf evaluation. + * + * Python evaluation works by sending the necessary (projected) input data via a socket to an + * external Python process, and combine the result from the Python process with the original row. + * + * For each row we send to Python, we also put it in a queue. For each output row from Python, + * we drain the queue to find the original input row. Note that if the Python process is way too + * slow, this could lead to the queue growing unbounded and eventually run out of memory. */ @DeveloperApi case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) @@ -342,51 +348,57 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = { - val childResults = child.execute().map(_.copy()) + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - val parent = childResults.mapPartitions { iter => + inputRDD.mapPartitions { iter => EvaluatePython.registerPicklers() // register pickler for Row + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + val pickle = new Pickler val currentRow = newMutableProjection(udf.children, child.output)() val fields = udf.children.map(_.dataType) val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) - iter.grouped(100).map { inputRows => + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val inputIterator = iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { row => + queue.add(row) EvaluatePython.toJava(currentRow(row), schema) }.toArray pickle.dumps(toBePickled) } - } - val pyRDD = new PythonRDD( - parent, - udf.command, - udf.envVars, - udf.pythonIncludes, - false, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator - ).mapPartitions { iter => - val pickle = new Unpickler - iter.flatMap { pickedResult => - val unpickledBatch = pickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - } - }.mapPartitions { iter => + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = new PythonRunner( + udf.command, + udf.envVars, + udf.pythonIncludes, + udf.pythonExec, + udf.pythonVer, + udf.broadcastVars, + udf.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler val row = new GenericMutableRow(1) - iter.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) - row: InternalRow - } - } + val joined = new JoinedRow - childResults.zip(pyRDD).mapPartitions { iter => - val joinedRow = new JoinedRow() - iter.map { - case (row, udfResult) => - joinedRow(row, udfResult) + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + row(0) = EvaluatePython.fromJava(result, udf.dataType) + joined(queue.poll(), row) } } } From e883d2122be70a3771381a1d08ae0e64528d21ca Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 22 Sep 2015 16:35:43 -0700 Subject: [PATCH 0109/1168] [SPARK-10640] History server fails to parse TaskCommitDenied ... simply because the code is missing! Author: Andrew Or Closes #8828 from andrewor14/task-end-reason-json. --- .../scala/org/apache/spark/TaskEndReason.scala | 6 +++++- .../org/apache/spark/util/JsonProtocol.scala | 13 +++++++++++++ .../apache/spark/util/JsonProtocolSuite.scala | 17 +++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 7137246bc34f2..9335c5f4160bf 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,13 +17,17 @@ package org.apache.spark -import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils +// ============================================================================================== +// NOTE: new task end reasons MUST be accompanied with serialization logic in util.JsonProtocol! +// ============================================================================================== + /** * :: DeveloperApi :: * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 99614a786bd93..40729fa5a4ffe 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -362,6 +362,10 @@ private[spark] object JsonProtocol { ("Stack Trace" -> stackTrace) ~ ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ ("Metrics" -> metrics) + case taskCommitDenied: TaskCommitDenied => + ("Job ID" -> taskCommitDenied.jobID) ~ + ("Partition ID" -> taskCommitDenied.partitionID) ~ + ("Attempt Number" -> taskCommitDenied.attemptNumber) case ExecutorLostFailure(executorId, isNormalExit) => ("Executor ID" -> executorId) ~ ("Normal Exit" -> isNormalExit) @@ -770,6 +774,7 @@ private[spark] object JsonProtocol { val exceptionFailure = Utils.getFormattedClassName(ExceptionFailure) val taskResultLost = Utils.getFormattedClassName(TaskResultLost) val taskKilled = Utils.getFormattedClassName(TaskKilled) + val taskCommitDenied = Utils.getFormattedClassName(TaskCommitDenied) val executorLostFailure = Utils.getFormattedClassName(ExecutorLostFailure) val unknownReason = Utils.getFormattedClassName(UnknownReason) @@ -794,6 +799,14 @@ private[spark] object JsonProtocol { ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled + case `taskCommitDenied` => + // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON + // de/serialization logic was not added until 1.5.1. To provide backward compatibility + // for reading those logs, we need to provide default values for all the fields. + val jobId = Utils.jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) + val partitionId = Utils.jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) + val attemptNo = Utils.jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) + TaskCommitDenied(jobId, partitionId, attemptNo) case `executorLostFailure` => val isNormalExit = Utils.jsonOption(json \ "Normal Exit"). map(_.extract[Boolean]) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 143c1b901df11..a24bf2931cca0 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -151,6 +151,7 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) + testTaskEndReason(TaskCommitDenied(2, 3, 4)) testTaskEndReason(ExecutorLostFailure("100", true)) testTaskEndReason(UnknownReason) @@ -352,6 +353,17 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(expectedStageInfo, JsonProtocol.stageInfoFromJson(oldStageInfo)) } + // `TaskCommitDenied` was added in 1.3.0 but JSON de/serialization logic was added in 1.5.1 + test("TaskCommitDenied backward compatibility") { + val denied = TaskCommitDenied(1, 2, 3) + val oldDenied = JsonProtocol.taskEndReasonToJson(denied) + .removeField({ _._1 == "Job ID" }) + .removeField({ _._1 == "Partition ID" }) + .removeField({ _._1 == "Attempt Number" }) + val expectedDenied = TaskCommitDenied(-1, -1, -1) + assertEquals(expectedDenied, JsonProtocol.taskEndReasonFromJson(oldDenied)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -577,6 +589,11 @@ class JsonProtocolSuite extends SparkFunSuite { assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => + case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1), + TaskCommitDenied(jobId2, partitionId2, attemptNumber2)) => + assert(jobId1 === jobId2) + assert(partitionId1 === partitionId2) + assert(attemptNumber1 === attemptNumber2) case (ExecutorLostFailure(execId1, isNormalExit1), ExecutorLostFailure(execId2, isNormalExit2)) => assert(execId1 === execId2) From e98bcc9f47f850302e117090055d93bd2c4eb430 Mon Sep 17 00:00:00 2001 From: Zhichao Li Date: Tue, 22 Sep 2015 19:41:57 -0700 Subject: [PATCH 0110/1168] [SPARK-10310] [SQL] Fixes script transformation field/line delimiters **Please attribute this PR to `Zhichao Li `.** This PR is based on PR #8476 authored by zhichao-li. It fixes SPARK-10310 by adding field delimiter SerDe property to the default `LazySimpleSerDe`, and enabling default record reader/writer classes. Currently, we only support `LazySimpleSerDe`, used together with `TextRecordReader` and `TextRecordWriter`, and don't support customizing record reader/writer using `RECORDREADER`/`RECORDWRITER` clauses. This should be addressed in separate PR(s). Author: Cheng Lian Closes #8860 from liancheng/spark-10310/fix-script-trans-delimiters. --- .../org/apache/spark/sql/hive/HiveQl.scala | 52 ++++++++++--- .../hive/execution/ScriptTransformation.scala | 75 +++++++++++++++---- .../resources/data/scripts/test_transform.py | 6 ++ .../sql/hive/execution/SQLQuerySuite.scala | 39 ++++++++++ .../execution/ScriptTransformationSuite.scala | 2 + 5 files changed, 152 insertions(+), 22 deletions(-) create mode 100755 sql/hive/src/test/resources/data/scripts/test_transform.py diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index d5cd7e98b5267..256440a9a2e97 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException @@ -884,16 +885,22 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C AttributeReference("value", StringType)()), true) } - def matchSerDe(clause: Seq[ASTNode]) - : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match { + type SerDeInfo = ( + Seq[(String, String)], // Input row format information + Option[String], // Optional input SerDe class + Seq[(String, String)], // Input SerDe properties + Boolean // Whether to use default record reader/writer + ) + + def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) } - (rowFormat, None, Nil) + (rowFormat, None, Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", @@ -903,20 +910,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C (BaseSemanticAnalyzer.unescapeSQLString(name), BaseSemanticAnalyzer.unescapeSQLString(value)) } - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), Nil) + // SPARK-10310: Special cases LazySimpleSerDe + // TODO Fully supports user-defined record reader/writer classes + val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass) + val useDefaultRecordReaderWriter = + unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName + (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) + + case Nil => + // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here + val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") + (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) } - val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) - val (outRowFormat, outSerdeClass, outSerdeProps) = matchSerDe(outputSerdeClause) + val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = + matchSerDe(inputSerdeClause) + + val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = + matchSerDe(outputSerdeClause) val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) + // TODO Adds support for user-defined record reader/writer classes + val recordReaderClass = if (useDefaultRecordReader) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) + } else { + None + } + + val recordWriterClass = if (useDefaultRecordWriter) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) + } else { + None + } + val schema = HiveScriptIOSchema( inRowFormat, outRowFormat, inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, schemaLess) + inSerdeProps, outSerdeProps, + recordReaderClass, recordWriterClass, + schemaLess) Some( logical.ScriptTransformation( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 32bddbaeaeaf9..b30117f0de997 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -24,20 +24,22 @@ import javax.annotation.Nullable import scala.collection.JavaConverters._ import scala.util.control.NonFatal +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.io.Writable import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} import org.apache.spark.{Logging, TaskContext} /** @@ -58,6 +60,8 @@ case class ScriptTransformation( override def otherCopyArgs: Seq[HiveContext] = sc :: Nil + private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf) + protected override def doExecute(): RDD[InternalRow] = { def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) @@ -67,6 +71,7 @@ case class ScriptTransformation( val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val errorStream = proc.getErrorStream + val localHiveConf = serializedHiveConf.value // In order to avoid deadlocks, we need to consume the error output of the child process. // To avoid issues caused by large error output, we use a circular buffer to limit the amount @@ -96,7 +101,8 @@ case class ScriptTransformation( outputStream, proc, stderrBuffer, - TaskContext.get() + TaskContext.get(), + localHiveConf ) // This nullability is a performance optimization in order to avoid an Option.foreach() call @@ -109,6 +115,10 @@ case class ScriptTransformation( val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var curLine: String = null val scriptOutputStream = new DataInputStream(inputStream) + + @Nullable val scriptOutputReader = + ioschema.recordReader(scriptOutputStream, localHiveConf).orNull + var scriptOutputWritable: Writable = null val reusedWritableObject: Writable = if (null != outputSerde) { outputSerde.getSerializedClass().newInstance @@ -134,15 +144,25 @@ case class ScriptTransformation( } } else if (scriptOutputWritable == null) { scriptOutputWritable = reusedWritableObject - try { - scriptOutputWritable.readFields(scriptOutputStream) - true - } catch { - case _: EOFException => - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } + + if (scriptOutputReader != null) { + if (scriptOutputReader.next(scriptOutputWritable) <= 0) { + writerThread.exception.foreach(throw _) false + } else { + true + } + } else { + try { + scriptOutputWritable.readFields(scriptOutputStream) + true + } catch { + case _: EOFException => + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } } } else { true @@ -210,7 +230,8 @@ private class ScriptTransformationWriterThread( outputStream: OutputStream, proc: Process, stderrBuffer: CircularBuffer, - taskContext: TaskContext + taskContext: TaskContext, + conf: Configuration ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { setDaemon(true) @@ -224,6 +245,7 @@ private class ScriptTransformationWriterThread( TaskContext.setTaskContext(taskContext) val dataOutputStream = new DataOutputStream(outputStream) + @Nullable val scriptInputWriter = ioschema.recordWriter(dataOutputStream, conf).orNull // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception @@ -250,7 +272,12 @@ private class ScriptTransformationWriterThread( } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) + + if (scriptInputWriter != null) { + scriptInputWriter.write(writable) + } else { + prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) + } } } outputStream.close() @@ -290,6 +317,8 @@ case class HiveScriptIOSchema ( outputSerdeClass: Option[String], inputSerdeProps: Seq[(String, String)], outputSerdeProps: Seq[(String, String)], + recordReaderClass: Option[String], + recordWriterClass: Option[String], schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors { private val defaultFormat = Map( @@ -347,4 +376,24 @@ case class HiveScriptIOSchema ( serde } + + def recordReader( + inputStream: InputStream, + conf: Configuration): Option[RecordReader] = { + recordReaderClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader] + val props = new Properties() + props.putAll(outputSerdeProps.toMap.asJava) + instance.initialize(inputStream, conf, props) + instance + } + } + + def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { + recordWriterClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter] + instance.initialize(outputStream, conf) + instance + } + } } diff --git a/sql/hive/src/test/resources/data/scripts/test_transform.py b/sql/hive/src/test/resources/data/scripts/test_transform.py new file mode 100755 index 0000000000000..ac6d11d8b919c --- /dev/null +++ b/sql/hive/src/test/resources/data/scripts/test_transform.py @@ -0,0 +1,6 @@ +import sys + +delim = sys.argv[1] + +for row in sys.stdin: + print(delim.join([w + '#' for w in row[:-1].split(delim)])) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index bb02473dd17ca..71823e32ad389 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1184,4 +1184,43 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(df, Row("text inside layer 2") :: Nil) } + + test("SPARK-10310: " + + "script transformation using default input/output SerDe and record reader/writer") { + sqlContext + .range(5) + .selectExpr("id AS a", "id AS b") + .registerTempTable("test") + + checkAnswer( + sql( + """FROM( + | FROM test SELECT TRANSFORM(a, b) + | USING 'python src/test/resources/data/scripts/test_transform.py "\t"' + | AS (c STRING, d STRING) + |) t + |SELECT c + """.stripMargin), + (0 until 5).map(i => Row(i + "#"))) + } + + test("SPARK-10310: script transformation using LazySimpleSerDe") { + sqlContext + .range(5) + .selectExpr("id AS a", "id AS b") + .registerTempTable("test") + + val df = sql( + """FROM test + |SELECT TRANSFORM(a, b) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + |USING 'python src/test/resources/data/scripts/test_transform.py "|"' + |AS (c STRING, d STRING) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + """.stripMargin) + + checkAnswer(df, (0 until 5).map(i => Row(i + "#", i + "#"))) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index cb8d0fca8e693..7cfdb886b585d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -38,6 +38,8 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { outputSerdeClass = None, inputSerdeProps = Seq.empty, outputSerdeProps = Seq.empty, + recordReaderClass = None, + recordWriterClass = None, schemaLess = false ) From dc68a6a7aeec6c92f0f3846bbd86108e750c90e8 Mon Sep 17 00:00:00 2001 From: Matt Hagen Date: Tue, 22 Sep 2015 21:14:25 -0700 Subject: [PATCH 0111/1168] [SPARK-10663] Removed unnecessary invocation of DataFrame.toDF method. The Scala example under the "Example: Pipeline" heading in this document initializes the "test" variable to a DataFrame. Because test is already a DF, there is not need to call test.toDF as the example does in a subsequent line: model.transform(test.toDF). So, I removed the extraneous toDF invocation. Author: Matt Hagen Closes #8875 from hagenhaus/SPARK-10663. --- docs/ml-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 0427ac6695aa1..fd3a6167bc65e 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -475,7 +475,7 @@ val test = sqlContext.createDataFrame(Seq( )).toDF("id", "text") // Make predictions on test documents. -model.transform(test.toDF) +model.transform(test) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => From 3089563148cf7538e8d10dbe17700c413b4be9ed Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Sep 2015 22:44:09 -0700 Subject: [PATCH 0112/1168] [SPARK-10652] [SPARK-10742] [STREAMING] Set meaningful job descriptions for all streaming jobs Here is the screenshot after adding the job descriptions to threads that run receivers and the scheduler thread running the batch jobs. ## All jobs page * Added job descriptions with links to relevant batch details page ![image](https://cloud.githubusercontent.com/assets/663212/9924165/cda4a372-5cb1-11e5-91ca-d43a32c699e9.png) ## All stages page * Added stage descriptions with links to relevant batch details page ![image](https://cloud.githubusercontent.com/assets/663212/9923814/2cce266a-5cae-11e5-8a3f-dad84d06c50e.png) ## Streaming batch details page * Added the +details link ![image](https://cloud.githubusercontent.com/assets/663212/9921977/24014a32-5c98-11e5-958e-457b6c38065b.png) Author: Tathagata Das Closes #8791 from tdas/SPARK-10652. --- .../scala/org/apache/spark/ui/UIUtils.scala | 62 ++++++++++++++++- .../apache/spark/ui/jobs/AllJobsPage.scala | 14 ++-- .../org/apache/spark/ui/jobs/StageTable.scala | 7 +- .../org/apache/spark/ui/UIUtilsSuite.scala | 66 +++++++++++++++++++ .../spark/streaming/StreamingContext.scala | 4 +- .../streaming/scheduler/JobScheduler.scala | 15 ++++- .../streaming/scheduler/ReceiverTracker.scala | 5 +- .../apache/spark/streaming/ui/BatchPage.scala | 33 ++++++---- .../streaming/StreamingContextSuite.scala | 2 +- 9 files changed, 179 insertions(+), 29 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index f2da417724104..21dc8f0b65485 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -18,9 +18,11 @@ package org.apache.spark.ui import java.text.SimpleDateFormat -import java.util.{Locale, Date} +import java.util.{Date, Locale} -import scala.xml.{Node, Text, Unparsed} +import scala.util.control.NonFatal +import scala.xml._ +import scala.xml.transform.{RewriteRule, RuleTransformer} import org.apache.spark.Logging import org.apache.spark.ui.scope.RDDOperationGraph @@ -395,4 +397,60 @@ private[spark] object UIUtils extends Logging { } + /** + * Returns HTML rendering of a job or stage description. It will try to parse the string as HTML + * and make sure that it only contains anchors with root-relative links. Otherwise, + * the whole string will rendered as a simple escaped text. + * + * Note: In terms of security, only anchor tags with root relative links are supported. So any + * attempts to embed links outside Spark UI, or other tags like } private def createExecutorTable() : Seq[Node] = { From e3441d1b48d869ed411db2a98e5e0bb1ddd0c49c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 18 Nov 2015 08:59:20 +0000 Subject: [PATCH 0786/1168] [SPARK-11652][CORE] Remote code execution with InvokerTransformer Update to Commons Collections 3.2.2 to avoid any potential remote code execution vulnerability Author: Sean Owen Closes #9731 from srowen/SPARK-11652. --- pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pom.xml b/pom.xml index 940e2d8740bf1..ad849112ce76c 100644 --- a/pom.xml +++ b/pom.xml @@ -162,6 +162,8 @@ 3.1 3.4.1 + + 3.2.2 2.10.5 2.10 ${scala.version} @@ -475,6 +477,11 @@ commons-math3 ${commons.math3.version} + + org.apache.commons + commons-collections + ${commons.collections.version} + org.apache.ivy ivy From cc63e782dd7258d9495659b26ddf6defb0713c74 Mon Sep 17 00:00:00 2001 From: Viveka Kulharia Date: Wed, 18 Nov 2015 09:10:15 +0000 Subject: [PATCH 0787/1168] rmse was wrongly calculated It was multiplying with U instaed of dividing by U Author: Viveka Kulharia Closes #9771 from vivkul/patch-1. --- examples/src/main/python/als.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 1c3a787bd0e94..205ca02962bee 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -36,7 +36,7 @@ def rmse(R, ms, us): diff = R - ms * us.T - return np.sqrt(np.sum(np.power(diff, 2)) / M * U) + return np.sqrt(np.sum(np.power(diff, 2)) / (M * U)) def update(i, vec, mat, ratings): From a1985e7d413994dc69e29e715e6e2e82946ea948 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 18 Nov 2015 08:18:54 -0800 Subject: [PATCH 0788/1168] =?UTF-8?q?[SPARK-11804]=20[PYSPARK]=20Exception?= =?UTF-8?q?=20raise=20when=20using=20Jdbc=20predicates=20opt=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ion in PySpark Author: Jeff Zhang Closes #9791 from zjffdu/SPARK-11804. --- python/pyspark/sql/readwriter.py | 10 +++++----- python/pyspark/sql/utils.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 7b8ddb9feba34..e8f0d7ec77035 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -26,6 +26,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * +from pyspark.sql import utils __all__ = ["DataFrameReader", "DataFrameWriter"] @@ -131,9 +132,7 @@ def load(self, path=None, format=None, schema=None, **options): if type(path) == list: paths = path gateway = self._sqlContext._sc._gateway - jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) - for i in range(0, len(paths)): - jpaths[i] = paths[i] + jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths) return self._df(self._jreader.load(jpaths)) else: return self._df(self._jreader.load(path)) @@ -269,8 +268,9 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound), int(numPartitions), jprop)) if predicates is not None: - arr = self._sqlContext._sc._jvm.PythonUtils.toArray(predicates) - return self._df(self._jreader.jdbc(url, table, arr, jprop)) + gateway = self._sqlContext._sc._gateway + jpredicates = utils.toJArray(gateway, gateway.jvm.java.lang.String, predicates) + return self._df(self._jreader.jdbc(url, table, jpredicates, jprop)) return self._df(self._jreader.jdbc(url, table, jprop)) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index c4fda8bd3b891..b0a0373372d20 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -71,3 +71,16 @@ def install_exception_handler(): patched = capture_sql_exception(original) # only patch the one used in in py4j.java_gateway (call Java API) py4j.java_gateway.get_return_value = patched + + +def toJArray(gateway, jtype, arr): + """ + Convert python list to java type array + :param gateway: Py4j Gateway + :param jtype: java type of element in array + :param arr: python type list + """ + jarr = gateway.new_array(jtype, len(arr)) + for i in range(0, len(arr)): + jarr[i] = arr[i] + return jarr From de9f3c7d1f043f6abb5399eee46bfdedb9c2b0c1 Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 18 Nov 2015 08:32:03 -0800 Subject: [PATCH 0789/1168] [SPARK-11281][SPARKR] Add tests covering the issue. The goal of this PR is to add tests covering the issue to ensure that is was resolved by [SPARK-11086](https://issues.apache.org/jira/browse/SPARK-11086). Author: zero323 Closes #9743 from zero323/SPARK-11281-tests. --- R/pkg/inst/tests/test_sparkSQL.R | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 8ff06276599e2..87ab33f6384b1 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -229,7 +229,7 @@ test_that("create DataFrame from list or data.frame", { df <- createDataFrame(sqlContext, l, c("a", "b")) expect_equal(columns(df), c("a", "b")) - l <- list(list(a=1, b=2), list(a=3, b=4)) + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) df <- createDataFrame(sqlContext, l) expect_equal(columns(df), c("a", "b")) @@ -292,11 +292,15 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { - ldf <- data.frame(row.names=1:2) + ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) + ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) + sdf <- createDataFrame(sqlContext, ldf) + collected <- collect(sdf) - expect_equivalent(ldf, collect(sdf)) + expect_identical(ldf[, 1, FALSE], collected[, 1, FALSE]) + expect_equal(ldf$an_envir, collected$an_envir) }) # For test map type and struct type in DataFrame From 194563bbab7604d73bc4899a0b24360f2cd5682f Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 18 Nov 2015 08:41:45 -0800 Subject: [PATCH 0790/1168] [SPARK-11773][SPARKR] Implement collection functions in SparkR. Author: Sun Rui Closes #9764 from sun-rui/SPARK-11773. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 2 +- R/pkg/R/functions.R | 109 ++++++++++++++++++++++--------- R/pkg/R/generics.R | 10 ++- R/pkg/R/utils.R | 2 +- R/pkg/inst/tests/test_sparkSQL.R | 10 +++ 6 files changed, 100 insertions(+), 35 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 2ee7d6f94f1bc..260c9edce62e0 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -98,6 +98,7 @@ exportMethods("%in%", "add_months", "alias", "approxCountDistinct", + "array_contains", "asc", "ascii", "asin", @@ -215,6 +216,7 @@ exportMethods("%in%", "sinh", "size", "skewness", + "sort_array", "soundex", "stddev", "stddev_pop", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fd105ba5bc9bb..34177e3cdd94f 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2198,4 +2198,4 @@ setMethod("coltypes", rTypes[naIndices] <- types[naIndices] rTypes - }) \ No newline at end of file + }) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 3d0255a62f155..ff0f438045c14 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -373,22 +373,6 @@ setMethod("exp", column(jc) }) -#' explode -#' -#' Creates a new row for each element in the given array or map column. -#' -#' @rdname explode -#' @name explode -#' @family collection_funcs -#' @export -#' @examples \dontrun{explode(df$c)} -setMethod("explode", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) - column(jc) - }) - #' expm1 #' #' Computes the exponential of the given value minus one. @@ -980,22 +964,6 @@ setMethod("sinh", column(jc) }) -#' size -#' -#' Returns length of array or map. -#' -#' @rdname size -#' @name size -#' @family collection_funcs -#' @export -#' @examples \dontrun{size(df$c)} -setMethod("size", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) - column(jc) - }) - #' skewness #' #' Aggregate function: returns the skewness of the values in a group. @@ -2365,3 +2333,80 @@ setMethod("rowNumber", jc <- callJStatic("org.apache.spark.sql.functions", "rowNumber") column(jc) }) + +###################### Collection functions###################### + +#' array_contains +#' +#' Returns true if the array contain the value. +#' +#' @param x A Column +#' @param value A value to be checked if contained in the column +#' @rdname array_contains +#' @name array_contains +#' @family collection_funcs +#' @export +#' @examples \dontrun{array_contains(df$c, 1)} +setMethod("array_contains", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_contains", x@jc, value) + column(jc) + }) + +#' explode +#' +#' Creates a new row for each element in the given array or map column. +#' +#' @rdname explode +#' @name explode +#' @family collection_funcs +#' @export +#' @examples \dontrun{explode(df$c)} +setMethod("explode", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) + column(jc) + }) + +#' size +#' +#' Returns length of array or map. +#' +#' @rdname size +#' @name size +#' @family collection_funcs +#' @export +#' @examples \dontrun{size(df$c)} +setMethod("size", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) + column(jc) + }) + +#' sort_array +#' +#' Sorts the input array for the given column in ascending order, +#' according to the natural ordering of the array elements. +#' +#' @param x A Column to sort +#' @param asc A logical flag indicating the sorting order. +#' TRUE, sorting is in ascending order. +#' FALSE, sorting is in descending order. +#' @rdname sort_array +#' @name sort_array +#' @family collection_funcs +#' @export +#' @examples +#' \dontrun{ +#' sort_array(df$c) +#' sort_array(df$c, FALSE) +#' } +setMethod("sort_array", + signature(x = "Column"), + function(x, asc = TRUE) { + jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index afdeffc2abd83..0dcd05438222b 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -644,6 +644,10 @@ setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @export setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) +#' @rdname array_contains +#' @export +setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) + #' @rdname ascii #' @export setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -961,6 +965,10 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @export setGeneric("skewness", function(x) { standardGeneric("skewness") }) +#' @rdname sort_array +#' @export +setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) + #' @rdname soundex #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) @@ -1076,4 +1084,4 @@ setGeneric("with") #' @rdname coltypes #' @export -setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) \ No newline at end of file +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index db3b2c4bbd799..45c77a86c9582 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -635,4 +635,4 @@ assignNewEnv <- function(data) { assign(x = cols[i], value = data[, cols[i]], envir = env) } env -} \ No newline at end of file +} diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 87ab33f6384b1..d9a94faff7ac0 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -878,6 +878,16 @@ test_that("column functions", { df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") + + # Test array_contains() and sort_array() + df <- createDataFrame(sqlContext, list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) + result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] + expect_equal(result, c(TRUE, FALSE)) + + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] + expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) + result <- collect(select(df, sort_array(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) }) # test_that("column binary mathfunctions", { From f6a1ce212d678285936830410b5fe889702964a7 Mon Sep 17 00:00:00 2001 From: Hurshal Patel Date: Wed, 18 Nov 2015 09:28:59 -0800 Subject: [PATCH 0791/1168] [SPARK-11195][CORE] Use correct classloader for TaskResultGetter Make sure we are using the context classloader when deserializing failed TaskResults instead of the Spark classloader. The issue is that `enqueueFailedTask` was using the incorrect classloader which results in `ClassNotFoundException`. Adds a test in TaskResultGetterSuite that compiles a custom exception, throws it on the executor, and asserts that Spark handles the TaskResult deserialization instead of returning `UnknownReason`. See #9367 for previous comments See SPARK-11195 for a full repro Author: Hurshal Patel Closes #9779 from choochootrain/spark-11195-master. --- .../scala/org/apache/spark/TestUtils.scala | 11 ++-- .../spark/scheduler/TaskResultGetter.scala | 4 +- .../scheduler/TaskResultGetterSuite.scala | 65 ++++++++++++++++++- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index acfe751f6c746..43c89b258f2fa 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} import java.nio.charset.StandardCharsets +import java.nio.file.Paths import java.util.Arrays import java.util.jar.{JarEntry, JarOutputStream} @@ -83,15 +84,15 @@ private[spark] object TestUtils { } /** - * Create a jar file that contains this set of files. All files will be located at the root - * of the jar. + * Create a jar file that contains this set of files. All files will be located in the specified + * directory or at the root of the jar. */ - def createJar(files: Seq[File], jarFile: File): URL = { + def createJar(files: Seq[File], jarFile: File, directoryPrefix: Option[String] = None): URL = { val jarFileStream = new FileOutputStream(jarFile) val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) for (file <- files) { - val jarEntry = new JarEntry(file.getName) + val jarEntry = new JarEntry(Paths.get(directoryPrefix.getOrElse(""), file.getName).toString) jarStream.putNextEntry(jarEntry) val in = new FileInputStream(file) @@ -123,7 +124,7 @@ private[spark] object TestUtils { classpathUrls: Seq[URL]): File = { val compiler = ToolProvider.getSystemJavaCompiler - // Calling this outputs a class file in pwd. It's easier to just rename the file than + // Calling this outputs a class file in pwd. It's easier to just rename the files than // build a custom FileManager that controls the output location. val options = if (classpathUrls.nonEmpty) { Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 46a6f6537e2ee..f4965994d8277 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -103,16 +103,16 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul try { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { + val loader = Utils.getContextOrSparkClassLoader try { if (serializedData != null && serializedData.limit() > 0) { reason = serializer.get().deserialize[TaskEndReason]( - serializedData, Utils.getSparkClassLoader) + serializedData, loader) } } catch { case cnd: ClassNotFoundException => // Log an error but keep going here -- the task failed, so not catastrophic // if we can't deserialize the reason. - val loader = Utils.getContextOrSparkClassLoader logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) case ex: Exception => {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 815caa79ff529..bc72c3685e8c1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.io.File +import java.net.URL import java.nio.ByteBuffer import scala.concurrent.duration._ @@ -26,8 +28,10 @@ import scala.util.control.NonFatal import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.storage.TaskResultBlockId +import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. @@ -119,5 +123,64 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local // Make sure two tasks were run (one failed one, and a second retried one). assert(scheduler.nextTaskId.get() === 2) } + + /** + * Make sure we are using the context classloader when deserializing failed TaskResults instead + * of the Spark classloader. + + * This test compiles a jar containing an exception and tests that when it is thrown on the + * executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown + * exception as the cause. + + * Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing + * the exception, resulting in an UnknownReason for the TaskEndResult. + */ + test("failed task deserialized with the correct classloader (SPARK-11195)") { + // compile a small jar containing an exception that will be thrown on an executor. + val tempDir = Utils.createTempDir() + val srcDir = new File(tempDir, "repro/") + srcDir.mkdirs() + val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath, + """package repro; + | + |public class MyException extends Exception { + |} + """.stripMargin) + val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty) + val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) + TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro")) + + // ensure we reset the classloader after the test completes + val originalClassLoader = Thread.currentThread.getContextClassLoader + try { + // load the exception from the jar + val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) + loader.addURL(jarFile.toURI.toURL) + Thread.currentThread().setContextClassLoader(loader) + val excClass: Class[_] = Utils.classForName("repro.MyException") + + // NOTE: we must run the cluster with "local" so that the executor can load the compiled + // jar. + sc = new SparkContext("local", "test", conf) + val rdd = sc.parallelize(Seq(1), 1).map { _ => + val exc = excClass.newInstance().asInstanceOf[Exception] + throw exc + } + + // the driver should not have any problems resolving the exception class and determining + // why the task failed. + val exceptionMessage = intercept[SparkException] { + rdd.collect() + }.getMessage + + val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r + val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r + + assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined) + assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty) + } finally { + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } } From 0e73211489f3aa1cc1c981d82b206bc272257272 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:15:50 -0800 Subject: [PATCH 0792/1168] [SPARK-11803][SQL] fix Dataset self-join When we resolve the join operator, we may change the output of right side if self-join is detected. So in `Dataset.joinWith`, we should resolve the join operator first, and then get the left output and right output from it, instead of using `left.output` and `right.output` directly. Author: Wenchen Fan Closes #9806 from cloud-fan/self-join. --- .../main/scala/org/apache/spark/sql/Dataset.scala | 14 +++++++++----- .../scala/org/apache/spark/sql/DatasetSuite.scala | 8 ++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 817c20fdbb9f3..b644f6ad3096d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -498,13 +498,17 @@ class Dataset[T] private[sql]( val left = this.logicalPlan val right = other.logicalPlan + val joined = sqlContext.executePlan(Join(left, right, Inner, Some(condition.expr))) + val leftOutput = joined.analyzed.output.take(left.output.length) + val rightOutput = joined.analyzed.output.takeRight(right.output.length) + val leftData = this.unresolvedTEncoder match { - case e if e.flat => Alias(left.output.head, "_1")() - case _ => Alias(CreateStruct(left.output), "_1")() + case e if e.flat => Alias(leftOutput.head, "_1")() + case _ => Alias(CreateStruct(leftOutput), "_1")() } val rightData = other.unresolvedTEncoder match { - case e if e.flat => Alias(right.output.head, "_2")() - case _ => Alias(CreateStruct(right.output), "_2")() + case e if e.flat => Alias(rightOutput.head, "_2")() + case _ => Alias(CreateStruct(rightOutput), "_2")() } @@ -513,7 +517,7 @@ class Dataset[T] private[sql]( withPlan[(T, U)](other) { (left, right) => Project( leftData :: rightData :: Nil, - Join(left, right, Inner, Some(condition.expr))) + joined.analyzed) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a522894c374f9..198962b8fb750 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -347,7 +347,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(joined, ("2", 2)) } - ignore("self join") { + test("self join") { val ds = Seq("1", "2").toDS().as("a") val joined = ds.joinWith(ds, lit(true)) checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) @@ -360,15 +360,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] - val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.groupBy(p => p).count().collect().toSeq == Seq((KryoData(1), 1L), (KryoData(2), 1L))) } - ignore("kryo encoder self join") { + test("kryo encoder self join") { implicit val kryoEncoder = Encoders.kryo[KryoData] - val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == Set( (KryoData(1), KryoData(1)), From c9fe4c1ad2026587bcc68de41c6a60a93e8642f6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:23:12 -0800 Subject: [PATCH 0793/1168] [SPARK-11725][SQL] correctly handle null inputs for UDF If user use primitive parameters in UDF, there is no way for him to do the null-check for primitive inputs, so we are assuming the primitive input is null-propagatable for this case and return null if the input is null. Author: Wenchen Fan Closes #9770 from cloud-fan/udf. --- .../spark/sql/catalyst/ScalaReflection.scala | 9 ++++ .../sql/catalyst/analysis/Analyzer.scala | 32 +++++++++++++- .../sql/catalyst/expressions/ScalaUDF.scala | 6 +++ .../sql/catalyst/ScalaReflectionSuite.scala | 17 +++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 44 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++++ 6 files changed, 121 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0b3dd351e38e8..38828e59a2152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -719,6 +719,15 @@ trait ScalaReflection { } } + /** + * Returns classes of input parameters of scala function object. + */ + def getParameterTypes(func: AnyRef): Seq[Class[_]] = { + val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) + assert(methods.length == 1) + methods.head.getParameterTypes + } + def typeOfObject: PartialFunction[Any, DataType] = { // The data type can be determined without ambiguity. case obj: Boolean => BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2f4670b55bdba..f00c451b5981a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} +import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ /** @@ -85,6 +85,8 @@ class Analyzer( extendedResolutionRules : _*), Batch("Nondeterministic", Once, PullOutNondeterministic), + Batch("UDF", Once, + HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, CleanupAliases) ) @@ -1063,6 +1065,34 @@ class Analyzer( Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } } + + /** + * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the + * null check. When user defines a UDF with primitive parameters, there is no way to tell if the + * primitive parameter is null or not, so here we assume the primitive input is null-propagatable + * and we should return null if the input is null. + */ + object HandleNullInputsForUDF extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p // Skip unresolved nodes. + + case plan => plan transformExpressionsUp { + + case udf @ ScalaUDF(func, _, inputs, _) => + val parameterTypes = ScalaReflection.getParameterTypes(func) + assert(parameterTypes.length == inputs.length) + + val inputsNullCheck = parameterTypes.zip(inputs) + // TODO: skip null handling for not-nullable primitive inputs after we can completely + // trust the `nullable` information. + // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } + .filter { case (cls, _) => cls.isPrimitive } + .map { case (_, expr) => IsNull(expr) } + .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) + inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) + } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 3388cc20a9803..03b89221ef2d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType /** * User-defined function. + * @param function The user defined scala function to run. + * Note that if you use primitive parameters, you are not able to check if it is + * null or not, and the UDF will return null for you if the primitive input is + * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself. * @param dataType Return type of function. + * @param children The input expressions of this UDF. + * @param inputTypes The expected input types of this UDF. */ case class ScalaUDF( function: AnyRef, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 3b848cfdf737f..4ea410d492b01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -280,4 +280,21 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(s.fields.map(_.dataType) === Seq(IntegerType, StringType, DoubleType)) } } + + test("get parameter type from a function object") { + val primitiveFunc = (i: Int, j: Long) => "x" + val primitiveTypes = getParameterTypes(primitiveFunc) + assert(primitiveTypes.forall(_.isPrimitive)) + assert(primitiveTypes === Seq(classOf[Int], classOf[Long])) + + val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x" + val boxedTypes = getParameterTypes(boxedFunc) + assert(boxedTypes.forall(!_.isPrimitive)) + assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long])) + + val anyFunc = (i: Any, j: AnyRef) => "x" + val anyTypes = getParameterTypes(anyFunc) + assert(anyTypes.forall(!_.isPrimitive)) + assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object])) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 65f09b46afae1..08586a97411ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -174,4 +174,48 @@ class AnalysisSuite extends AnalysisTest { ) assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) } + + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val string = testRelation2.output(0) + val double = testRelation2.output(2) + val short = testRelation2.output(4) + val nullResult = Literal.create(null, StringType) + + def checkUDF(udf: Expression, transformed: Expression): Unit = { + checkAnalysis( + Project(Alias(udf, "")() :: Nil, testRelation2), + Project(Alias(transformed, "")() :: Nil, testRelation2) + ) + } + + // non-primitive parameters do not need special null handling + val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil) + val expected1 = udf1 + checkUDF(udf1, expected1) + + // only primitive parameter needs special null handling + val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) + val expected2 = If(IsNull(double), nullResult, udf2) + checkUDF(udf2, expected2) + + // special null handling should apply to all primitive parameters + val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil) + val expected3 = If( + IsNull(short) || IsNull(double), + nullResult, + udf3) + checkUDF(udf3, expected3) + + // we can skip special null handling for primitive parameters that are not nullable + // TODO: this is disabled for now as we can not completely trust `nullable`. + val udf4 = ScalaUDF( + (s: Short, d: Double) => "x", + StringType, + short :: double.withNullability(false) :: Nil) + val expected4 = If( + IsNull(short), + nullResult, + udf4) + // checkUDF(udf4, expected4) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 35cdab50bdec9..5a7f24684d1b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1115,4 +1115,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select(df("*")), Row(1, "a")) checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a")) } + + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val df = Seq( + new java.lang.Integer(22) -> "John", + null.asInstanceOf[java.lang.Integer] -> "Lucy").toDF("age", "name") + + val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { + (i: java.lang.Integer) => if (i == null) null else i * 2 + } + checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(null) :: Nil) + + val primitiveUDF = udf((i: Int) => i * 2) + checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) + } } From f8a4343f5370f0a4fdaf0372387224da29f54eb0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:33:17 -0800 Subject: [PATCH 0794/1168] [SPARK-11795][SQL] combine grouping attributes into a single NamedExpression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit we use `ExpressionEncoder.tuple` to build the result encoder, which assumes the input encoder should point to a struct type field if it’s non-flat. However, our keyEncoder always point to a flat field/fields: `groupingAttributes`, we should combine them into a single `NamedExpression`. Author: Wenchen Fan Closes #9792 from cloud-fan/agg. --- .../main/scala/org/apache/spark/sql/GroupedDataset.scala | 9 +++++++-- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 5 ++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index c66162ee2148a..3f84e22a1025b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -187,7 +187,12 @@ class GroupedDataset[K, T] private[sql]( val namedColumns = columns.map( _.withInputType(resolvedTEncoder, dataAttributes).named) - val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan) + val keyColumn = if (groupingAttributes.length > 1) { + Alias(CreateStruct(groupingAttributes), "key")() + } else { + groupingAttributes.head + } + val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sqlContext, aggregate) new Dataset( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 198962b8fb750..b6db583dfe01f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -84,8 +84,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 2), ("b", 3), ("c", 4)) } - ignore("Dataset should set the resolved encoders internally for maps") { - // TODO: Enable this once we fix SPARK-11793. + test("map and group by with class data") { // We inject a group by here to make sure this test case is future proof // when we implement better pipelining and local execution mode. val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() @@ -94,7 +93,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer( ds, - (ClassData("one", 1), 1L), (ClassData("two", 2), 1L)) + (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) } test("select") { From 12b575c2a9db023b4654aaa997304e8ab8bda3e8 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 18 Nov 2015 11:35:41 -0800 Subject: [PATCH 0795/1168] [MINOR][BUILD] Ignore ensime cache Using ENSIME, I often have `.ensime_cache` polluting my source tree. This PR simply adds the cache directory to `.gitignore` Author: Jakob Odersky Closes #9708 from jodersky/master. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 08f2d8f7543f0..07524bc429e92 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ spark-tests.log streaming-tests.log dependency-reduced-pom.xml .ensime +.ensime_cache/ .ensime_lucene checkpoint derby.log From d038c654f97e07835687c94aae9a7abd1b261e14 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 18 Nov 2015 11:49:12 -0800 Subject: [PATCH 0796/1168] [SPARK-11792] [SQL] [FOLLOW-UP] Change SizeEstimation to KnownSizeEstimation and make estimatedSize return Long instead of Option[Long] https://issues.apache.org/jira/browse/SPARK-11792 The main changes include: * Renaming `SizeEstimation` to `KnownSizeEstimation`. Hopefully this new name has more information. * Making `estimatedSize` return `Long` instead of `Option[Long]`. * In `UnsaveHashedRelation`, `estimatedSize` will delegate the work to `SizeEstimator` if we have not created a `BytesToBytesMap`. Since we will put `UnsaveHashedRelation` to `BlockManager`, it is generally good to let it provide a more accurate size estimation. Also, if we do not put `BytesToBytesMap` directly into `BlockerManager`, I feel it is not really necessary to make `BytesToBytesMap` extends `KnownSizeEstimation`. Author: Yin Huai Closes #9813 from yhuai/SPARK-11792-followup. --- .../org/apache/spark/util/SizeEstimator.scala | 30 ++++++++++--------- .../spark/util/SizeEstimatorSuite.scala | 14 ++------- .../sql/execution/joins/HashedRelation.scala | 12 +++++--- 3 files changed, 26 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index c3a2675ee5f45..09864e3f8392d 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -36,9 +36,14 @@ import org.apache.spark.util.collection.OpenHashSet * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first. * If `estimatedSize` does not return [[None]], [[SizeEstimator]] will use the returned size * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work. + * The difference between a [[KnownSizeEstimation]] and + * [[org.apache.spark.util.collection.SizeTracker]] is that, a + * [[org.apache.spark.util.collection.SizeTracker]] still uses [[SizeEstimator]] to + * estimate the size. However, a [[KnownSizeEstimation]] can provide a better estimation without + * using [[SizeEstimator]]. */ -private[spark] trait SizeEstimation { - def estimatedSize: Option[Long] +private[spark] trait KnownSizeEstimation { + def estimatedSize: Long } /** @@ -209,18 +214,15 @@ object SizeEstimator extends Logging { // the size estimator since it references the whole REPL. Do nothing in this case. In // general all ClassLoaders and Classes will be shared between objects anyway. } else { - val estimatedSize = obj match { - case s: SizeEstimation => s.estimatedSize - case _ => None - } - if (estimatedSize.isDefined) { - state.size += estimatedSize.get - } else { - val classInfo = getClassInfo(cls) - state.size += alignSize(classInfo.shellSize) - for (field <- classInfo.pointerFields) { - state.enqueue(field.get(obj)) - } + obj match { + case s: KnownSizeEstimation => + state.size += s.estimatedSize + case _ => + val classInfo = getClassInfo(cls) + state.size += alignSize(classInfo.shellSize) + for (field <- classInfo.pointerFields) { + state.enqueue(field.get(obj)) + } } } } diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 9b6261af123e6..101610e38014e 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -60,16 +60,10 @@ class DummyString(val arr: Array[Char]) { @transient val hash32: Int = 0 } -class DummyClass8 extends SizeEstimation { +class DummyClass8 extends KnownSizeEstimation { val x: Int = 0 - override def estimatedSize: Option[Long] = Some(2015) -} - -class DummyClass9 extends SizeEstimation { - val x: Int = 0 - - override def estimatedSize: Option[Long] = None + override def estimatedSize: Long = 2015 } class SizeEstimatorSuite @@ -231,9 +225,5 @@ class SizeEstimatorSuite // DummyClass8 provides its size estimation. assertResult(2015)(SizeEstimator.estimate(new DummyClass8)) assertResult(20206)(SizeEstimator.estimate(Array.fill(10)(new DummyClass8))) - - // DummyClass9 does not provide its size estimation. - assertResult(16)(SizeEstimator.estimate(new DummyClass9)) - assertResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass9))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 49ae09bf53782..aebfea5832402 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.MemoryLocation -import org.apache.spark.util.{SizeEstimation, Utils} +import org.apache.spark.util.{SizeEstimator, KnownSizeEstimation, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.{SparkConf, SparkEnv} @@ -190,7 +190,7 @@ private[execution] object HashedRelation { private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) extends HashedRelation - with SizeEstimation + with KnownSizeEstimation with Externalizable { private[joins] def this() = this(null) // Needed for serialization @@ -217,8 +217,12 @@ private[joins] final class UnsafeHashedRelation( } } - override def estimatedSize: Option[Long] = { - Option(binaryMap).map(_.getTotalMemoryConsumption) + override def estimatedSize: Long = { + if (binaryMap != null) { + binaryMap.getTotalMemoryConsumption + } else { + SizeEstimator.estimate(hashTable) + } } override def get(key: InternalRow): Seq[InternalRow] = { From 9b6844bf5b5f58734c8870ca2adc276d7414cfe2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 18 Nov 2015 11:53:28 -0800 Subject: [PATCH 0797/1168] [SPARK-11739][SQL] clear the instantiated SQLContext Currently, if the first SQLContext is not removed after stopping SparkContext, a SQLContext could set there forever. This patch make this more robust. Author: Davies Liu Closes #9706 from davies/clear_context. --- .../scala/org/apache/spark/sql/SQLContext.scala | 17 +++++++++++------ .../spark/sql/MultiSQLContextsSuite.scala | 5 ++--- .../execution/ExchangeCoordinatorSuite.scala | 2 +- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cd1fdc4edb39d..39471d2fb79a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1229,7 +1229,7 @@ class SQLContext private[sql]( // construction of the instance. sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { - SQLContext.clearInstantiatedContext(self) + SQLContext.clearInstantiatedContext() } }) @@ -1270,13 +1270,13 @@ object SQLContext { */ def getOrCreate(sparkContext: SparkContext): SQLContext = { val ctx = activeContext.get() - if (ctx != null) { + if (ctx != null && !ctx.sparkContext.isStopped) { return ctx } synchronized { val ctx = instantiatedContext.get() - if (ctx == null) { + if (ctx == null || ctx.sparkContext.isStopped) { new SQLContext(sparkContext) } else { ctx @@ -1284,12 +1284,17 @@ object SQLContext { } } - private[sql] def clearInstantiatedContext(sqlContext: SQLContext): Unit = { - instantiatedContext.compareAndSet(sqlContext, null) + private[sql] def clearInstantiatedContext(): Unit = { + instantiatedContext.set(null) } private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = { - instantiatedContext.compareAndSet(null, sqlContext) + synchronized { + val ctx = instantiatedContext.get() + if (ctx == null || ctx.sparkContext.isStopped) { + instantiatedContext.set(sqlContext) + } + } } private[sql] def getInstantiatedContextOption(): Option[SQLContext] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala index 0e8fcb6a858b1..34c5c68fd1c18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -31,7 +31,7 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() - originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + SQLContext.clearInstantiatedContext() sparkConf = new SparkConf(false) .setMaster("local[*]") @@ -89,10 +89,9 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { testNewSession(rootSQLContext) testNewSession(rootSQLContext) testCreatingNewSQLContext(allowMultipleSQLContexts) - - SQLContext.clearInstantiatedContext(rootSQLContext) } finally { sc.stop() + SQLContext.clearInstantiatedContext() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 25f2f5caeed15..b96d50a70b85c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -34,7 +34,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() - originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + SQLContext.clearInstantiatedContext() } override protected def afterAll(): Unit = { From e229c91195eebca7c80fa597a0eae2ef47f994a5 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 18 Nov 2015 12:09:54 -0800 Subject: [PATCH 0798/1168] [SPARK-4557][STREAMING] Spark Streaming foreachRDD Java API method should accept a VoidFunction<...> Currently streaming foreachRDD Java API uses a function prototype requiring a return value of null. This PR deprecates the old method and uses VoidFunction to allow for more concise declaration. Also added VoidFunction2 to Java API in order to use in Streaming methods. Unit test is added for using foreachRDD with VoidFunction, and changes have been tested with Java 7 and Java 8 using lambdas. Author: Bryan Cutler Closes #9488 from BryanCutler/foreachRDD-VoidFunction-SPARK-4557. --- .../api/java/function/VoidFunction2.java | 27 ++++++++++++ .../apache/spark/streaming/Java8APISuite.java | 26 ++++++++++++ project/MimaExcludes.scala | 4 ++ .../streaming/api/java/JavaDStreamLike.scala | 24 ++++++++++- .../apache/spark/streaming/JavaAPISuite.java | 41 ++++++++++++++++++- 5 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java new file mode 100644 index 0000000000000..6c576ab678455 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A two-argument function that takes arguments of type T1 and T2 with no return value. + */ +public interface VoidFunction2 extends Serializable { + public void call(T1 v1, T2 v2) throws Exception; +} diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 163ae92c12c6d..4eee97bc89613 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -28,6 +28,7 @@ import org.junit.Assert; import org.junit.Test; +import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; @@ -360,6 +361,31 @@ public void testFlatMap() { assertOrderInvariantEquals(expected, result); } + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sc().accumulator(0); + final Accumulator accumEle = ssc.sc().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(rdd -> { + accumRdd.add(1); + rdd.foreach(x -> accumEle.add(1)); + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD((rdd, time) -> null); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + @Test public void testPairFlatMap() { List> inputData = Arrays.asList( diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eb70d27c34c20..bb45d1bb12146 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -142,6 +142,10 @@ object MimaExcludes { "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD") + ) ++ Seq( + // SPARK-4557 Changed foreachRDD to use VoidFunction + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") ) case v if v.startsWith("1.5") => Seq( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index edfa474677f15..84acec7d8e330 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -27,7 +27,7 @@ import scala.reflect.ClassTag import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaRDDLike} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, _} +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, VoidFunction => JVoidFunction, VoidFunction2 => JVoidFunction2, _} import org.apache.spark.rdd.RDD import org.apache.spark.streaming._ import org.apache.spark.streaming.api.java.JavaDStream._ @@ -308,7 +308,10 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction) */ + @deprecated("Use foreachRDD(foreachFunc: JVoidFunction[R])", "1.6.0") def foreachRDD(foreachFunc: JFunction[R, Void]) { dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) } @@ -316,11 +319,30 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction2) */ + @deprecated("Use foreachRDD(foreachFunc: JVoidFunction2[R, Time])", "1.6.0") def foreachRDD(foreachFunc: JFunction2[R, Time, Void]) { dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) } + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JVoidFunction[R]) { + dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JVoidFunction2[R, Time]) { + dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) + } + /** * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index c5217149224e4..609bb4413b6b1 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -37,7 +37,9 @@ import com.google.common.io.Files; import com.google.common.collect.Sets; +import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -45,7 +47,6 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; -import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -768,6 +769,44 @@ public Iterable call(String x) { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sc().accumulator(0); + final Accumulator accumEle = ssc.sc().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(new VoidFunction>() { + @Override + public void call(JavaRDD rdd) { + accumRdd.add(1); + rdd.foreach(new VoidFunction() { + @Override + public void call(Integer i) { + accumEle.add(1); + } + }); + } + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD(new VoidFunction2, Time>() { + @Override + public void call(JavaRDD rdd, Time time) { + } + }); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + @SuppressWarnings("unchecked") @Test public void testPairFlatMap() { From f90b7fd187d664ee65330d6e8326ba221e1f539d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 12:50:29 -0800 Subject: [PATCH 0799/1168] [SPARK-11809] Switch the default Mesos mode to coarse-grained mode Based on my conversions with people, I believe the consensus is that the coarse-grained mode is more stable and easier to reason about. It is best to use that as the default rather than the more flaky fine-grained mode. Author: Reynold Xin Closes #9795 from rxin/SPARK-11809. --- .../scala/org/apache/spark/SparkContext.scala | 2 +- docs/job-scheduling.md | 2 +- docs/running-on-mesos.md | 27 ++++++++++++------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b5645b08f92d4..ab374cb71286a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2710,7 +2710,7 @@ object SparkContext extends Logging { case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() val scheduler = new TaskSchedulerImpl(sc) - val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false) + val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", defaultValue = true) val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager) diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index a3c34cb6796fa..36327c6efeaf3 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -47,7 +47,7 @@ application is not running tasks on a machine, other applications may run tasks is useful when you expect large numbers of not overly active applications, such as shell sessions from separate users. However, it comes with a risk of less predictable latency, because it may take a while for an application to gain back cores on one node when it has work to do. To use this mode, simply use a -`mesos://` URL without setting `spark.mesos.coarse` to true. +`mesos://` URL and set `spark.mesos.coarse` to false. Note that none of the modes currently provide memory sharing across applications. If you would like to share data this way, we recommend running a single server application that can serve multiple requests by querying diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 5be208cf3461e..a197d0e373027 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -161,21 +161,15 @@ Note that jars or python files that are passed to spark-submit should be URIs re # Mesos Run Modes -Spark can run over Mesos in two modes: "fine-grained" (default) and "coarse-grained". +Spark can run over Mesos in two modes: "coarse-grained" (default) and "fine-grained". -In "fine-grained" mode (default), each Spark task runs as a separate Mesos task. This allows -multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, -where each application gets more or fewer machines as it ramps up and down, but it comes with an -additional overhead in launching each task. This mode may be inappropriate for low-latency -requirements like interactive queries or serving web requests. - -The "coarse-grained" mode will instead launch only *one* long-running Spark task on each Mesos +The "coarse-grained" mode will launch only *one* long-running Spark task on each Mesos machine, and dynamically schedule its own "mini-tasks" within it. The benefit is much lower startup overhead, but at the cost of reserving the Mesos resources for the complete duration of the application. -To run in coarse-grained mode, set the `spark.mesos.coarse` property in your -[SparkConf](configuration.html#spark-properties): +Coarse-grained is the default mode. You can also set `spark.mesos.coarse` property to true +to turn it on explictly in [SparkConf](configuration.html#spark-properties): {% highlight scala %} conf.set("spark.mesos.coarse", "true") @@ -186,6 +180,19 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +In "fine-grained" mode, each Spark task runs as a separate Mesos task. This allows +multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, +where each application gets more or fewer machines as it ramps up and down, but it comes with an +additional overhead in launching each task. This mode may be inappropriate for low-latency +requirements like interactive queries or serving web requests. + +To run in coarse-grained mode, set the `spark.mesos.coarse` property to false in your +[SparkConf](configuration.html#spark-properties): + +{% highlight scala %} +conf.set("spark.mesos.coarse", "false") +{% endhighlight %} + You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. {% highlight scala %} From dc28c548800fc4ebc4fbde806cb34a776b7f4123 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 18 Nov 2015 12:53:22 -0800 Subject: [PATCH 0800/1168] [SPARK-10745][CORE] Separate configs between shuffle and RPC [SPARK-6028](https://issues.apache.org/jira/browse/SPARK-6028) uses network module to implement RPC. However, there are some configurations named with `spark.shuffle` prefix in the network module. This PR refactors them to make sure the user can control them in shuffle and RPC separately. The user can use `spark.rpc.*` to set the configuration for netty RPC. Author: Shixiong Zhu Closes #9481 from zsxwing/SPARK-10745. --- .../spark/deploy/ExternalShuffleService.scala | 3 +- .../netty/NettyBlockTransferService.scala | 2 +- .../network/netty/SparkTransportConf.scala | 12 ++-- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 8 +-- .../mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../shuffle/FileShuffleBlockResolver.scala | 2 +- .../shuffle/IndexShuffleBlockResolver.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 2 +- .../spark/ExternalShuffleServiceSuite.scala | 2 +- .../spark/network/util/TransportConf.java | 65 ++++++++++++++----- .../network/ChunkFetchIntegrationSuite.java | 2 +- .../RequestTimeoutIntegrationSuite.java | 2 +- .../spark/network/RpcIntegrationSuite.java | 2 +- .../org/apache/spark/network/StreamSuite.java | 2 +- .../network/TransportClientFactorySuite.java | 6 +- .../spark/network/sasl/SparkSaslSuite.java | 6 +- .../network/sasl/SaslIntegrationSuite.java | 2 +- .../ExternalShuffleBlockResolverSuite.java | 2 +- .../shuffle/ExternalShuffleCleanupSuite.java | 2 +- .../ExternalShuffleIntegrationSuite.java | 2 +- .../shuffle/ExternalShuffleSecuritySuite.java | 2 +- .../shuffle/RetryingBlockFetcherSuite.java | 2 +- .../network/yarn/YarnShuffleService.java | 2 +- 23 files changed, 84 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index a039d543c35e7..e8a1e35c3fc48 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -45,7 +45,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) private val useSasl: Boolean = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) + private val transportConf = + SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0) private val blockHandler = newShuffleBlockHandler(transportConf) private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler, true) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 70a42f9045e6b..b0694e3c6c8af 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -41,7 +41,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index cef203006d685..84833f59d7afe 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -40,23 +40,23 @@ object SparkTransportConf { /** * Utility for creating a [[TransportConf]] from a [[SparkConf]]. + * @param _conf the [[SparkConf]] + * @param module the module name * @param numUsableCores if nonzero, this will restrict the server and client threads to only * use the given number of cores, rather than all of the machine's cores. * This restriction will only occur if these properties are not already set. */ - def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = { + def fromSparkConf(_conf: SparkConf, module: String, numUsableCores: Int = 0): TransportConf = { val conf = _conf.clone // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily // assuming we have all the machine's cores). // NB: Only set if serverThreads/clientThreads not already set. val numThreads = defaultNumThreads(numUsableCores) - conf.set("spark.shuffle.io.serverThreads", - conf.get("spark.shuffle.io.serverThreads", numThreads.toString)) - conf.set("spark.shuffle.io.clientThreads", - conf.get("spark.shuffle.io.clientThreads", numThreads.toString)) + conf.setIfMissing(s"spark.$module.io.serverThreads", numThreads.toString) + conf.setIfMissing(s"spark.$module.io.clientThreads", numThreads.toString) - new TransportConf(new ConfigProvider { + new TransportConf(module, new ConfigProvider { override def get(name: String): String = conf.get(name) }) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 09093819bb22c..3e0c497969502 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -22,16 +22,13 @@ import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy +import javax.annotation.Nullable -import scala.collection.mutable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success} import scala.util.control.NonFatal -import com.google.common.base.Preconditions import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ @@ -49,7 +46,8 @@ private[netty] class NettyRpcEnv( securityManager: SecurityManager) extends RpcEnv(conf) with Logging { private val transportConf = SparkTransportConf.fromSparkConf( - conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"), + conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), + "rpc", conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index dcd7a1f8a3230..3adb022822094 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -109,7 +109,7 @@ private[spark] class CoarseMesosSchedulerBackend( private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { Some(new MesosExternalShuffleClient( - SparkTransportConf.fromSparkConf(conf), + SparkTransportConf.fromSparkConf(conf, "shuffle"), securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled())) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 39fadd8783518..cc5f933393adf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -46,7 +46,7 @@ private[spark] trait ShuffleWriterGroup { private[spark] class FileShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver with Logging { - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") private lazy val blockManager = SparkEnv.get.blockManager diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 05b1eed7f3bef..fadb8fe7ed0ab 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -47,7 +47,7 @@ private[spark] class IndexShuffleBlockResolver( private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager) - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") def getDataFile(shuffleId: Int, mapId: Int): File = { blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 661c706af32b1..ab0007fb78993 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -122,7 +122,7 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled()) } else { diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 231f4631e0a47..1c775bcb3d9c1 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -35,7 +35,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { var rpcHandler: ExternalShuffleBlockHandler = _ override def beforeAll() { - val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2) + val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2) rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) val transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 3b2eff377955a..115135d44adbd 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -23,18 +23,53 @@ * A central location that tracks all the settings we expose to users. */ public class TransportConf { + + private final String SPARK_NETWORK_IO_MODE_KEY; + private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY; + private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY; + private final String SPARK_NETWORK_IO_BACKLOG_KEY; + private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY; + private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY; + private final String SPARK_NETWORK_IO_CLIENTTHREADS_KEY; + private final String SPARK_NETWORK_IO_RECEIVEBUFFER_KEY; + private final String SPARK_NETWORK_IO_SENDBUFFER_KEY; + private final String SPARK_NETWORK_SASL_TIMEOUT_KEY; + private final String SPARK_NETWORK_IO_MAXRETRIES_KEY; + private final String SPARK_NETWORK_IO_RETRYWAIT_KEY; + private final String SPARK_NETWORK_IO_LAZYFD_KEY; + private final ConfigProvider conf; - public TransportConf(ConfigProvider conf) { + private final String module; + + public TransportConf(String module, ConfigProvider conf) { + this.module = module; this.conf = conf; + SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode"); + SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs"); + SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout"); + SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog"); + SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer"); + SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads"); + SPARK_NETWORK_IO_CLIENTTHREADS_KEY = getConfKey("io.clientThreads"); + SPARK_NETWORK_IO_RECEIVEBUFFER_KEY = getConfKey("io.receiveBuffer"); + SPARK_NETWORK_IO_SENDBUFFER_KEY = getConfKey("io.sendBuffer"); + SPARK_NETWORK_SASL_TIMEOUT_KEY = getConfKey("sasl.timeout"); + SPARK_NETWORK_IO_MAXRETRIES_KEY = getConfKey("io.maxRetries"); + SPARK_NETWORK_IO_RETRYWAIT_KEY = getConfKey("io.retryWait"); + SPARK_NETWORK_IO_LAZYFD_KEY = getConfKey("io.lazyFD"); + } + + private String getConfKey(String suffix) { + return "spark." + module + "." + suffix; } /** IO mode: nio or epoll */ - public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } /** If true, we will prefer allocating off-heap byte buffers within Netty. */ public boolean preferDirectBufs() { - return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true); + return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true); } /** Connect timeout in milliseconds. Default 120 secs. */ @@ -42,23 +77,23 @@ public int connectionTimeoutMs() { long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( conf.get("spark.network.timeout", "120s")); long defaultTimeoutMs = JavaUtils.timeStringAsSec( - conf.get("spark.shuffle.io.connectionTimeout", defaultNetworkTimeoutS + "s")) * 1000; + conf.get(SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY, defaultNetworkTimeoutS + "s")) * 1000; return (int) defaultTimeoutMs; } /** Number of concurrent connections between two nodes for fetching data. */ public int numConnectionsPerPeer() { - return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 1); + return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1); } /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ - public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } + public int backLog() { return conf.getInt(SPARK_NETWORK_IO_BACKLOG_KEY, -1); } /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ - public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); } + public int serverThreads() { return conf.getInt(SPARK_NETWORK_IO_SERVERTHREADS_KEY, 0); } /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ - public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); } + public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); } /** * Receive buffer size (SO_RCVBUF). @@ -67,28 +102,28 @@ public int numConnectionsPerPeer() { * Assuming latency = 1ms, network_bandwidth = 10Gbps * buffer size should be ~ 1.25MB */ - public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); } + public int receiveBuf() { return conf.getInt(SPARK_NETWORK_IO_RECEIVEBUFFER_KEY, -1); } /** Send buffer size (SO_SNDBUF). */ - public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); } /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ public int saslRTTimeoutMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.sasl.timeout", "30s")) * 1000; + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000; } /** * Max number of times we will try IO exceptions (such as connection timeouts) per request. * If set to 0, we will not do any retries. */ - public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); } + public int maxIORetries() { return conf.getInt(SPARK_NETWORK_IO_MAXRETRIES_KEY, 3); } /** * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. * Only relevant if maxIORetries > 0. */ public int ioRetryWaitTimeMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.io.retryWait", "5s")) * 1000; + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_IO_RETRYWAIT_KEY, "5s")) * 1000; } /** @@ -101,11 +136,11 @@ public int memoryMapBytes() { } /** - * Whether to initialize shuffle FileDescriptor lazily or not. If true, file descriptors are + * Whether to initialize FileDescriptor lazily or not. If true, file descriptors are * created only when data is going to be transferred. This can reduce the number of open files. */ public boolean lazyFileDescriptor() { - return conf.getBoolean("spark.shuffle.io.lazyFD", true); + return conf.getBoolean(SPARK_NETWORK_IO_LAZYFD_KEY, true); } /** diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index dfb7740344ed0..dc5fa1cee69bc 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -83,7 +83,7 @@ public static void setUp() throws Exception { fp.write(fileContent); fp.close(); - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); streamManager = new StreamManager() { diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 84ebb337e6d54..42955ef69235a 100644 --- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -60,7 +60,7 @@ public class RequestTimeoutIntegrationSuite { public void setUp() throws Exception { Map configMap = Maps.newHashMap(); configMap.put("spark.shuffle.io.connectionTimeout", "2s"); - conf = new TransportConf(new MapConfigProvider(configMap)); + conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); defaultManager = new StreamManager() { @Override diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 64b457b4b3f01..8eb56bdd9846f 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -49,7 +49,7 @@ public class RpcIntegrationSuite { @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); rpcHandler = new RpcHandler() { @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 6dcec831dec71..00158fd081626 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -89,7 +89,7 @@ public static void setUp() throws Exception { fp.close(); } - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); final StreamManager streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index f447137419306..dac7d4a5b0a09 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -52,7 +52,7 @@ public class TransportClientFactorySuite { @Before public void setUp() { - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); RpcHandler rpcHandler = new NoOpRpcHandler(); context = new TransportContext(conf, rpcHandler); server1 = context.createServer(); @@ -76,7 +76,7 @@ private void testClientReuse(final int maxConnections, boolean concurrent) Map configMap = Maps.newHashMap(); configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); - TransportConf conf = new TransportConf(new MapConfigProvider(configMap)); + TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler); @@ -182,7 +182,7 @@ public void closeBlockClientsWithFactory() throws IOException { @Test public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException { - TransportConf conf = new TransportConf(new ConfigProvider() { + TransportConf conf = new TransportConf("shuffle", new ConfigProvider() { @Override public String get(String name) { diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 3469e84e7f4da..b146899670180 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -207,7 +207,7 @@ public void testEncryptedMessage() throws Exception { public void testEncryptedMessageChunking() throws Exception { File file = File.createTempFile("sasltest", ".txt"); try { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); byte[] data = new byte[8 * 1024]; new Random().nextBytes(data); @@ -242,7 +242,7 @@ public void testFileRegionEncryption() throws Exception { final File file = File.createTempFile("sasltest", ".txt"); SaslTestCtx ctx = null; try { - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); StreamManager sm = mock(StreamManager.class); when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { @Override @@ -368,7 +368,7 @@ private static class SaslTestCtx { boolean disableClientEncryption) throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index c393a5e1e6810..1c2fa4d0d462c 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -70,7 +70,7 @@ public class SaslIntegrationSuite { @BeforeClass public static void beforeAll() throws IOException { - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); context = new TransportContext(conf, new TestRpcHandler()); secretKeyHolder = mock(SecretKeyHolder.class); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 3c6cb367dea46..a9958232a1d28 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -42,7 +42,7 @@ public class ExternalShuffleBlockResolverSuite { static TestShuffleDataContext dataContext; - static TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + static TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @BeforeClass public static void beforeAll() throws IOException { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index 2f4f1d0df478b..532d7ab8d01bd 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -35,7 +35,7 @@ public class ExternalShuffleCleanupSuite { // Same-thread Executor used to ensure cleanup happens synchronously in test thread. Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @Test public void noCleanupAndCleanup() throws IOException { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index a3f9a38b1aeb9..2095f41d79c16 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -91,7 +91,7 @@ public static void beforeAll() throws IOException { dataContext1.create(); dataContext1.insertHashShuffleData(1, 0, exec1Blocks); - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); handler = new ExternalShuffleBlockHandler(conf, null); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index aa99efda94948..08ddb3755bd08 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -39,7 +39,7 @@ public class ExternalShuffleSecuritySuite { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); TransportServer server; @Before diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index 06e46f9241094..3a6ef0d3f8476 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -254,7 +254,7 @@ private static void performInteractions(List> inte BlockFetchingListener listener) throws IOException { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); Stubber stub = null; diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 11ea7f3fd3cfe..ba6d30a74c673 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -120,7 +120,7 @@ protected void serviceInit(Configuration conf) { registeredExecutorFile = findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs")); - TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); + TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); From 4ea0b8b23c739d662f23c64d0231405e4197745f Mon Sep 17 00:00:00 2001 From: JihongMa Date: Wed, 18 Nov 2015 13:03:37 -0800 Subject: [PATCH 0801/1168] [SPARK-11720][SQL][ML] Handle edge cases when count = 0 or 1 for Stats function return Double.NaN for mean/average when count == 0 for all numeric types that is converted to Double, Decimal type continue to return null. Author: JihongMa Closes #9705 from JihongMA/SPARK-11720. --- python/pyspark/sql/dataframe.py | 2 +- .../aggregate/CentralMomentAgg.scala | 2 +- .../expressions/aggregate/Kurtosis.scala | 9 +++++---- .../expressions/aggregate/Skewness.scala | 9 +++++---- .../expressions/aggregate/Stddev.scala | 18 ++++++++++++++---- .../expressions/aggregate/Variance.scala | 18 ++++++++++++++---- .../spark/sql/DataFrameAggregateSuite.scala | 18 ++++++++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 8 files changed, 53 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ad6ad0235a904..0dd75ba7ca820 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -761,7 +761,7 @@ def describe(self, *cols): +-------+------------------+-----+ | count| 2| 2| | mean| 3.5| null| - | stddev|2.1213203435596424| NaN| + | stddev|2.1213203435596424| null| | min| 2|Alice| | max| 5| Bob| +-------+------------------+-----+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index de5872ab11eb1..d07d4c338cdfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -206,7 +206,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) * needed to compute the aggregate stat. */ - def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double + def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any override final def eval(buffer: InternalRow): Any = { val n = buffer.getDouble(nOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala index 8fa3aac9f1a51..c2bf2cb94116c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -37,16 +37,17 @@ case class Kurtosis(child: Expression, override protected val momentOrder = 4 // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m4 = moments(4) - if (n == 0.0 || m2 == 0.0) { + if (n == 0.0) { + null + } else if (m2 == 0.0) { Double.NaN - } - else { + } else { n * m4 / (m2 * m2) - 3.0 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala index e1c01a5b82781..9411bcea2539a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -36,16 +36,17 @@ case class Skewness(child: Expression, override protected val momentOrder = 3 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m3 = moments(3) - if (n == 0.0 || m2 == 0.0) { + if (n == 0.0) { + null + } else if (m2 == 0.0) { Double.NaN - } - else { + } else { math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index 05dd5e3b22543..eec79a9033e36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -36,11 +36,17 @@ case class StddevSamp(child: Expression, override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0)) + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + math.sqrt(moments(2) / (n - 1.0)) + } } } @@ -62,10 +68,14 @@ case class StddevPop( override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n) + if (n == 0.0) { + null + } else { + math.sqrt(moments(2) / n) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala index ede2da2805966..cf3a740305391 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala @@ -36,11 +36,17 @@ case class VarianceSamp(child: Expression, override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + moments(2) / (n - 1.0) + } } } @@ -62,10 +68,14 @@ case class VariancePop( override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0) Double.NaN else moments(2) / n + if (n == 0.0) { + null + } else { + moments(2) / n + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 432e8d17623a4..71adf2148a403 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -205,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)), - Row(Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null)) } test("zero sum") { @@ -244,17 +244,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("zero moments") { val input = Seq((1, 2)).toDF("a", "b") checkAnswer( - input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + input.agg(stddev('a), stddev_samp('a), stddev_pop('a), variance('a), + var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) checkAnswer( input.agg( + expr("stddev(a)"), + expr("stddev_samp(a)"), + expr("stddev_pop(a)"), expr("variance(a)"), expr("var_samp(a)"), expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) } test("null moments") { @@ -262,7 +268,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) checkAnswer( emptyTableData.agg( @@ -271,6 +277,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5a7f24684d1b7..6399b0165c4c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val emptyDescribeResult = Seq( Row("count", "0", "0"), Row("mean", null, null), - Row("stddev", "NaN", "NaN"), + Row("stddev", null, null), Row("min", null, null), Row("max", null, null)) From e890f3a13368423784400d254239d4933a10a62a Mon Sep 17 00:00:00 2001 From: Wenjian Huang Date: Wed, 18 Nov 2015 13:06:25 -0800 Subject: [PATCH 0802/1168] [SPARK-6790][ML] Add spark.ml LinearRegression import/export This replaces [https://github.com/apache/spark/pull/9656] with updates. fayeshine should be the main author when this PR is committed. CC: mengxr fayeshine Author: Wenjian Huang Author: Joseph K. Bradley Closes #9814 from jkbradley/fayeshine-patch-6790. --- .../ml/regression/LinearRegression.scala | 77 ++++++++++++++++++- .../ml/regression/LinearRegressionSuite.scala | 34 +++++++- 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 913140e581983..ca55d5915e688 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import breeze.stats.distributions.StudentsT +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} import org.apache.spark.ml.feature.Instance @@ -30,7 +31,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ @@ -65,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Logging { + with LinearRegressionParams with Writable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -341,6 +342,19 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object LinearRegression extends Readable[LinearRegression] { + + @Since("1.6.0") + override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression] + + @Since("1.6.0") + override def load(path: String): LinearRegression = read.load(path) } /** @@ -354,7 +368,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams { + with LinearRegressionParams with Writable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -422,6 +436,63 @@ class LinearRegressionModel private[ml] ( if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) newModel.setParent(parent) } + + /** + * Returns a [[Writer]] instance for this ML instance. + * + * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + * + * This also does not save the [[parent]] currently. + */ + @Since("1.6.0") + override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this) +} + +@Since("1.6.0") +object LinearRegressionModel extends Readable[LinearRegressionModel] { + + @Since("1.6.0") + override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader + + @Since("1.6.0") + override def load(path: String): LinearRegressionModel = read.load(path) + + /** [[Writer]] instance for [[LinearRegressionModel]] */ + private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) + extends Writer with Logging { + + private case class Data(intercept: Double, coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + } + } + + private class LinearRegressionModelReader extends Reader[LinearRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + + override def load(path: String): LinearRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("intercept", "coefficients").head() + val intercept = data.getDouble(0) + val coefficients = data.getAs[Vector](1) + val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index a1d86fe8fedad..2bdc0e184d734 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -22,14 +22,15 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LinearRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private val seed: Int = 42 @transient var datasetWithDenseFeature: DataFrame = _ @@ -854,4 +855,33 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } } + + test("read/write") { + def checkModelData(model: LinearRegressionModel, model2: LinearRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + } + val lr = new LinearRegression() + testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, + checkModelData) + } +} + +object LinearRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "regParam" -> 0.01, + "elasticNetParam" -> 0.1, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> true, + "tol" -> 0.8, + "standardization" -> false, + "solver" -> "l-bfgs" + ) } From 96cab3d6beb01e5e130ff9f931c14d38ef0b107f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 18 Nov 2015 13:16:31 -0800 Subject: [PATCH 0803/1168] [SPARK-6789][ML] Add Readable, Writable support for spark.ml ALS, ALSModel Also modifies DefaultParamsWriter.saveMetadata to take optional extra metadata. CC: mengxr yanboliang Author: Joseph K. Bradley Closes #9786 from jkbradley/als-io. --- .../apache/spark/ml/recommendation/ALS.scala | 75 ++++++++++++++++-- .../org/apache/spark/ml/util/ReadWrite.scala | 14 +++- .../spark/ml/recommendation/ALSSuite.scala | 78 ++++++++++++++++--- 3 files changed, 150 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 535f266b9a944..d92514d2e239e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -27,13 +27,16 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.{FileSystem, Path} +import org.json4s.{DefaultFormats, JValue} +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD @@ -182,7 +185,7 @@ class ALSModel private[ml] ( val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) - extends Model[ALSModel] with ALSModelParams { + extends Model[ALSModel] with ALSModelParams with Writable { /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) @@ -220,8 +223,60 @@ class ALSModel private[ml] ( val copied = new ALSModel(uid, rank, userFactors, itemFactors) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new ALSModel.ALSModelWriter(this) } +@Since("1.6.0") +object ALSModel extends Readable[ALSModel] { + + @Since("1.6.0") + override def read: Reader[ALSModel] = new ALSModelReader + + @Since("1.6.0") + override def load(path: String): ALSModel = read.load(path) + + private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata = render("rank" -> instance.rank) + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val userPath = new Path(path, "userFactors").toString + instance.userFactors.write.format("parquet").save(userPath) + val itemPath = new Path(path, "itemFactors").toString + instance.itemFactors.write.format("parquet").save(itemPath) + } + } + + private[recommendation] class ALSModelReader extends Reader[ALSModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.recommendation.ALSModel" + + override def load(path: String): ALSModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + implicit val format = DefaultFormats + val rank: Int = metadata.extraMetadata match { + case Some(m: JValue) => + (m \ "rank").extract[Int] + case None => + throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" + + s" ${metadata.metadataStr}") + } + + val userPath = new Path(path, "userFactors").toString + val userFactors = sqlContext.read.format("parquet").load(userPath) + val itemPath = new Path(path, "itemFactors").toString + val itemFactors = sqlContext.read.format("parquet").load(itemPath) + + val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} /** * :: Experimental :: @@ -254,7 +309,7 @@ class ALSModel private[ml] ( * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable { import org.apache.spark.ml.recommendation.ALS.Rating @@ -336,8 +391,12 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { } override def copy(extra: ParamMap): ALS = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } + /** * :: DeveloperApi :: * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is @@ -347,7 +406,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { * than 2 billion. */ @DeveloperApi -object ALS extends Logging { +object ALS extends Readable[ALS] with Logging { /** * :: DeveloperApi :: @@ -356,6 +415,12 @@ object ALS extends Logging { @DeveloperApi case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) + @Since("1.6.0") + override def read: Reader[ALS] = new DefaultParamsReader[ALS] + + @Since("1.6.0") + override def load(path: String): ALS = read.load(path) + /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { /** Solves a least squares problem with regularization (possibly with other constraints). */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index dddb72af5ba78..d8ce907af5323 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -194,7 +194,11 @@ private[ml] object DefaultParamsWriter { * - uid * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ - def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = { + def saveMetadata( + instance: Params, + path: String, + sc: SparkContext, + extraMetadata: Option[JValue] = None): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] @@ -205,7 +209,8 @@ private[ml] object DefaultParamsWriter { ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) + ("paramMap" -> jsonParams) ~ + ("extraMetadata" -> extraMetadata) val metadataPath = new Path(path, "metadata").toString val metadataJson = compact(render(metadata)) sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) @@ -236,6 +241,7 @@ private[ml] object DefaultParamsReader { /** * All info from metadata file. * @param params paramMap, as a [[JValue]] + * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]] * @param metadataStr Full metadata file String (for debugging) */ case class Metadata( @@ -244,6 +250,7 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, + extraMetadata: Option[JValue], metadataStr: String) /** @@ -262,12 +269,13 @@ private[ml] object DefaultParamsReader { val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] val params = metadata \ "paramMap" + val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]] if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, metadataStr) + Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index eadc80e0e62b1..2c3fb84160dcb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.recommendation -import java.io.File import java.util.Random import scala.collection.mutable @@ -26,28 +25,26 @@ import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.sql.{DataFrame, Row} -class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { - private var tempDir: File = _ +class ALSSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { override def beforeAll(): Unit = { super.beforeAll() - tempDir = Utils.createTempDir() sc.setCheckpointDir(tempDir.getAbsolutePath) } override def afterAll(): Unit = { - Utils.deleteRecursively(tempDir) super.afterAll() } @@ -186,7 +183,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5)) var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)] var i = 0 - while (i < compressed.srcIds.size) { + while (i < compressed.srcIds.length) { var j = compressed.dstPtrs(i) while (j < compressed.dstPtrs(i + 1)) { val dstEncodedIndex = compressed.dstEncodedIndices(j) @@ -483,4 +480,67 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true, seed = 0) } + + test("read/write") { + import ALSSuite._ + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + val als = new ALS() + allEstimatorParamSettings.foreach { case (p, v) => + als.set(als.getParam(p), v) + } + val sqlContext = this.sqlContext + import sqlContext.implicits._ + val model = als.fit(ratings.toDF()) + + // Test Estimator save/load + val als2 = testDefaultReadWrite(als) + allEstimatorParamSettings.foreach { case (p, v) => + val param = als.getParam(p) + assert(als.get(param).get === als2.get(param).get) + } + + // Test Model save/load + val model2 = testDefaultReadWrite(model) + allModelParamSettings.foreach { case (p, v) => + val param = model.getParam(p) + assert(model.get(param).get === model2.get(param).get) + } + assert(model.rank === model2.rank) + def getFactors(df: DataFrame): Set[(Int, Array[Float])] = { + df.select("id", "features").collect().map { case r => + (r.getInt(0), r.getAs[Array[Float]](1)) + }.toSet + } + assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) + assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + } +} + +object ALSSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allModelParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPredictionCol" + ) + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allEstimatorParamSettings: Map[String, Any] = allModelParamSettings ++ Map( + "maxIter" -> 1, + "rank" -> 1, + "regParam" -> 0.01, + "numUserBlocks" -> 2, + "numItemBlocks" -> 2, + "implicitPrefs" -> true, + "alpha" -> 0.9, + "nonnegative" -> true, + "checkpointInterval" -> 20 + ) } From 6138b6ec4cbe559980ad97afd7824034494fca96 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 18 Nov 2015 13:25:15 -0800 Subject: [PATCH 0804/1168] [SPARK-11813][MLLIB] Avoid serialization of vocab in Word2Vec jira: https://issues.apache.org/jira/browse/SPARK-11813 I found the problem during training a large corpus. Avoid serialization of vocab in Word2Vec has 2 benefits. 1. Performance improvement for less serialization. 2. Increase the capacity of Word2Vec a lot. Currently in the fit of word2vec, the closure mainly includes serialization of Word2Vec and 2 global table. the main part of Word2vec is the vocab of size: vocab * 40 * 2 * 4 = 320 vocab 2 global table: vocab * vectorSize * 8. If vectorSize = 20, that's 160 vocab. Their sum cannot exceed Int.max due to the restriction of ByteArrayOutputStream. In any case, avoiding serialization of vocab helps decrease the size of the closure serialization, especially when vectorSize is small, thus to allow larger vocabulary. Actually there's another possible fix, make local copy of fields to avoid including Word2Vec in the closure. Let me know if that's preferred. Author: Yuhao Yang Closes #9803 from hhbyyh/w2vVocab. --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f3e4d346e358a..7ab0d89d23a3f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -145,8 +145,8 @@ class Word2Vec extends Serializable with Logging { private var trainWordsCount = 0 private var vocabSize = 0 - private var vocab: Array[VocabWord] = null - private var vocabHash = mutable.HashMap.empty[String, Int] + @transient private var vocab: Array[VocabWord] = null + @transient private var vocabHash = mutable.HashMap.empty[String, Int] private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) From c5ab5a3394776683b9b6e381599c6d88ba17b637 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 18 Nov 2015 13:30:29 -0800 Subject: [PATCH 0805/1168] [SPARK-11684][R][ML][DOC] Update SparkR glm API doc, user guide and example codes This PR includes: * Update SparkR:::glm, SparkR:::summary API docs. * Update SparkR machine learning user guide and example codes to show: * supporting feature interaction in R formula. * summary for gaussian GLM model. * coefficients for binomial GLM model. mengxr Author: Yanbo Liang Closes #9727 from yanboliang/spark-11684. --- R/pkg/R/mllib.R | 18 +++++-- docs/sparkr.md | 50 ++++++++++++++++--- .../ml/regression/LinearRegression.scala | 3 ++ 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index f23e1c7f1fce4..8d3b4388ae575 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -32,6 +32,12 @@ setClass("PipelineModel", representation(model = "jobj")) #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter #' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @param standardize Whether to standardize features before training +#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and +#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory +#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an +#' analytical solution to the linear regression problem. The default value is "auto" +#' which means that the solver algorithm is selected automatically. #' @return a fitted MLlib model #' @rdname glm #' @export @@ -79,9 +85,15 @@ setMethod("predict", signature(object = "PipelineModel"), #' #' Returns the summary of a model produced by glm(), similarly to R's summary(). #' -#' @param x A fitted MLlib model -#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See -#' summary.glm for more information. +#' @param object A fitted MLlib model +#' @return a list with 'devianceResiduals' and 'coefficients' components for gaussian family +#' or a list with 'coefficients' component for binomial family. \cr +#' For gaussian family: the 'devianceResiduals' gives the min/max deviance residuals +#' of the estimation, the 'coefficients' gives the estimated coefficients and their +#' estimated standard errors, t values and p-values. (It only available when model +#' fitted by normal solver.) \cr +#' For binomial family: the 'coefficients' gives the estimated coefficients. +#' See summary.glm for more information. \cr #' @rdname summary #' @export #' @examples diff --git a/docs/sparkr.md b/docs/sparkr.md index 437bd4756c276..a744b76be7466 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -286,24 +286,37 @@ head(teenagers) # Machine Learning -SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR. +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. + +The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html). + +* For gaussian GLM model, it returns a list with 'devianceResiduals' and 'coefficients' components. The 'devianceResiduals' gives the min/max deviance residuals of the estimation; the 'coefficients' gives the estimated coefficients and their estimated standard errors, t values and p-values. (It only available when model fitted by normal solver.) +* For binomial GLM model, it returns a list with 'coefficients' component which gives the estimated coefficients. + +The examples below show the use of building gaussian GLM model and binomial GLM model using SparkR. + +## Gaussian GLM model
    {% highlight r %} # Create the DataFrame df <- createDataFrame(sqlContext, iris) -# Fit a linear model over the dataset. +# Fit a gaussian GLM model over the dataset. model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") -# Model coefficients are returned in a similar format to R's native glm(). +# Model summary are returned in a similar format to R's native glm(). summary(model) +##$devianceResiduals +## Min Max +## -1.307112 1.412532 +## ##$coefficients -## Estimate -##(Intercept) 2.2513930 -##Sepal_Width 0.8035609 -##Species_versicolor 1.4587432 -##Species_virginica 1.9468169 +## Estimate Std. Error t value Pr(>|t|) +##(Intercept) 2.251393 0.3697543 6.08889 9.568102e-09 +##Sepal_Width 0.8035609 0.106339 7.556598 4.187317e-12 +##Species_versicolor 1.458743 0.1121079 13.01195 0 +##Species_virginica 1.946817 0.100015 19.46525 0 # Make predictions based on the model. predictions <- predict(model, newData = df) @@ -317,3 +330,24 @@ head(select(predictions, "Sepal_Length", "prediction")) ##6 5.4 5.385281 {% endhighlight %}
    + +## Binomial GLM model + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) +training <- filter(df, df$Species != "setosa") + +# Fit a binomial GLM model over the dataset. +model <- glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial") + +# Model coefficients are returned in a similar format to R's native glm(). +summary(model) +##$coefficients +## Estimate +##(Intercept) -13.046005 +##Sepal_Length 1.902373 +##Sepal_Width 0.404655 +{% endhighlight %} +
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index ca55d5915e688..f7c44f0a51b8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -145,6 +145,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String /** * Set the solver algorithm used for optimization. * In case of linear regression, this can be "l-bfgs", "normal" and "auto". + * "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton + * optimization method. "normal" denotes using Normal Equation as an analytical + * solution to the linear regression problem. * The default value is "auto" which means that the solver algorithm is * selected automatically. * @group setParam From 00f3f947202db67bf9ffe4104b4e0c02bd68939e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 18 Nov 2015 13:32:06 -0800 Subject: [PATCH 0806/1168] [SPARK-11820][ML][PYSPARK] PySpark LiR & LoR should support weightCol [SPARK-7685](https://issues.apache.org/jira/browse/SPARK-7685) and [SPARK-9642](https://issues.apache.org/jira/browse/SPARK-9642) have already supported setting weight column for ```LogisticRegression``` and ```LinearRegression```. It's a very important feature, PySpark should also support. mengxr Author: Yanbo Liang Closes #9811 from yanboliang/spark-11820. --- python/pyspark/ml/classification.py | 17 +++++++++-------- python/pyspark/ml/regression.py | 16 ++++++++-------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 603f2c7f798dc..4a2982e2047ff 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -36,7 +36,8 @@ @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, - HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds): + HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, + HasWeightCol): """ Logistic regression. Currently, this class only supports binary classification. @@ -44,9 +45,9 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors >>> df = sc.parallelize([ - ... Row(label=1.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() - >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() + >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") >>> model = lr.fit(df) >>> model.weights DenseVector([5.5...]) @@ -80,12 +81,12 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", - rawPredictionCol="rawPrediction", standardization=True): + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ - rawPredictionCol="rawPrediction", standardization=True) + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() @@ -105,12 +106,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", - rawPredictionCol="rawPrediction", standardization=True): + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ - rawPredictionCol="rawPrediction", standardization=True) + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 7648bf13266bf..944e648ec8801 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -35,7 +35,7 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver): + HasStandardization, HasSolver, HasWeightCol): """ Linear regression. @@ -50,9 +50,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> from pyspark.mllib.linalg import Vectors >>> df = sqlContext.createDataFrame([ - ... (1.0, Vectors.dense(1.0)), - ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal") + ... (1.0, 2.0, Vectors.dense(1.0)), + ... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) + >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight") >>> model = lr.fit(df) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001 @@ -75,11 +75,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto"): + standardization=True, solver="auto", weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto") + standardization=True, solver="auto", weightCol=None) """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -92,11 +92,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto"): + standardization=True, solver="auto", weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto") + standardization=True, solver="auto", weightCol=None) Sets params for linear regression. """ kwargs = self.setParams._input_kwargs From 11fbb0fa410a11a55f1f4e25b7e0e81bb4729789 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 18 Nov 2015 14:05:18 -0800 Subject: [PATCH 0807/1168] [SPARK-11544][SQL] sqlContext doesn't use PathFilter Apply the user supplied pathfilter while retrieving the files from fs. Author: Dilip Biswal Closes #9652 from dilipbiswal/spark-11544. --- .../apache/spark/sql/sources/interfaces.scala | 25 ++++++++++--- .../datasources/json/JsonSuite.scala | 36 +++++++++++++++++-- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b3d3bdf50df63..f9465157c936d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,7 +21,8 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} +import org.apache.hadoop.mapred.{JobConf, FileInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -447,9 +448,15 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - logInfo(s"Listing $qualified on driver") - Try(fs.listStatus(qualified)).getOrElse(Array.empty) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) + } else { + Try(fs.listStatus(qualified)).getOrElse(Array.empty) + } }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -847,8 +854,16 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(fs.getConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6042b1178affe..f09b61e838159 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,19 +19,27 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.spark.rdd.RDD +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} import org.scalactic.Tolerance._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class TestFileFilter extends PathFilter { + override def accept(path: Path): Boolean = path.getParent.getName != "p=2" +} + class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1390,4 +1398,28 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + test("SPARK-11544 test pathfilter") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(2) + df.write.json(path + "/p=1") + df.write.json(path + "/p=2") + assert(sqlContext.read.json(path).count() === 4) + + val clonedConf = new Configuration(hadoopConfiguration) + try { + hadoopConfiguration.setClass( + "mapreduce.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + assert(sqlContext.read.json(path).count() === 2) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } } From 8ad7aae902ed25ebca337bc108008ee3b80b129f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 15:42:07 -0800 Subject: [PATCH 0808/1168] [SPARK-11810][SQL] Java-based encoder for opaque types in Datasets. This patch refactors the existing Kryo encoder expressions and adds support for Java serialization. Author: Reynold Xin Closes #9802 from rxin/SPARK-11810. --- .../scala/org/apache/spark/sql/Encoder.scala | 41 +++++++++--- .../sql/catalyst/expressions/objects.scala | 67 ++++++++++++------- .../catalyst/encoders/FlatEncoderSuite.scala | 27 ++++++-- .../org/apache/spark/sql/DatasetSuite.scala | 36 +++++++++- 4 files changed, 130 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 79c2255641c06..1ed5111440c80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.reflect.{ClassTag, classTag} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.{DeserializeWithKryo, BoundReference, SerializeWithKryo} +import org.apache.spark.sql.catalyst.expressions.{DecodeUsingSerializer, BoundReference, EncodeUsingSerializer} import org.apache.spark.sql.types._ /** @@ -43,28 +43,49 @@ trait Encoder[T] extends Serializable { */ object Encoders { - /** - * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. - * This encoder maps T into a single byte array (binary) field. - */ - def kryo[T: ClassTag]: Encoder[T] = { - val ser = SerializeWithKryo(BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true)) - val deser = DeserializeWithKryo[T](BoundReference(0, BinaryType, nullable = true), classTag[T]) + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { ExpressionEncoder[T]( schema = new StructType().add("value", BinaryType), flat = true, - toRowExpressions = Seq(ser), - fromRowExpression = deser, + toRowExpressions = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + fromRowExpression = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), clsTag = classTag[T] ) } + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + */ + def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) + /** * Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java + * serialization. This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + */ + def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) + + /** + * Creates an encoder that serializes objects of type T using generic Java serialization. + * This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + */ + def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 489c6126f8cd3..acf0da240051e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -21,7 +21,7 @@ import scala.language.existentials import scala.reflect.ClassTag import org.apache.spark.SparkConf -import org.apache.spark.serializer.{KryoSerializerInstance, KryoSerializer} +import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} @@ -517,29 +517,39 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy } } -/** Serializes an input object using Kryo serializer. */ -case class SerializeWithKryo(child: Expression) extends UnaryExpression { +/** + * Serializes an input object using a generic serializer (Kryo or Java). + * @param kryo if true, use Kryo. Otherwise, use Java. + */ +case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends UnaryExpression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val input = child.gen(ctx) - val kryo = ctx.freshName("kryoSerializer") - val kryoClass = classOf[KryoSerializer].getName - val kryoInstanceClass = classOf[KryoSerializerInstance].getName - val sparkConfClass = classOf[SparkConf].getName + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" ctx.addMutableState( - kryoInstanceClass, - kryo, - s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + // Code to serialize. + val input = child.gen(ctx) s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $kryo.serialize(${input.value}, null).array(); + ${ev.value} = $serializer.serialize(${input.value}, null).array(); } """ } @@ -548,29 +558,38 @@ case class SerializeWithKryo(child: Expression) extends UnaryExpression { } /** - * Deserializes an input object using Kryo serializer. Note that the ClassTag is not an implicit - * parameter because TreeNode cannot copy implicit parameters. + * Serializes an input object using a generic serializer (Kryo or Java). Note that the ClassTag + * is not an implicit parameter because TreeNode cannot copy implicit parameters. + * @param kryo if true, use Kryo. Otherwise, use Java. */ -case class DeserializeWithKryo[T](child: Expression, tag: ClassTag[T]) extends UnaryExpression { +case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) + extends UnaryExpression { override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val input = child.gen(ctx) - val kryo = ctx.freshName("kryoSerializer") - val kryoClass = classOf[KryoSerializer].getName - val kryoInstanceClass = classOf[KryoSerializerInstance].getName - val sparkConfClass = classOf[SparkConf].getName + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" ctx.addMutableState( - kryoInstanceClass, - kryo, - s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + // Code to serialize. + val input = child.gen(ctx) s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = (${ctx.javaType(dataType)}) - $kryo.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); + $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala index 2729db84897a2..6e0322fb6e019 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -76,17 +76,34 @@ class FlatEncoderSuite extends ExpressionEncoderSuite { // Kryo encoders encodeDecodeTest( "hello", - Encoders.kryo[String].asInstanceOf[ExpressionEncoder[String]], + encoderFor(Encoders.kryo[String]), "kryo string") encodeDecodeTest( - new NotJavaSerializable(15), - Encoders.kryo[NotJavaSerializable].asInstanceOf[ExpressionEncoder[NotJavaSerializable]], + new KryoSerializable(15), + encoderFor(Encoders.kryo[KryoSerializable]), "kryo object serialization") + + // Java encoders + encodeDecodeTest( + "hello", + encoderFor(Encoders.javaSerialization[String]), + "java string") + encodeDecodeTest( + new JavaSerializable(15), + encoderFor(Encoders.javaSerialization[JavaSerializable]), + "java object serialization") } +/** For testing Kryo serialization based encoder. */ +class KryoSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[KryoSerializable].value + } +} -class NotJavaSerializable(val value: Int) { +/** For testing Java serialization based encoder. */ +class JavaSerializable(val value: Int) extends Serializable { override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[NotJavaSerializable].value + this.value == other.asInstanceOf[JavaSerializable].value } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b6db583dfe01f..89d964aa3e469 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -357,7 +357,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.toString == "[_1: int, _2: int]") } - test("kryo encoder") { + test("Kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() @@ -365,7 +365,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq((KryoData(1), 1L), (KryoData(2), 1L))) } - test("kryo encoder self join") { + test("Kryo encoder self join") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == @@ -375,6 +375,25 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (KryoData(2), KryoData(1)), (KryoData(2), KryoData(2)))) } + + test("Java encoder") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + + assert(ds.groupBy(p => p).count().collect().toSeq == + Seq((JavaData(1), 1L), (JavaData(2), 1L))) + } + + ignore("Java encoder self join") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + assert(ds.joinWith(ds, lit(true)).collect().toSet == + Set( + (JavaData(1), JavaData(1)), + (JavaData(1), JavaData(2)), + (JavaData(2), JavaData(1)), + (JavaData(2), JavaData(2)))) + } } @@ -406,3 +425,16 @@ class KryoData(val a: Int) { object KryoData { def apply(a: Int): KryoData = new KryoData(a) } + +/** Used to test Java encoder. */ +class JavaData(val a: Int) extends Serializable { + override def equals(other: Any): Boolean = { + a == other.asInstanceOf[JavaData].a + } + override def hashCode: Int = a + override def toString: String = s"JavaData($a)" +} + +object JavaData { + def apply(a: Int): JavaData = new JavaData(a) +} From 3e5ac02b340967eb956e7baf89ad1756b174be6c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 18 Nov 2015 15:47:49 -0800 Subject: [PATCH 0809/1168] [SPARK-6787][ML] add read/write to estimators under ml.feature (1) Add read/write support to the following estimators under spark.ml: * CountVectorizer * IDF * MinMaxScaler * StandardScaler (a little awkward because we store some params in spark.mllib model) * StringIndexer Added some necessary method for read/write. Maybe we should add `private[ml] trait DefaultParamsReadable` and `DefaultParamsWritable` to save some boilerplate code, though we still need to override `load` for Java compatibility. jkbradley Author: Xiangrui Meng Closes #9798 from mengxr/SPARK-6787. --- .../spark/ml/feature/CountVectorizer.scala | 72 +++++++++++++++-- .../org/apache/spark/ml/feature/IDF.scala | 71 ++++++++++++++++- .../spark/ml/feature/MinMaxScaler.scala | 72 +++++++++++++++-- .../spark/ml/feature/StandardScaler.scala | 78 ++++++++++++++++++- .../spark/ml/feature/StringIndexer.scala | 70 +++++++++++++++-- .../ml/feature/CountVectorizerSuite.scala | 24 +++++- .../apache/spark/ml/feature/IDFSuite.scala | 19 ++++- .../spark/ml/feature/MinMaxScalerSuite.scala | 25 +++++- .../ml/feature/StandardScalerSuite.scala | 64 +++++++++++---- .../spark/ml/feature/StringIndexerSuite.scala | 19 ++++- 10 files changed, 467 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 49028e4b85064..5ff9bfb7d1119 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -16,17 +16,19 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.DataFrame import org.apache.spark.util.collection.OpenHashMap /** @@ -105,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit */ @Experimental class CountVectorizer(override val uid: String) - extends Estimator[CountVectorizerModel] with CountVectorizerParams { + extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable { def this() = this(Identifiable.randomUID("cntVec")) @@ -169,6 +171,19 @@ class CountVectorizer(override val uid: String) } override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object CountVectorizer extends Readable[CountVectorizer] { + + @Since("1.6.0") + override def read: Reader[CountVectorizer] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): CountVectorizer = super.load(path) } /** @@ -178,7 +193,9 @@ class CountVectorizer(override val uid: String) */ @Experimental class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) - extends Model[CountVectorizerModel] with CountVectorizerParams { + extends Model[CountVectorizerModel] with CountVectorizerParams with Writable { + + import CountVectorizerModel._ def this(vocabulary: Array[String]) = { this(Identifiable.randomUID("cntVecModel"), vocabulary) @@ -232,4 +249,47 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent) copyValues(copied, extra) } + + @Since("1.6.0") + override def write: Writer = new CountVectorizerModelWriter(this) +} + +@Since("1.6.0") +object CountVectorizerModel extends Readable[CountVectorizerModel] { + + private[CountVectorizerModel] + class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer { + + private case class Data(vocabulary: Seq[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.vocabulary) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class CountVectorizerModelReader extends Reader[CountVectorizerModel] { + + private val className = "org.apache.spark.ml.feature.CountVectorizerModel" + + override def load(path: String): CountVectorizerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabulary") + .head() + val vocabulary = data.getAs[Seq[String]](0).toArray + val model = new CountVectorizerModel(metadata.uid, vocabulary) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader + + @Since("1.6.0") + override def load(path: String): CountVectorizerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 4c36df75d8aa0..53ad34ef12646 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -60,7 +62,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable { def this() = this(Identifiable.randomUID("idf")) @@ -85,6 +87,19 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa } override def copy(extra: ParamMap): IDF = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object IDF extends Readable[IDF] { + + @Since("1.6.0") + override def read: Reader[IDF] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): IDF = super.load(path) } /** @@ -95,7 +110,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) - extends Model[IDFModel] with IDFBase { + extends Model[IDFModel] with IDFBase with Writable { + + import IDFModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -117,4 +134,50 @@ class IDFModel private[ml] ( val copied = new IDFModel(uid, idfModel) copyValues(copied, extra).setParent(parent) } + + /** Returns the IDF vector. */ + @Since("1.6.0") + def idf: Vector = idfModel.idf + + @Since("1.6.0") + override def write: Writer = new IDFModelWriter(this) +} + +@Since("1.6.0") +object IDFModel extends Readable[IDFModel] { + + private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer { + + private case class Data(idf: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.idf) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class IDFModelReader extends Reader[IDFModel] { + + private val className = "org.apache.spark.ml.feature.IDFModel" + + override def load(path: String): IDFModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("idf") + .head() + val idf = data.getAs[Vector](0) + val model = new IDFModel(metadata.uid, new feature.IDFModel(idf)) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[IDFModel] = new IDFModelReader + + @Since("1.6.0") + override def load(path: String): IDFModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 1b494ec8b1727..24d964fae834e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -17,11 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params} -import org.apache.spark.ml.util.Identifiable + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.mllib.stat.Statistics import org.apache.spark.sql._ @@ -85,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H */ @Experimental class MinMaxScaler(override val uid: String) - extends Estimator[MinMaxScalerModel] with MinMaxScalerParams { + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable { def this() = this(Identifiable.randomUID("minMaxScal")) @@ -115,6 +118,19 @@ class MinMaxScaler(override val uid: String) } override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object MinMaxScaler extends Readable[MinMaxScaler] { + + @Since("1.6.0") + override def read: Reader[MinMaxScaler] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): MinMaxScaler = super.load(path) } /** @@ -131,7 +147,9 @@ class MinMaxScalerModel private[ml] ( override val uid: String, val originalMin: Vector, val originalMax: Vector) - extends Model[MinMaxScalerModel] with MinMaxScalerParams { + extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable { + + import MinMaxScalerModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -175,4 +193,46 @@ class MinMaxScalerModel private[ml] ( val copied = new MinMaxScalerModel(uid, originalMin, originalMax) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new MinMaxScalerModelWriter(this) +} + +@Since("1.6.0") +object MinMaxScalerModel extends Readable[MinMaxScalerModel] { + + private[MinMaxScalerModel] + class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer { + + private case class Data(originalMin: Vector, originalMax: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = new Data(instance.originalMin, instance.originalMax) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] { + + private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" + + override def load(path: String): MinMaxScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath) + .select("originalMin", "originalMax") + .head() + val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader + + @Since("1.6.0") + override def load(path: String): MinMaxScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index f6d0b0c0e9e75..ab04e5418dd4f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -57,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with */ @Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams { + with StandardScalerParams with Writable { def this() = this(Identifiable.randomUID("stdScal")) @@ -94,6 +96,19 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object StandardScaler extends Readable[StandardScaler] { + + @Since("1.6.0") + override def read: Reader[StandardScaler] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): StandardScaler = super.load(path) } /** @@ -104,7 +119,9 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) - extends Model[StandardScalerModel] with StandardScalerParams { + extends Model[StandardScalerModel] with StandardScalerParams with Writable { + + import StandardScalerModel._ /** Standard deviation of the StandardScalerModel */ val std: Vector = scaler.std @@ -112,6 +129,14 @@ class StandardScalerModel private[ml] ( /** Mean of the StandardScalerModel */ val mean: Vector = scaler.mean + /** Whether to scale to unit standard deviation. */ + @Since("1.6.0") + def getWithStd: Boolean = scaler.withStd + + /** Whether to center data with mean. */ + @Since("1.6.0") + def getWithMean: Boolean = scaler.withMean + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -138,4 +163,49 @@ class StandardScalerModel private[ml] ( val copied = new StandardScalerModel(uid, scaler) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new StandardScalerModelWriter(this) +} + +@Since("1.6.0") +object StandardScalerModel extends Readable[StandardScalerModel] { + + private[StandardScalerModel] + class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer { + + private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StandardScalerModelReader extends Reader[StandardScalerModel] { + + private val className = "org.apache.spark.ml.feature.StandardScalerModel" + + override def load(path: String): StandardScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) = + sqlContext.read.parquet(dataPath) + .select("std", "mean", "withStd", "withMean") + .head() + // This is very likely to change in the future because withStd and withMean should be params. + val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean) + val model = new StandardScalerModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[StandardScalerModel] = new StandardScalerModelReader + + @Since("1.6.0") + override def load(path: String): StandardScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f782a272d11db..f16f6afc002d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,13 +17,14 @@ package org.apache.spark.ml.feature +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.Transformer import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ @@ -64,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] - with StringIndexerBase { + with StringIndexerBase with Writable { def this() = this(Identifiable.randomUID("strIdx")) @@ -92,6 +93,19 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod } override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object StringIndexer extends Readable[StringIndexer] { + + @Since("1.6.0") + override def read: Reader[StringIndexer] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): StringIndexer = super.load(path) } /** @@ -107,7 +121,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod @Experimental class StringIndexerModel ( override val uid: String, - val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + val labels: Array[String]) + extends Model[StringIndexerModel] with StringIndexerBase with Writable { + + import StringIndexerModel._ def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) @@ -176,6 +193,49 @@ class StringIndexerModel ( val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: StringIndexModelWriter = new StringIndexModelWriter(this) +} + +@Since("1.6.0") +object StringIndexerModel extends Readable[StringIndexerModel] { + + private[StringIndexerModel] + class StringIndexModelWriter(instance: StringIndexerModel) extends Writer { + + private case class Data(labels: Array[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.labels) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StringIndexerModelReader extends Reader[StringIndexerModel] { + + private val className = "org.apache.spark.ml.feature.StringIndexerModel" + + override def load(path: String): StringIndexerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("labels") + .head() + val labels = data.getAs[Seq[String]](0).toArray + val model = new StringIndexerModel(metadata.uid, labels) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[StringIndexerModel] = new StringIndexerModelReader + + @Since("1.6.0") + override def load(path: String): StringIndexerModel = super.load(path) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index e192fa4850af0..9c9999017317d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -18,14 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { test("params") { + ParamsSuite.checkParams(new CountVectorizer) ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) } @@ -164,4 +167,23 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(features ~== expected absTol 1e-14) } } + + test("CountVectorizer read/write") { + val t = new CountVectorizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDF(0.5) + .setMinTF(3.0) + .setVocabSize(10) + testDefaultReadWrite(t) + } + + test("CountVectorizerModel read/write") { + val instance = new CountVectorizerModel("myCountVectorizerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinTF(3.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.vocabulary === instance.vocabulary) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 08f80af03429b..bc958c15857ba 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { @@ -98,4 +99,20 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } } + + test("IDF read/write") { + val t = new IDF() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDocFreq(5) + testDefaultReadWrite(t) + } + + test("IDFModel read/write") { + val instance = new IDFModel("myIDFModel", new OldIDFModel(Vectors.dense(1.0, 2.0))) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.idf === instance.idf) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index c04dda41eea34..09183fe65b722 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} -class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("MinMaxScaler fit basic case") { val sqlContext = new SQLContext(sc) @@ -69,4 +69,25 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("MinMaxScaler read/write") { + val t = new MinMaxScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMax(1.0) + .setMin(-1.0) + testDefaultReadWrite(t) + } + + test("MinMaxScalerModel read/write") { + val instance = new MinMaxScalerModel( + "myMinMaxScalerModel", Vectors.dense(-1.0, 0.0), Vectors.dense(1.0, 10.0)) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMin(-1.0) + .setMax(1.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.originalMin === instance.originalMin) + assert(newInstance.originalMax === instance.originalMax) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 879a3ae875004..49a4b2efe0c29 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -19,12 +19,16 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} -class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { @transient var data: Array[Vector] = _ @transient var resWithStd: Array[Vector] = _ @@ -56,23 +60,29 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ ) } - def assertResult(dataframe: DataFrame): Unit = { - dataframe.select("standarded_features", "expected").collect().foreach { + def assertResult(df: DataFrame): Unit = { + df.select("standardized_features", "expected").collect().foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== vector2 absTol 1E-5, "The vector value is not correct after standardization.") } } + test("params") { + ParamsSuite.checkParams(new StandardScaler) + val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0)) + ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel)) + } + test("Standardization with default parameter") { val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") - val standardscaler0 = new StandardScaler() + val standardScaler0 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .fit(df0) - assertResult(standardscaler0.transform(df0)) + assertResult(standardScaler0.transform(df0)) } test("Standardization with setter") { @@ -80,29 +90,49 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") - val standardscaler1 = new StandardScaler() + val standardScaler1 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(true) .setWithStd(true) .fit(df1) - val standardscaler2 = new StandardScaler() + val standardScaler2 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(true) .setWithStd(false) .fit(df2) - val standardscaler3 = new StandardScaler() + val standardScaler3 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(false) .setWithStd(false) .fit(df3) - assertResult(standardscaler1.transform(df1)) - assertResult(standardscaler2.transform(df2)) - assertResult(standardscaler3.transform(df3)) + assertResult(standardScaler1.transform(df1)) + assertResult(standardScaler2.transform(df2)) + assertResult(standardScaler3.transform(df3)) + } + + test("StandardScaler read/write") { + val t = new StandardScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setWithStd(false) + .setWithMean(true) + testDefaultReadWrite(t) + } + + test("StandardScalerModel read/write") { + val oldModel = new feature.StandardScalerModel( + Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true) + val instance = new StandardScalerModel("myStandardScalerModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.std === instance.std) + assert(newInstance.mean === instance.mean) + assert(newInstance.getWithStd === instance.getWithStd) + assert(newInstance.getWithMean === instance.getWithMean) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index be37bfb438833..749bfac747826 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -118,6 +118,23 @@ class StringIndexerSuite assert(indexerModel.transform(df).eq(df)) } + test("StringIndexer read/write") { + val t = new StringIndexer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + testDefaultReadWrite(t) + } + + test("StringIndexerModel read/write") { + val instance = new StringIndexerModel("myStringIndexerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.labels === instance.labels) + } + test("IndexToString params") { val idxToStr = new IndexToString() ParamsSuite.checkParams(idxToStr) @@ -175,7 +192,7 @@ class StringIndexerSuite assert(outSchema("output").dataType === StringType) } - test("read/write") { + test("IndexToString read/write") { val t = new IndexToString() .setInputCol("myInputCol") .setOutputCol("myOutputCol") From ef4de26d68d8d8fc293396c3d77e39296b339180 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 18 Nov 2015 15:55:41 -0800 Subject: [PATCH 0810/1168] [SPARK-11649] Properly set Akka frame size in SparkListenerSuite test SparkListenerSuite's _"onTaskGettingResult() called when result fetched remotely"_ test was extremely slow (1 to 4 minutes to run) and recently became extremely flaky, frequently failing with OutOfMemoryError. The root cause was the fact that this was using `System.setProperty` to set the Akka frame size, which was not actually modifying the frame size. As a result, this test would allocate much more data than necessary. The fix here is to simply use SparkConf in order to configure the frame size. Author: Josh Rosen Closes #9822 from JoshRosen/SPARK-11649. --- .../org/apache/spark/scheduler/SparkListenerSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 53102b9f1c936..84e545851f49e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -269,14 +269,15 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("onTaskGettingResult() called when result fetched remotely") { - sc = new SparkContext("local", "SparkListenerSuite") + val conf = new SparkConf().set("spark.akka.frameSize", "1") + sc = new SparkContext("local", "SparkListenerSuite", conf) val listener = new SaveTaskEvents sc.addSparkListener(listener) // Make a task whose result is larger than the akka frame size - System.setProperty("spark.akka.frameSize", "1") val akkaFrameSize = sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt + assert(akkaFrameSize === 1024 * 1024) val result = sc.parallelize(Seq(1), 1) .map { x => 1.to(akkaFrameSize).toArray } .reduce { case (x, y) => x } From b62c54898fcd940bfd953e701387139d12cf0a39 Mon Sep 17 00:00:00 2001 From: Derek Dagit Date: Wed, 18 Nov 2015 15:56:54 -0800 Subject: [PATCH 0811/1168] [SPARK-10930] History "Stages" page "duration" can be confusing Author: Derek Dagit Closes #9051 from d2r/spark-10930-ui-max-task-dur. --- .../org/apache/spark/ui/jobs/StageTable.scala | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index ea806d09b6009..2a1c3c1a50ec9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -145,9 +145,22 @@ private[ui] class StageTableBase( case None => "Unknown" } val finishTime = s.completionTime.getOrElse(System.currentTimeMillis) - val duration = s.submissionTime.map { t => - if (finishTime > t) finishTime - t else System.currentTimeMillis - t - } + + // The submission time for a stage is misleading because it counts the time + // the stage waits to be launched. (SPARK-10930) + val taskLaunchTimes = + stageData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) + val duration: Option[Long] = + if (taskLaunchTimes.nonEmpty) { + val startTime = taskLaunchTimes.min + if (finishTime > startTime) { + Some(finishTime - startTime) + } else { + Some(System.currentTimeMillis() - startTime) + } + } else { + None + } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val inputRead = stageData.inputBytes From 435c3faebddf80d3e65fa3fb8db769ef5f836873 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 18 Nov 2015 16:00:35 -0800 Subject: [PATCH 0812/1168] [SPARK-11495] Fix potential socket / file handle leaks that were found via static analysis The HP Fortify Opens Source Review team (https://www.hpfod.com/open-source-review-project) reported a handful of potential resource leaks that were discovered using their static analysis tool. We should fix the issues identified by their scan. Author: Josh Rosen Closes #9455 from JoshRosen/fix-potential-resource-leaks. --- .../spark/unsafe/map/BytesToBytesMap.java | 7 ++++ .../unsafe/sort/UnsafeSorterSpillReader.java | 38 +++++++++++-------- .../streaming/JavaCustomReceiver.java | 31 +++++++-------- .../network/ChunkFetchIntegrationSuite.java | 15 ++++++-- .../shuffle/TestShuffleDataContext.java | 32 ++++++++++------ .../spark/streaming/JavaReceiverAPISuite.java | 20 ++++++---- 6 files changed, 90 insertions(+), 53 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 04694dc54418c..3387f9a4177ce 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -24,6 +24,7 @@ import java.util.LinkedList; import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -272,6 +273,7 @@ private void advanceToNextPage() { } } try { + Closeables.close(reader, /* swallowIOException = */ false); reader = spillWriters.getFirst().getReader(blockManager); recordsInPage = -1; } catch (IOException e) { @@ -318,6 +320,11 @@ public Location next() { try { reader.loadNext(); } catch (IOException e) { + try { + reader.close(); + } catch(IOException e2) { + logger.error("Error while closing spill reader", e2); + } // Scala iterator does not handle exception Platform.throwException(e); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 039e940a357ea..dcb13e6581e54 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -20,8 +20,7 @@ import java.io.*; import com.google.common.io.ByteStreams; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import com.google.common.io.Closeables; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; @@ -31,10 +30,8 @@ * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description * of the file format). */ -public final class UnsafeSorterSpillReader extends UnsafeSorterIterator { - private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class); +public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable { - private final File file; private InputStream in; private DataInputStream din; @@ -52,11 +49,15 @@ public UnsafeSorterSpillReader( File file, BlockId blockId) throws IOException { assert (file.length() > 0); - this.file = file; final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); - this.in = blockManager.wrapForCompression(blockId, bs); - this.din = new DataInputStream(this.in); - numRecordsRemaining = din.readInt(); + try { + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + numRecordsRemaining = din.readInt(); + } catch (IOException e) { + Closeables.close(bs, /* swallowIOException = */ true); + throw e; + } } @Override @@ -75,12 +76,7 @@ public void loadNext() throws IOException { ByteStreams.readFully(in, arr, 0, recordLength); numRecordsRemaining--; if (numRecordsRemaining == 0) { - in.close(); - if (!file.delete() && file.exists()) { - logger.warn("Unable to delete spill file {}", file.getPath()); - } - in = null; - din = null; + close(); } } @@ -103,4 +99,16 @@ public int getRecordLength() { public long getKeyPrefix() { return keyPrefix; } + + @Override + public void close() throws IOException { + if (in != null) { + try { + in.close(); + } finally { + in = null; + din = null; + } + } + } } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 99df259b4e8e6..4b50fbf59f80e 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -18,6 +18,7 @@ package org.apache.spark.examples.streaming; import com.google.common.collect.Lists; +import com.google.common.io.Closeables; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -121,23 +122,23 @@ public void onStop() { /** Create a socket connection and receive data until receiver is stopped */ private void receive() { - Socket socket = null; - String userInput = null; - try { - // connect to the server - socket = new Socket(host, port); - - BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); - - // Until stopped or connection broken continue reading - while (!isStopped() && (userInput = reader.readLine()) != null) { - System.out.println("Received data '" + userInput + "'"); - store(userInput); + Socket socket = null; + BufferedReader reader = null; + String userInput = null; + try { + // connect to the server + socket = new Socket(host, port); + reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); + // Until stopped or connection broken continue reading + while (!isStopped() && (userInput = reader.readLine()) != null) { + System.out.println("Received data '" + userInput + "'"); + store(userInput); + } + } finally { + Closeables.close(reader, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); } - reader.close(); - socket.close(); - // Restart in an attempt to connect again when server is active again restart("Trying to connect again"); } catch(ConnectException ce) { diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index dc5fa1cee69bc..50a324e293386 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -31,6 +31,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import com.google.common.io.Closeables; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -78,10 +79,15 @@ public static void setUp() throws Exception { testFile = File.createTempFile("shuffle-test-file", "txt"); testFile.deleteOnExit(); RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); - byte[] fileContent = new byte[1024]; - new Random().nextBytes(fileContent); - fp.write(fileContent); - fp.close(); + boolean shouldSuppressIOException = true; + try { + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + shouldSuppressIOException = false; + } finally { + Closeables.close(fp, shouldSuppressIOException); + } final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); @@ -117,6 +123,7 @@ public StreamManager getStreamManager() { @AfterClass public static void tearDown() { + bufferChunk.release(); server.close(); clientFactory.close(); testFile.delete(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 3fdde054ab6c7..7ac1ca128aed0 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.io.OutputStream; +import com.google.common.io.Closeables; import com.google.common.io.Files; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; @@ -60,21 +61,28 @@ public void cleanup() { public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; - OutputStream dataStream = new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); - DataOutputStream indexStream = new DataOutputStream(new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + OutputStream dataStream = null; + DataOutputStream indexStream = null; + boolean suppressExceptionsDuringClose = true; - long offset = 0; - indexStream.writeLong(offset); - for (byte[] block : blocks) { - offset += block.length; - dataStream.write(block); + try { + dataStream = new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); + indexStream = new DataOutputStream(new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + + long offset = 0; indexStream.writeLong(offset); + for (byte[] block : blocks) { + offset += block.length; + dataStream.write(block); + indexStream.writeLong(offset); + } + suppressExceptionsDuringClose = false; + } finally { + Closeables.close(dataStream, suppressExceptionsDuringClose); + Closeables.close(indexStream, suppressExceptionsDuringClose); } - - dataStream.close(); - indexStream.close(); } /** Creates reducer blocks in a hash-based data format within our local dirs. */ diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index ec2bffd6a5b97..7a8ef9d14784c 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -23,6 +23,7 @@ import org.apache.spark.streaming.api.java.JavaStreamingContext; import static org.junit.Assert.*; +import com.google.common.io.Closeables; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -121,14 +122,19 @@ public void onStop() { private void receive() { try { - Socket socket = new Socket(host, port); - BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); - String userInput; - while ((userInput = in.readLine()) != null) { - store(userInput); + Socket socket = null; + BufferedReader in = null; + try { + socket = new Socket(host, port); + in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + String userInput; + while ((userInput = in.readLine()) != null) { + store(userInput); + } + } finally { + Closeables.close(in, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); } - in.close(); - socket.close(); } catch(ConnectException ce) { ce.printStackTrace(); restart("Could not connect", ce); From 3e97c6570ff4eced900ce2b8995439233996addb Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 18 Nov 2015 16:08:06 -0800 Subject: [PATCH 0813/1168] [SPARK-11814][STREAMING] Add better default checkpoint duration DStream checkpoint interval is by default set at max(10 second, batch interval). That's bad for large batch intervals where the checkpoint interval = batch interval, and RDDs get checkpointed every batch. This PR is to set the checkpoint interval of trackStateByKey to 10 * batch duration. Author: Tathagata Das Closes #9805 from tdas/SPARK-11814. --- .../streaming/dstream/TrackStateDStream.scala | 13 ++++++ .../streaming/TrackStateByKeySuite.scala | 44 ++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala index 98e881e6ae115..0ada1111ce30a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -25,6 +25,7 @@ import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} +import org.apache.spark.streaming.dstream.InternalTrackStateDStream._ /** * :: Experimental :: @@ -120,6 +121,14 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT /** Enable automatic checkpointing */ override val mustCheckpoint = true + /** Override the default checkpoint duration */ + override def initialize(time: Time): Unit = { + if (checkpointDuration == null) { + checkpointDuration = slideDuration * DEFAULT_CHECKPOINT_DURATION_MULTIPLIER + } + super.initialize(time) + } + /** Method that generates a RDD for the given time */ override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD @@ -141,3 +150,7 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT } } } + +private[streaming] object InternalTrackStateDStream { + private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10 +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index e3072b4442840..58aef74c0040f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -22,9 +22,10 @@ import java.io.File import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag +import org.scalatest.PrivateMethodTester._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.streaming.dstream.{TrackStateDStream, TrackStateDStreamImpl} +import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} @@ -57,6 +58,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef sc = new SparkContext(conf) } + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + test("state - get, exists, update, remove, ") { var state: StateImpl[Int] = null @@ -436,6 +443,41 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) } + test("trackStateByKey - checkpoint durations") { + val privateMethod = PrivateMethod[InternalTrackStateDStream[_, _, _, _]]('internalStream) + + def testCheckpointDuration( + batchDuration: Duration, + expectedCheckpointDuration: Duration, + explicitCheckpointDuration: Option[Duration] = None + ): Unit = { + try { + ssc = new StreamingContext(sc, batchDuration) + val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) + val dummyFunc = (value: Option[Int], state: State[Int]) => 0 + val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc)) + val internalTrackStateStream = trackStateStream invokePrivate privateMethod() + + explicitCheckpointDuration.foreach { d => + trackStateStream.checkpoint(d) + } + trackStateStream.register() + ssc.start() // should initialize all the checkpoint durations + assert(trackStateStream.checkpointDuration === null) + assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration) + } finally { + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + } + } + + testCheckpointDuration(Milliseconds(100), Seconds(1)) + testCheckpointDuration(Seconds(1), Seconds(10)) + testCheckpointDuration(Seconds(10), Seconds(100)) + + testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) + } private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], From 5d9d43cff75310cf9c0daaf718750127673e2729 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 18 Nov 2015 16:19:00 -0800 Subject: [PATCH 0814/1168] [SPARK-11791] Fix flaky test in BatchedWriteAheadLogSuite stack trace of failure: ``` org.scalatest.exceptions.TestFailedDueToTimeoutException: The code passed to eventually never returned normally. Attempted 62 times over 1.006322071 seconds. Last failure message: Argument(s) are different! Wanted: writeAheadLog.write( java.nio.HeapByteBuffer[pos=0 lim=124 cap=124], 10 ); -> at org.apache.spark.streaming.util.BatchedWriteAheadLogSuite$$anonfun$23$$anonfun$apply$mcV$sp$15.apply(WriteAheadLogSuite.scala:518) Actual invocation has different arguments: writeAheadLog.write( java.nio.HeapByteBuffer[pos=0 lim=124 cap=124], 10 ); -> at org.apache.spark.streaming.util.WriteAheadLogSuite$BlockingWriteAheadLog.write(WriteAheadLogSuite.scala:756) ``` I believe the issue was that due to a race condition, the ordering of the events could be messed up in the final ByteBuffer, therefore the comparison fails. By adding eventually between the requests, we make sure the ordering is preserved. Note that in real life situations, the ordering across threads will not matter. Another solution would be to implement a custom mockito matcher that sorts and then compares the results, but that kind of sounds like overkill to me. Let me know what you think tdas zsxwing Author: Burak Yavuz Closes #9790 from brkyvz/fix-flaky-2. --- .../spark/streaming/util/WriteAheadLogSuite.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 7f80d6ecdbbb5..eaa88ea3cd380 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -30,6 +30,7 @@ import scala.language.{implicitConversions, postfixOps} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.mockito.ArgumentCaptor import org.mockito.Matchers.{eq => meq} import org.mockito.Matchers._ import org.mockito.Mockito._ @@ -507,15 +508,18 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( } blockingWal.allowWrite() - val buffer1 = wrapArrayArrayByte(Array(event1)) - val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5)) + val buffer = wrapArrayArrayByte(Array(event1)) + val queuedEvents = Set(event2, event3, event4, event5) eventually(timeout(1 second)) { assert(batchedWal.invokePrivate(queueLength()) === 0) - verify(wal, times(1)).write(meq(buffer1), meq(3L)) + verify(wal, times(1)).write(meq(buffer), meq(3L)) // the file name should be the timestamp of the last record, as events should be naturally // in order of timestamp, and we need the last element. - verify(wal, times(1)).write(meq(buffer2), meq(10L)) + val bufferCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + verify(wal, times(1)).write(bufferCaptor.capture(), meq(10L)) + val records = BatchedWriteAheadLog.deaggregate(bufferCaptor.getValue).map(byteBufferToString) + assert(records.toSet === queuedEvents) } } From 4cd0426f640798f77eb2a78c964cfb98a42a3a26 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 18 Nov 2015 16:48:09 -0800 Subject: [PATCH 0815/1168] [SPARK-11636][SQL] Support classes defined in the REPL with Encoders Before this PR there were two things that would blow up if you called `df.as[MyClass]` if `MyClass` was defined in the REPL: - [x] Because `classForName` doesn't work on the munged names returned by `tpe.erasure.typeSymbol.asClass.fullName` - [x] Because we don't have anything to pass into the constructor for the `$outer` pointer. Note that this PR is just adding the infrastructure for working with inner classes in encoder and is not yet sufficient to make them work in the REPL. Currently, the implementation show in https://github.com/marmbrus/spark/commit/95cec7d413b930b36420724fafd829bef8c732ab is causing a bug that breaks code gen due to some interaction between janino and the `ExecutorClassLoader`. This will be addressed in a follow-up PR. Author: Michael Armbrust Closes #9602 from marmbrus/dataset-replClasses. --- .../spark/sql/catalyst/ScalaReflection.scala | 81 ++++++++++--------- .../catalyst/encoders/ExpressionEncoder.scala | 26 +++++- .../sql/catalyst/encoders/OuterScopes.scala | 42 ++++++++++ .../catalyst/encoders/ProductEncoder.scala | 6 +- .../expressions/codegen/CodegenFallback.scala | 2 +- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateProjection.scala | 10 +-- .../codegen/GenerateSafeProjection.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../codegen/GenerateUnsafeRowJoiner.scala | 6 +- .../sql/catalyst/expressions/literals.scala | 6 ++ .../sql/catalyst/expressions/objects.scala | 42 ++++++++-- .../encoders/ExpressionEncoderSuite.scala | 7 +- .../encoders/ProductEncoderSuite.scala | 4 + .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../org/apache/spark/sql/GroupedDataset.scala | 8 +- .../aggregate/TypedAggregateExpression.scala | 19 ++--- 17 files changed, 193 insertions(+), 82 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 38828e59a2152..59ccf356f2c48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -35,17 +35,6 @@ object ScalaReflection extends ScalaReflection { // class loader of the current thread. override def mirror: universe.Mirror = universe.runtimeMirror(Thread.currentThread().getContextClassLoader) -} - -/** - * Support for generating catalyst schemas for scala objects. - */ -trait ScalaReflection { - /** The universe we work in (runtime or macro) */ - val universe: scala.reflect.api.Universe - - /** The mirror used to access types in the universe */ - def mirror: universe.Mirror import universe._ @@ -53,30 +42,6 @@ trait ScalaReflection { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - case class Schema(dataType: DataType, nullable: Boolean) - - /** Returns a Sequence of attributes for the given case class type. */ - def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { - case Schema(s: StructType, _) => - s.toAttributes - } - - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = - ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } - - /** - * Return the Scala Type for `T` in the current classloader mirror. - * - * Use this method instead of the convenience method `universe.typeOf`, which - * assumes that all types can be found in the classloader that loaded scala-reflect classes. - * That's not necessarily the case when running using Eclipse launchers or even - * Sbt console or test (without `fork := true`). - * - * @see SPARK-5281 - */ - def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe - /** * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping * to a native type, an ObjectType is returned. Special handling is also used for Arrays including @@ -114,7 +79,9 @@ trait ScalaReflection { } ObjectType(cls) - case other => ObjectType(Utils.classForName(className)) + case other => + val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + ObjectType(clazz) } } @@ -640,6 +607,48 @@ trait ScalaReflection { } } } +} + +/** + * Support for generating catalyst schemas for scala objects. Note that unlike its companion + * object, this trait able to work in both the runtime and the compile time (macro) universe. + */ +trait ScalaReflection { + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + /** The mirror used to access types in the universe */ + def mirror: universe.Mirror + + import universe._ + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + case class Schema(dataType: DataType, nullable: Boolean) + + /** Returns a Sequence of attributes for the given case class type. */ + def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { + case Schema(s: StructType, _) => + s.toAttributes + } + + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor[T: TypeTag]: Schema = + ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } + + /** + * Return the Scala Type for `T` in the current classloader mirror. + * + * Use this method instead of the convenience method `universe.typeOf`, which + * assumes that all types can be found in the classloader that loaded scala-reflect classes. + * That's not necessarily the case when running using Eclipse launchers or even + * Sbt console or test (without `fork := true`). + * + * @see SPARK-5281 + */ + def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index b977f278c5b5c..456b595008479 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.encoders +import java.util.concurrent.ConcurrentMap + import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.util.Utils -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.{AnalysisException, Encoder} import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ @@ -211,7 +213,9 @@ case class ExpressionEncoder[T]( * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the * given schema. */ - def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { + def resolve( + schema: Seq[Attribute], + outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { val positionToAttribute = AttributeMap.toIndex(schema) val unbound = fromRowExpression transform { case b: BoundReference => positionToAttribute(b.ordinal) @@ -219,7 +223,23 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) - copy(fromRowExpression = analyzedPlan.expressions.head.children.head) + + // In order to construct instances of inner classes (for example those declared in a REPL cell), + // we need an instance of the outer scope. This rule substitues those outer objects into + // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` + // registry. + copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform { + case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + val outer = outerScopes.get(n.cls.getDeclaringClass.getName) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " + + s"to the scope that this class was defined in. " + "" + + "Try moving this class out of its parent class.") + } + + n.copy(outerPointer = Some(Literal.fromObject(outer))) + }) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala new file mode 100644 index 0000000000000..a753b187bcd32 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.util.concurrent.ConcurrentMap + +import com.google.common.collect.MapMaker + +object OuterScopes { + @transient + lazy val outerScopes: ConcurrentMap[String, AnyRef] = + new MapMaker().weakValues().makeMap() + + /** + * Adds a new outer scope to this context that can be used when instantiating an `inner class` + * during deserialialization. Inner classes are created when a case class is defined in the + * Spark REPL and registering the outer scope that this class was defined in allows us to create + * new instances on the spark executors. In normal use, users should not need to call this + * function. + * + * Warning: this function operates on the assumption that there is only ever one instance of any + * given wrapper class. + */ + def addOuterScope(outer: AnyRef): Unit = { + outerScopes.putIfAbsent(outer.getClass.getName, outer) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index 55c4ee11b20f4..2914c6ee790ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -31,6 +31,7 @@ import scala.reflect.ClassTag object ProductEncoder { import ScalaReflection.universe._ + import ScalaReflection.mirror import ScalaReflection.localTypeOf import ScalaReflection.dataTypeFor import ScalaReflection.Schema @@ -420,8 +421,7 @@ object ProductEncoder { } } - val className: String = t.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) + val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString @@ -429,7 +429,7 @@ object ProductEncoder { val dataType = schemaFor(fieldType).dataType // For tuples, we based grab the inner fields by ordinal instead of name. - if (className startsWith "scala.Tuple") { + if (cls.getName startsWith "scala.Tuple") { constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) } else { constructorFor(fieldType, Some(addToPath(fieldName))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index d51a8dede7f34..a31574c251af5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -34,7 +34,7 @@ trait CodegenFallback extends Expression { val objectTerm = ctx.freshName("obj") s""" /* expression: ${this} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4b66069b5f55a..40189f0877764 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -82,7 +82,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val code = s""" - public Object generate($exprType[] expr) { + public java.lang.Object generate($exprType[] expr) { return new SpecificMutableProjection(expr); } @@ -109,7 +109,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allProjections // copy all the results into MutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c0d313b2e1301..f229f2000d8e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -167,7 +167,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { ${initMutableStates(ctx)} } - public Object apply(Object r) { + public java.lang.Object apply(java.lang.Object r) { // GenerateProjection does not work with UnsafeRows. assert(!(r instanceof ${classOf[UnsafeRow].getName})); return new SpecificRow((InternalRow) r); @@ -186,14 +186,14 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - public Object genericGet(int i) { + public java.lang.Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases } return null; } - public void update(int i, Object value) { + public void update(int i, java.lang.Object value) { if (value == null) { setNullAt(i); return; @@ -212,7 +212,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return result; } - public boolean equals(Object other) { + public boolean equals(java.lang.Object other) { if (other instanceof SpecificRow) { SpecificRow row = (SpecificRow) other; $columnChecks @@ -222,7 +222,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } public InternalRow copy() { - Object[] arr = new Object[${expressions.length}]; + java.lang.Object[] arr = new java.lang.Object[${expressions.length}]; ${copyColumns} return new ${classOf[GenericInternalRow].getName}(arr); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f0ed8645d923f..b7926bda3de19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -148,7 +148,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) val code = s""" - public Object generate($exprType[] expr) { + public java.lang.Object generate($exprType[] expr) { return new SpecificSafeProjection(expr); } @@ -165,7 +165,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${initMutableStates(ctx)} } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allExpressions return mutableRow; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4c17d02a23725..7b6c9373ebe30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -324,7 +324,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) val code = s""" - public Object generate($exprType[] exprs) { + public java.lang.Object generate($exprType[] exprs) { return new SpecificUnsafeProjection(exprs); } @@ -342,7 +342,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } // Scala.Function1 need this - public Object apply(Object row) { + public java.lang.Object apply(java.lang.Object row) { return apply((InternalRow) row); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index da91ff29537b3..da602d9b4bce1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -159,7 +159,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // ------------------------ Finally, put everything together --------------------------- // val code = s""" - |public Object generate($exprType[] exprs) { + |public java.lang.Object generate($exprType[] exprs) { | return new SpecificUnsafeRowJoiner(); |} | @@ -176,9 +176,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | buf = new byte[sizeInBytes]; | } | - | final Object obj1 = row1.getBaseObject(); + | final java.lang.Object obj1 = row1.getBaseObject(); | final long offset1 = row1.getBaseOffset(); - | final Object obj2 = row2.getBaseObject(); + | final java.lang.Object obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | | $copyBitset diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 455fa2427c26d..e34fd49be8389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -48,6 +48,12 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + /** + * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object + * into code generation. + */ + def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass)) + def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index acf0da240051e..f865a9408ef4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.encoders.ProductEncoder import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow @@ -178,6 +179,15 @@ case class Invoke( } } +object NewInstance { + def apply( + cls: Class[_], + arguments: Seq[Expression], + propagateNull: Boolean = false, + dataType: DataType): NewInstance = + new NewInstance(cls, arguments, propagateNull, dataType, None) +} + /** * Constructs a new instance of the given class, using the result of evaluating the specified * expressions as arguments. @@ -189,12 +199,15 @@ case class Invoke( * @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you * to manually specify the type when the object in question is a valid internal * representation (i.e. ArrayData) instead of an object. + * @param outerPointer If the object being constructed is an inner class the outerPointer must + * for the containing class must be specified. */ case class NewInstance( cls: Class[_], arguments: Seq[Expression], - propagateNull: Boolean = true, - dataType: DataType) extends Expression { + propagateNull: Boolean, + dataType: DataType, + outerPointer: Option[Literal]) extends Expression { private val className = cls.getName override def nullable: Boolean = propagateNull @@ -209,30 +222,43 @@ case class NewInstance( val argGen = arguments.map(_.gen(ctx)) val argString = argGen.map(_.value).mkString(", ") + val outer = outerPointer.map(_.gen(ctx)) + + val setup = + s""" + ${argGen.map(_.code).mkString("\n")} + ${outer.map(_.code.mkString("")).getOrElse("")} + """.stripMargin + + val constructorCall = outer.map { gen => + s"""${gen.value}.new ${cls.getSimpleName}($argString)""" + }.getOrElse { + s"new $className($argString)" + } + if (propagateNull) { val objNullCheck = if (ctx.defaultValue(dataType) == "null") { s"${ev.isNull} = ${ev.value} == null;" } else { "" } - val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" + s""" - ${argGen.map(_.code).mkString("\n")} + $setup boolean ${ev.isNull} = true; $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - if ($argsNonNull) { - ${ev.value} = new $className($argString); + ${ev.value} = $constructorCall; ${ev.isNull} = false; } """ } else { s""" - ${argGen.map(_.code).mkString("\n")} + $setup - $javaType ${ev.value} = new $className($argString); + $javaType ${ev.value} = $constructorCall; final boolean ${ev.isNull} = ${ev.value} == null; """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 9fe64b4cf10e4..cde0364f3dd9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.catalyst.encoders import java.util.Arrays +import java.util.concurrent.ConcurrentMap + +import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.AttributeReference @@ -25,6 +28,8 @@ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.ArrayType abstract class ExpressionEncoderSuite extends SparkFunSuite { + val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + protected def encodeDecodeTest[T]( input: T, encoder: ExpressionEncoder[T], @@ -32,7 +37,7 @@ abstract class ExpressionEncoderSuite extends SparkFunSuite { test(s"encode/decode for $testName: $input") { val row = encoder.toRow(input) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.resolve(schema).bind(schema) + val boundEncoder = encoder.resolve(schema, outers).bind(schema) val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala index bc539d62c537d..1798514c5c38b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -53,6 +53,10 @@ case class RepeatedData( case class SpecificCollection(l: List[Int]) class ProductEncoderSuite extends ExpressionEncoderSuite { + outers.put(getClass.getName, this) + + case class InnerClass(i: Int) + productTest(InnerClass(1)) productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b644f6ad3096d..bdcdc5d47cbae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -74,7 +74,7 @@ class Dataset[T] private[sql]( /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(queryExecution.analyzed.output) + unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes) private implicit def classTag = resolvedTEncoder.clsTag @@ -375,7 +375,7 @@ class Dataset[T] private[sql]( sqlContext, Project( c1.withInputType( - resolvedTEncoder, + resolvedTEncoder.bind(queryExecution.analyzed.output), queryExecution.analyzed.output).named :: Nil, logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 3f84e22a1025b..7e5acbe8517d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -52,8 +52,10 @@ class GroupedDataset[K, T] private[sql]( private implicit val unresolvedKEncoder = encoderFor(kEncoder) private implicit val unresolvedTEncoder = encoderFor(tEncoder) - private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes) - private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes) + private val resolvedKEncoder = + unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) + private val resolvedTEncoder = + unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes) private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 3f2775896bb8c..6ce41aaf01e27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -52,8 +52,8 @@ object TypedAggregateExpression { */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], - aEncoder: Option[ExpressionEncoder[Any]], - bEncoder: ExpressionEncoder[Any], + aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. + bEncoder: ExpressionEncoder[Any], // Should be bound. cEncoder: ExpressionEncoder[Any], children: Seq[Attribute], mutableAggBufferOffset: Int, @@ -92,9 +92,6 @@ case class TypedAggregateExpression( // We let the dataset do the binding for us. lazy val boundA = aEncoder.get - val bAttributes = bEncoder.schema.toAttributes - lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) - private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { // todo: need a more neat way to assign the value. var i = 0 @@ -114,24 +111,24 @@ case class TypedAggregateExpression( override def update(buffer: MutableRow, input: InternalRow): Unit = { val inputA = boundA.fromRow(input) - val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val merged = aggregator.reduce(currentB, inputA) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer, returned) } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1) - val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2) + val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) + val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) val merged = aggregator.merge(b1, b2) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer1, returned) } override def eval(buffer: InternalRow): Any = { - val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val result = cEncoder.toRow(aggregator.finish(b)) dataType match { case _: StructType => result From e9fee01bcfe91a3d76efb61cf931b915ddaa4e5d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 18 Nov 2015 18:34:01 -0800 Subject: [PATCH 0816/1168] [SPARK-11839][ML] refactor save/write traits * add "ML" prefix to reader/writer/readable/writable to avoid name collision with java.util.* * define `DefaultParamsReadable/Writable` and use them to save some code * use `super.load` instead so people can jump directly to the doc of `Readable.load`, which documents the Java compatibility issues jkbradley Author: Xiangrui Meng Closes #9827 from mengxr/SPARK-11839. --- .../scala/org/apache/spark/ml/Pipeline.scala | 40 +++++++++---------- .../classification/LogisticRegression.scala | 29 +++++++------- .../apache/spark/ml/feature/Binarizer.scala | 12 ++---- .../apache/spark/ml/feature/Bucketizer.scala | 12 ++---- .../spark/ml/feature/CountVectorizer.scala | 22 ++++------ .../org/apache/spark/ml/feature/DCT.scala | 12 ++---- .../apache/spark/ml/feature/HashingTF.scala | 12 ++---- .../org/apache/spark/ml/feature/IDF.scala | 23 +++++------ .../apache/spark/ml/feature/Interaction.scala | 12 ++---- .../spark/ml/feature/MinMaxScaler.scala | 22 ++++------ .../org/apache/spark/ml/feature/NGram.scala | 12 ++---- .../apache/spark/ml/feature/Normalizer.scala | 12 ++---- .../spark/ml/feature/OneHotEncoder.scala | 12 ++---- .../ml/feature/PolynomialExpansion.scala | 12 ++---- .../ml/feature/QuantileDiscretizer.scala | 12 ++---- .../spark/ml/feature/SQLTransformer.scala | 13 ++---- .../spark/ml/feature/StandardScaler.scala | 22 ++++------ .../spark/ml/feature/StopWordsRemover.scala | 12 ++---- .../spark/ml/feature/StringIndexer.scala | 32 +++++---------- .../apache/spark/ml/feature/Tokenizer.scala | 24 +++-------- .../spark/ml/feature/VectorAssembler.scala | 12 ++---- .../spark/ml/feature/VectorSlicer.scala | 12 ++---- .../apache/spark/ml/recommendation/ALS.scala | 27 +++++-------- .../ml/regression/LinearRegression.scala | 30 ++++++-------- .../org/apache/spark/ml/util/ReadWrite.scala | 40 ++++++++++++------- .../org/apache/spark/ml/PipelineSuite.scala | 14 +++---- .../spark/ml/util/DefaultReadWriteTest.scala | 17 ++++---- 27 files changed, 190 insertions(+), 321 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 25f0c696f42be..b0f22e042ec56 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -29,8 +29,8 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, Logging} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.Reader -import org.apache.spark.ml.util.Writer +import org.apache.spark.ml.util.MLReader +import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -89,7 +89,7 @@ abstract class PipelineStage extends Params with Logging { * an identity transformer. */ @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable { +class Pipeline(override val uid: String) extends Estimator[PipelineModel] with MLWritable { def this() = this(Identifiable.randomUID("pipeline")) @@ -174,16 +174,16 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with W theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } - override def write: Writer = new Pipeline.PipelineWriter(this) + override def write: MLWriter = new Pipeline.PipelineWriter(this) } -object Pipeline extends Readable[Pipeline] { +object Pipeline extends MLReadable[Pipeline] { - override def read: Reader[Pipeline] = new PipelineReader + override def read: MLReader[Pipeline] = new PipelineReader - override def load(path: String): Pipeline = read.load(path) + override def load(path: String): Pipeline = super.load(path) - private[ml] class PipelineWriter(instance: Pipeline) extends Writer { + private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter { SharedReadWrite.validateStages(instance.getStages) @@ -191,7 +191,7 @@ object Pipeline extends Readable[Pipeline] { SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) } - private[ml] class PipelineReader extends Reader[Pipeline] { + private[ml] class PipelineReader extends MLReader[Pipeline] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.Pipeline" @@ -202,7 +202,7 @@ object Pipeline extends Readable[Pipeline] { } } - /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */ + /** Methods for [[MLReader]] and [[MLWriter]] shared between [[Pipeline]] and [[PipelineModel]] */ private[ml] object SharedReadWrite { import org.json4s.JsonDSL._ @@ -210,7 +210,7 @@ object Pipeline extends Readable[Pipeline] { /** Check that all stages are Writable */ def validateStages(stages: Array[PipelineStage]): Unit = { stages.foreach { - case stage: Writable => // good + case stage: MLWritable => // good case other => throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + @@ -245,7 +245,7 @@ object Pipeline extends Readable[Pipeline] { // Save stages val stagesDir = new Path(path, "stages").toString - stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) => + stages.zipWithIndex.foreach { case (stage: MLWritable, idx: Int) => stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir)) } } @@ -285,7 +285,7 @@ object Pipeline extends Readable[Pipeline] { val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath) } (metadata.uid, stages) } @@ -308,7 +308,7 @@ object Pipeline extends Readable[Pipeline] { class PipelineModel private[ml] ( override val uid: String, val stages: Array[Transformer]) - extends Model[PipelineModel] with Writable with Logging { + extends Model[PipelineModel] with MLWritable with Logging { /** A Java/Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, stages: ju.List[Transformer]) = { @@ -333,18 +333,18 @@ class PipelineModel private[ml] ( new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } - override def write: Writer = new PipelineModel.PipelineModelWriter(this) + override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) } -object PipelineModel extends Readable[PipelineModel] { +object PipelineModel extends MLReadable[PipelineModel] { import Pipeline.SharedReadWrite - override def read: Reader[PipelineModel] = new PipelineModelReader + override def read: MLReader[PipelineModel] = new PipelineModelReader - override def load(path: String): PipelineModel = read.load(path) + override def load(path: String): PipelineModel = super.load(path) - private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer { + private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) @@ -352,7 +352,7 @@ object PipelineModel extends Readable[PipelineModel] { instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } - private[ml] class PipelineModelReader extends Reader[PipelineModel] { + private[ml] class PipelineModelReader extends MLReader[PipelineModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.PipelineModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 71c2533bcbf47..a3cc49f7f018c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -29,9 +29,9 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -157,7 +157,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas @Experimental class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] - with LogisticRegressionParams with Writable with Logging { + with LogisticRegressionParams with DefaultParamsWritable with Logging { def this() = this(Identifiable.randomUID("logreg")) @@ -385,12 +385,11 @@ class LogisticRegression(override val uid: String) } override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) - - override def write: Writer = new DefaultParamsWriter(this) } -object LogisticRegression extends Readable[LogisticRegression] { - override def read: Reader[LogisticRegression] = new DefaultParamsReader[LogisticRegression] +object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { + + override def load(path: String): LogisticRegression = super.load(path) } /** @@ -403,7 +402,7 @@ class LogisticRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams with Writable { + with LogisticRegressionParams with MLWritable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients @@ -519,26 +518,26 @@ class LogisticRegressionModel private[ml] ( } /** - * Returns a [[Writer]] instance for this ML instance. + * Returns a [[MLWriter]] instance for this ML instance. * * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * * This also does not save the [[parent]] currently. */ - override def write: Writer = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } -object LogisticRegressionModel extends Readable[LogisticRegressionModel] { +object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { - override def read: Reader[LogisticRegressionModel] = new LogisticRegressionModelReader + override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader - override def load(path: String): LogisticRegressionModel = read.load(path) + override def load(path: String): LogisticRegressionModel = super.load(path) - /** [[Writer]] instance for [[LogisticRegressionModel]] */ + /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) - extends Writer with Logging { + extends MLWriter with Logging { private case class Data( numClasses: Int, @@ -558,7 +557,7 @@ object LogisticRegressionModel extends Readable[LogisticRegressionModel] { } private[classification] class LogisticRegressionModelReader - extends Reader[LogisticRegressionModel] { + extends MLReader[LogisticRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index e2be6547d8f00..63c06581482ed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental final class Binarizer(override val uid: String) - extends Transformer with Writable with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("binarizer")) @@ -86,17 +86,11 @@ final class Binarizer(override val uid: String) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Binarizer extends Readable[Binarizer] { - - @Since("1.6.0") - override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer] +object Binarizer extends DefaultParamsReadable[Binarizer] { @Since("1.6.0") - override def load(path: String): Binarizer = read.load(path) + override def load(path: String): Binarizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 7095fbd70aa07..324353a96afb3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ @Experimental final class Bucketizer(override val uid: String) - extends Model[Bucketizer] with HasInputCol with HasOutputCol with Writable { + extends Model[Bucketizer] with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("bucketizer")) @@ -93,12 +93,9 @@ final class Bucketizer(override val uid: String) override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } -object Bucketizer extends Readable[Bucketizer] { +object Bucketizer extends DefaultParamsReadable[Bucketizer] { /** We require splits to be of length >= 3 and to be in strictly increasing order. */ private[feature] def checkSplits(splits: Array[Double]): Boolean = { @@ -140,8 +137,5 @@ object Bucketizer extends Readable[Bucketizer] { } @Since("1.6.0") - override def read: Reader[Bucketizer] = new DefaultParamsReader[Bucketizer] - - @Since("1.6.0") - override def load(path: String): Bucketizer = read.load(path) + override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 5ff9bfb7d1119..4969cf42450d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -107,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit */ @Experimental class CountVectorizer(override val uid: String) - extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable { + extends Estimator[CountVectorizerModel] with CountVectorizerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("cntVec")) @@ -171,16 +171,10 @@ class CountVectorizer(override val uid: String) } override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object CountVectorizer extends Readable[CountVectorizer] { - - @Since("1.6.0") - override def read: Reader[CountVectorizer] = new DefaultParamsReader +object CountVectorizer extends DefaultParamsReadable[CountVectorizer] { @Since("1.6.0") override def load(path: String): CountVectorizer = super.load(path) @@ -193,7 +187,7 @@ object CountVectorizer extends Readable[CountVectorizer] { */ @Experimental class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) - extends Model[CountVectorizerModel] with CountVectorizerParams with Writable { + extends Model[CountVectorizerModel] with CountVectorizerParams with MLWritable { import CountVectorizerModel._ @@ -251,14 +245,14 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin } @Since("1.6.0") - override def write: Writer = new CountVectorizerModelWriter(this) + override def write: MLWriter = new CountVectorizerModelWriter(this) } @Since("1.6.0") -object CountVectorizerModel extends Readable[CountVectorizerModel] { +object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private[CountVectorizerModel] - class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer { + class CountVectorizerModelWriter(instance: CountVectorizerModel) extends MLWriter { private case class Data(vocabulary: Seq[String]) @@ -270,7 +264,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] { } } - private class CountVectorizerModelReader extends Reader[CountVectorizerModel] { + private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { private val className = "org.apache.spark.ml.feature.CountVectorizerModel" @@ -288,7 +282,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] { } @Since("1.6.0") - override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader + override def read: MLReader[CountVectorizerModel] = new CountVectorizerModelReader @Since("1.6.0") override def load(path: String): CountVectorizerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 6ea5a616173ee..6bed72164a1da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class DCT(override val uid: String) - extends UnaryTransformer[Vector, Vector, DCT] with Writable { + extends UnaryTransformer[Vector, Vector, DCT] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("dct")) @@ -69,17 +69,11 @@ class DCT(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object DCT extends Readable[DCT] { - - @Since("1.6.0") - override def read: Reader[DCT] = new DefaultParamsReader[DCT] +object DCT extends DefaultParamsReadable[DCT] { @Since("1.6.0") - override def load(path: String): DCT = read.load(path) + override def load(path: String): DCT = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 6d2ea675f5617..9e15835429a38 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{ArrayType, StructType} */ @Experimental class HashingTF(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("hashingTF")) @@ -77,17 +77,11 @@ class HashingTF(override val uid: String) } override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object HashingTF extends Readable[HashingTF] { - - @Since("1.6.0") - override def read: Reader[HashingTF] = new DefaultParamsReader[HashingTF] +object HashingTF extends DefaultParamsReadable[HashingTF] { @Since("1.6.0") - override def load(path: String): HashingTF = read.load(path) + override def load(path: String): HashingTF = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 53ad34ef12646..0e00ef6f2ee20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -62,7 +62,8 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idf")) @@ -87,16 +88,10 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa } override def copy(extra: ParamMap): IDF = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object IDF extends Readable[IDF] { - - @Since("1.6.0") - override def read: Reader[IDF] = new DefaultParamsReader +object IDF extends DefaultParamsReadable[IDF] { @Since("1.6.0") override def load(path: String): IDF = super.load(path) @@ -110,7 +105,7 @@ object IDF extends Readable[IDF] { class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) - extends Model[IDFModel] with IDFBase with Writable { + extends Model[IDFModel] with IDFBase with MLWritable { import IDFModel._ @@ -140,13 +135,13 @@ class IDFModel private[ml] ( def idf: Vector = idfModel.idf @Since("1.6.0") - override def write: Writer = new IDFModelWriter(this) + override def write: MLWriter = new IDFModelWriter(this) } @Since("1.6.0") -object IDFModel extends Readable[IDFModel] { +object IDFModel extends MLReadable[IDFModel] { - private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer { + private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter { private case class Data(idf: Vector) @@ -158,7 +153,7 @@ object IDFModel extends Readable[IDFModel] { } } - private class IDFModelReader extends Reader[IDFModel] { + private class IDFModelReader extends MLReader[IDFModel] { private val className = "org.apache.spark.ml.feature.IDFModel" @@ -176,7 +171,7 @@ object IDFModel extends Readable[IDFModel] { } @Since("1.6.0") - override def read: Reader[IDFModel] = new IDFModelReader + override def read: MLReader[IDFModel] = new IDFModelReader @Since("1.6.0") override def load(path: String): IDFModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 9df6b311cc9da..2181119f04a5d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.types._ @Since("1.6.0") @Experimental class Interaction @Since("1.6.0") (override val uid: String) extends Transformer - with HasInputCols with HasOutputCol with Writable { + with HasInputCols with HasOutputCol with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("interaction")) @@ -224,19 +224,13 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer require($(inputCols).length > 0, "Input cols must have non-zero length.") require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Interaction extends Readable[Interaction] { - - @Since("1.6.0") - override def read: Reader[Interaction] = new DefaultParamsReader[Interaction] +object Interaction extends DefaultParamsReadable[Interaction] { @Since("1.6.0") - override def load(path: String): Interaction = read.load(path) + override def load(path: String): Interaction = super.load(path) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 24d964fae834e..ed24eabb50444 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -88,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H */ @Experimental class MinMaxScaler(override val uid: String) - extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable { + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("minMaxScal")) @@ -118,16 +118,10 @@ class MinMaxScaler(override val uid: String) } override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object MinMaxScaler extends Readable[MinMaxScaler] { - - @Since("1.6.0") - override def read: Reader[MinMaxScaler] = new DefaultParamsReader +object MinMaxScaler extends DefaultParamsReadable[MinMaxScaler] { @Since("1.6.0") override def load(path: String): MinMaxScaler = super.load(path) @@ -147,7 +141,7 @@ class MinMaxScalerModel private[ml] ( override val uid: String, val originalMin: Vector, val originalMax: Vector) - extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable { + extends Model[MinMaxScalerModel] with MinMaxScalerParams with MLWritable { import MinMaxScalerModel._ @@ -195,14 +189,14 @@ class MinMaxScalerModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new MinMaxScalerModelWriter(this) + override def write: MLWriter = new MinMaxScalerModelWriter(this) } @Since("1.6.0") -object MinMaxScalerModel extends Readable[MinMaxScalerModel] { +object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private[MinMaxScalerModel] - class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer { + class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends MLWriter { private case class Data(originalMin: Vector, originalMax: Vector) @@ -214,7 +208,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] { } } - private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] { + private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" @@ -231,7 +225,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] { } @Since("1.6.0") - override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader + override def read: MLReader[MinMaxScalerModel] = new MinMaxScalerModelReader @Since("1.6.0") override def load(path: String): MinMaxScalerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 4a17acd95199f..65414ecbefbbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class NGram(override val uid: String) - extends UnaryTransformer[Seq[String], Seq[String], NGram] with Writable { + extends UnaryTransformer[Seq[String], Seq[String], NGram] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("ngram")) @@ -66,17 +66,11 @@ class NGram(override val uid: String) } override protected def outputDataType: DataType = new ArrayType(StringType, false) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object NGram extends Readable[NGram] { - - @Since("1.6.0") - override def read: Reader[NGram] = new DefaultParamsReader[NGram] +object NGram extends DefaultParamsReadable[NGram] { @Since("1.6.0") - override def load(path: String): NGram = read.load(path) + override def load(path: String): NGram = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 9df6a091d5058..c2d514fd9629e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class Normalizer(override val uid: String) - extends UnaryTransformer[Vector, Vector, Normalizer] with Writable { + extends UnaryTransformer[Vector, Vector, Normalizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("normalizer")) @@ -56,17 +56,11 @@ class Normalizer(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT() - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Normalizer extends Readable[Normalizer] { - - @Since("1.6.0") - override def read: Reader[Normalizer] = new DefaultParamsReader[Normalizer] +object Normalizer extends DefaultParamsReadable[Normalizer] { @Since("1.6.0") - override def load(path: String): Normalizer = read.load(path) + override def load(path: String): Normalizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 4e2adfaafa21e..d70164eaf0224 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental class OneHotEncoder(override val uid: String) extends Transformer - with HasInputCol with HasOutputCol with Writable { + with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("oneHot")) @@ -165,17 +165,11 @@ class OneHotEncoder(override val uid: String) extends Transformer } override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object OneHotEncoder extends Readable[OneHotEncoder] { - - @Since("1.6.0") - override def read: Reader[OneHotEncoder] = new DefaultParamsReader[OneHotEncoder] +object OneHotEncoder extends DefaultParamsReadable[OneHotEncoder] { @Since("1.6.0") - override def load(path: String): OneHotEncoder = read.load(path) + override def load(path: String): OneHotEncoder = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 49415398325fd..08610593fadda 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class PolynomialExpansion(override val uid: String) - extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with Writable { + extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("poly")) @@ -63,9 +63,6 @@ class PolynomialExpansion(override val uid: String) override protected def outputDataType: DataType = new VectorUDT() override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } /** @@ -81,7 +78,7 @@ class PolynomialExpansion(override val uid: String) * current index and increment it properly for sparse input. */ @Since("1.6.0") -object PolynomialExpansion extends Readable[PolynomialExpansion] { +object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] { private def choose(n: Int, k: Int): Int = { Range(n, n - k, -1).product / Range(k, 1, -1).product @@ -182,8 +179,5 @@ object PolynomialExpansion extends Readable[PolynomialExpansion] { } @Since("1.6.0") - override def read: Reader[PolynomialExpansion] = new DefaultParamsReader[PolynomialExpansion] - - @Since("1.6.0") - override def load(path: String): PolynomialExpansion = read.load(path) + override def load(path: String): PolynomialExpansion = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 2da5c966d2967..7bf67c6325a35 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -60,7 +60,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol w */ @Experimental final class QuantileDiscretizer(override val uid: String) - extends Estimator[Bucketizer] with QuantileDiscretizerBase with Writable { + extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("quantileDiscretizer")) @@ -93,13 +93,10 @@ final class QuantileDiscretizer(override val uid: String) } override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { +object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { /** * Sampling from the given dataset to collect quantile statistics. */ @@ -179,8 +176,5 @@ object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { } @Since("1.6.0") - override def read: Reader[QuantileDiscretizer] = new DefaultParamsReader[QuantileDiscretizer] - - @Since("1.6.0") - override def load(path: String): QuantileDiscretizer = read.load(path) + override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index c115064ff301a..3a735017ba836 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.StructType */ @Experimental @Since("1.6.0") -class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer with Writable { +class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer + with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("sql")) @@ -77,17 +78,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor @Since("1.6.0") override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object SQLTransformer extends Readable[SQLTransformer] { - - @Since("1.6.0") - override def read: Reader[SQLTransformer] = new DefaultParamsReader[SQLTransformer] +object SQLTransformer extends DefaultParamsReadable[SQLTransformer] { @Since("1.6.0") - override def load(path: String): SQLTransformer = read.load(path) + override def load(path: String): SQLTransformer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index ab04e5418dd4f..1f689c1da1ba9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -59,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with */ @Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams with Writable { + with StandardScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stdScal")) @@ -96,16 +96,10 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StandardScaler extends Readable[StandardScaler] { - - @Since("1.6.0") - override def read: Reader[StandardScaler] = new DefaultParamsReader +object StandardScaler extends DefaultParamsReadable[StandardScaler] { @Since("1.6.0") override def load(path: String): StandardScaler = super.load(path) @@ -119,7 +113,7 @@ object StandardScaler extends Readable[StandardScaler] { class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) - extends Model[StandardScalerModel] with StandardScalerParams with Writable { + extends Model[StandardScalerModel] with StandardScalerParams with MLWritable { import StandardScalerModel._ @@ -165,14 +159,14 @@ class StandardScalerModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new StandardScalerModelWriter(this) + override def write: MLWriter = new StandardScalerModelWriter(this) } @Since("1.6.0") -object StandardScalerModel extends Readable[StandardScalerModel] { +object StandardScalerModel extends MLReadable[StandardScalerModel] { private[StandardScalerModel] - class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer { + class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter { private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) @@ -184,7 +178,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] { } } - private class StandardScalerModelReader extends Reader[StandardScalerModel] { + private class StandardScalerModelReader extends MLReader[StandardScalerModel] { private val className = "org.apache.spark.ml.feature.StandardScalerModel" @@ -204,7 +198,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] { } @Since("1.6.0") - override def read: Reader[StandardScalerModel] = new StandardScalerModelReader + override def read: MLReader[StandardScalerModel] = new StandardScalerModelReader @Since("1.6.0") override def load(path: String): StandardScalerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index f1146988dcc7c..318808596dc6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -86,7 +86,7 @@ private[spark] object StopWords { */ @Experimental class StopWordsRemover(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stopWords")) @@ -154,17 +154,11 @@ class StopWordsRemover(override val uid: String) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StopWordsRemover extends Readable[StopWordsRemover] { - - @Since("1.6.0") - override def read: Reader[StopWordsRemover] = new DefaultParamsReader[StopWordsRemover] +object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] { @Since("1.6.0") - override def load(path: String): StopWordsRemover = read.load(path) + override def load(path: String): StopWordsRemover = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f16f6afc002d8..97a2e4f6d6ca4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -65,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] - with StringIndexerBase with Writable { + with StringIndexerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("strIdx")) @@ -93,16 +93,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod } override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StringIndexer extends Readable[StringIndexer] { - - @Since("1.6.0") - override def read: Reader[StringIndexer] = new DefaultParamsReader +object StringIndexer extends DefaultParamsReadable[StringIndexer] { @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -122,7 +116,7 @@ object StringIndexer extends Readable[StringIndexer] { class StringIndexerModel ( override val uid: String, val labels: Array[String]) - extends Model[StringIndexerModel] with StringIndexerBase with Writable { + extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { import StringIndexerModel._ @@ -199,10 +193,10 @@ class StringIndexerModel ( } @Since("1.6.0") -object StringIndexerModel extends Readable[StringIndexerModel] { +object StringIndexerModel extends MLReadable[StringIndexerModel] { private[StringIndexerModel] - class StringIndexModelWriter(instance: StringIndexerModel) extends Writer { + class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter { private case class Data(labels: Array[String]) @@ -214,7 +208,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { } } - private class StringIndexerModelReader extends Reader[StringIndexerModel] { + private class StringIndexerModelReader extends MLReader[StringIndexerModel] { private val className = "org.apache.spark.ml.feature.StringIndexerModel" @@ -232,7 +226,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { } @Since("1.6.0") - override def read: Reader[StringIndexerModel] = new StringIndexerModelReader + override def read: MLReader[StringIndexerModel] = new StringIndexerModelReader @Since("1.6.0") override def load(path: String): StringIndexerModel = super.load(path) @@ -249,7 +243,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { */ @Experimental class IndexToString private[ml] (override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idxToStr")) @@ -316,17 +310,11 @@ class IndexToString private[ml] (override val uid: String) override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object IndexToString extends Readable[IndexToString] { - - @Since("1.6.0") - override def read: Reader[IndexToString] = new DefaultParamsReader[IndexToString] +object IndexToString extends DefaultParamsReadable[IndexToString] { @Since("1.6.0") - override def load(path: String): IndexToString = read.load(path) + override def load(path: String): IndexToString = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0e4445d1e2fa7..8ad7bbedaab5c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class Tokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], Tokenizer] with Writable { + extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("tok")) @@ -46,19 +46,13 @@ class Tokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Tokenizer extends Readable[Tokenizer] { - - @Since("1.6.0") - override def read: Reader[Tokenizer] = new DefaultParamsReader[Tokenizer] +object Tokenizer extends DefaultParamsReadable[Tokenizer] { @Since("1.6.0") - override def load(path: String): Tokenizer = read.load(path) + override def load(path: String): Tokenizer = super.load(path) } /** @@ -70,7 +64,7 @@ object Tokenizer extends Readable[Tokenizer] { */ @Experimental class RegexTokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], RegexTokenizer] with Writable { + extends UnaryTransformer[String, Seq[String], RegexTokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("regexTok")) @@ -145,17 +139,11 @@ class RegexTokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object RegexTokenizer extends Readable[RegexTokenizer] { - - @Since("1.6.0") - override def read: Reader[RegexTokenizer] = new DefaultParamsReader[RegexTokenizer] +object RegexTokenizer extends DefaultParamsReadable[RegexTokenizer] { @Since("1.6.0") - override def load(path: String): RegexTokenizer = read.load(path) + override def load(path: String): RegexTokenizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 7e54205292ca2..0feec0549852b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ */ @Experimental class VectorAssembler(override val uid: String) - extends Transformer with HasInputCols with HasOutputCol with Writable { + extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecAssembler")) @@ -120,19 +120,13 @@ class VectorAssembler(override val uid: String) } override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object VectorAssembler extends Readable[VectorAssembler] { - - @Since("1.6.0") - override def read: Reader[VectorAssembler] = new DefaultParamsReader[VectorAssembler] +object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { @Since("1.6.0") - override def load(path: String): VectorAssembler = read.load(path) + override def load(path: String): VectorAssembler = super.load(path) private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 911582b55b574..5410a50bc2e47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType */ @Experimental final class VectorSlicer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vectorSlicer")) @@ -151,13 +151,10 @@ final class VectorSlicer(override val uid: String) } override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object VectorSlicer extends Readable[VectorSlicer] { +object VectorSlicer extends DefaultParamsReadable[VectorSlicer] { /** Return true if given feature indices are valid */ private[feature] def validIndices(indices: Array[Int]): Boolean = { @@ -174,8 +171,5 @@ object VectorSlicer extends Readable[VectorSlicer] { } @Since("1.6.0") - override def read: Reader[VectorSlicer] = new DefaultParamsReader[VectorSlicer] - - @Since("1.6.0") - override def load(path: String): VectorSlicer = read.load(path) + override def load(path: String): VectorSlicer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index d92514d2e239e..795b73c4c2121 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -185,7 +185,7 @@ class ALSModel private[ml] ( val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) - extends Model[ALSModel] with ALSModelParams with Writable { + extends Model[ALSModel] with ALSModelParams with MLWritable { /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) @@ -225,19 +225,19 @@ class ALSModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new ALSModel.ALSModelWriter(this) + override def write: MLWriter = new ALSModel.ALSModelWriter(this) } @Since("1.6.0") -object ALSModel extends Readable[ALSModel] { +object ALSModel extends MLReadable[ALSModel] { @Since("1.6.0") - override def read: Reader[ALSModel] = new ALSModelReader + override def read: MLReader[ALSModel] = new ALSModelReader @Since("1.6.0") - override def load(path: String): ALSModel = read.load(path) + override def load(path: String): ALSModel = super.load(path) - private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer { + private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { val extraMetadata = render("rank" -> instance.rank) @@ -249,7 +249,7 @@ object ALSModel extends Readable[ALSModel] { } } - private[recommendation] class ALSModelReader extends Reader[ALSModel] { + private[recommendation] class ALSModelReader extends MLReader[ALSModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.recommendation.ALSModel" @@ -309,7 +309,8 @@ object ALSModel extends Readable[ALSModel] { * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable { +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams + with DefaultParamsWritable { import org.apache.spark.ml.recommendation.ALS.Rating @@ -391,9 +392,6 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w } override def copy(extra: ParamMap): ALS = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @@ -406,7 +404,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w * than 2 billion. */ @DeveloperApi -object ALS extends Readable[ALS] with Logging { +object ALS extends DefaultParamsReadable[ALS] with Logging { /** * :: DeveloperApi :: @@ -416,10 +414,7 @@ object ALS extends Readable[ALS] with Logging { case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) @Since("1.6.0") - override def read: Reader[ALS] = new DefaultParamsReader[ALS] - - @Since("1.6.0") - override def load(path: String): ALS = read.load(path) + override def load(path: String): ALS = super.load(path) /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f7c44f0a51b8a..7ba1a60edaf71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -66,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Writable with Logging { + with LinearRegressionParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -345,19 +345,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object LinearRegression extends Readable[LinearRegression] { - - @Since("1.6.0") - override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression] +object LinearRegression extends DefaultParamsReadable[LinearRegression] { @Since("1.6.0") - override def load(path: String): LinearRegression = read.load(path) + override def load(path: String): LinearRegression = super.load(path) } /** @@ -371,7 +365,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with Writable { + with LinearRegressionParams with MLWritable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -441,7 +435,7 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[Writer]] instance for this ML instance. + * Returns a [[MLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -449,21 +443,21 @@ class LinearRegressionModel private[ml] ( * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) } @Since("1.6.0") -object LinearRegressionModel extends Readable[LinearRegressionModel] { +object LinearRegressionModel extends MLReadable[LinearRegressionModel] { @Since("1.6.0") - override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader + override def read: MLReader[LinearRegressionModel] = new LinearRegressionModelReader @Since("1.6.0") - override def load(path: String): LinearRegressionModel = read.load(path) + override def load(path: String): LinearRegressionModel = super.load(path) - /** [[Writer]] instance for [[LinearRegressionModel]] */ + /** [[MLWriter]] instance for [[LinearRegressionModel]] */ private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) - extends Writer with Logging { + extends MLWriter with Logging { private case class Data(intercept: Double, coefficients: Vector) @@ -477,7 +471,7 @@ object LinearRegressionModel extends Readable[LinearRegressionModel] { } } - private class LinearRegressionModelReader extends Reader[LinearRegressionModel] { + private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.regression.LinearRegressionModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index d8ce907af5323..ff9322dba122a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils /** - * Trait for [[Writer]] and [[Reader]]. + * Trait for [[MLWriter]] and [[MLReader]]. */ private[util] sealed trait BaseReadWrite { private var optionSQLContext: Option[SQLContext] = None @@ -64,7 +64,7 @@ private[util] sealed trait BaseReadWrite { */ @Experimental @Since("1.6.0") -abstract class Writer extends BaseReadWrite with Logging { +abstract class MLWriter extends BaseReadWrite with Logging { protected var shouldOverwrite: Boolean = false @@ -111,16 +111,16 @@ abstract class Writer extends BaseReadWrite with Logging { } /** - * Trait for classes that provide [[Writer]]. + * Trait for classes that provide [[MLWriter]]. */ @Since("1.6.0") -trait Writable { +trait MLWritable { /** - * Returns a [[Writer]] instance for this ML instance. + * Returns an [[MLWriter]] instance for this ML instance. */ @Since("1.6.0") - def write: Writer + def write: MLWriter /** * Saves this ML instance to the input path, a shortcut of `write.save(path)`. @@ -130,13 +130,18 @@ trait Writable { def save(path: String): Unit = write.save(path) } +private[ml] trait DefaultParamsWritable extends MLWritable { self: Params => + + override def write: MLWriter = new DefaultParamsWriter(this) +} + /** * Abstract class for utility classes that can load ML instances. * @tparam T ML instance type */ @Experimental @Since("1.6.0") -abstract class Reader[T] extends BaseReadWrite { +abstract class MLReader[T] extends BaseReadWrite { /** * Loads the ML component from the input path. @@ -149,18 +154,18 @@ abstract class Reader[T] extends BaseReadWrite { } /** - * Trait for objects that provide [[Reader]]. + * Trait for objects that provide [[MLReader]]. * @tparam T ML instance type */ @Experimental @Since("1.6.0") -trait Readable[T] { +trait MLReadable[T] { /** - * Returns a [[Reader]] instance for this class. + * Returns an [[MLReader]] instance for this class. */ @Since("1.6.0") - def read: Reader[T] + def read: MLReader[T] /** * Reads an ML instance from the input path, a shortcut of `read.load(path)`. @@ -171,13 +176,18 @@ trait Readable[T] { def load(path: String): T = read.load(path) } +private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] { + + override def read: MLReader[T] = new DefaultParamsReader +} + /** - * Default [[Writer]] implementation for transformers and estimators that contain basic + * Default [[MLWriter]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @param instance object to save */ -private[ml] class DefaultParamsWriter(instance: Params) extends Writer { +private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter { override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) @@ -218,13 +228,13 @@ private[ml] object DefaultParamsWriter { } /** - * Default [[Reader]] implementation for transformers and estimators that contain basic + * Default [[MLReader]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @tparam T ML instance type * TODO: Consider adding check for correct class name. */ -private[ml] class DefaultParamsReader[T] extends Reader[T] { +private[ml] class DefaultParamsReader[T] extends MLReader[T] { override def load(path: String): T = { val metadata = DefaultParamsReader.loadMetadata(path, sc) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 7f5c3895acb0c..12aba6bc6dbeb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -179,8 +179,8 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } -/** Used to test [[Pipeline]] with [[Writable]] stages */ -class WritableStage(override val uid: String) extends Transformer with Writable { +/** Used to test [[Pipeline]] with [[MLWritable]] stages */ +class WritableStage(override val uid: String) extends Transformer with MLWritable { final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -192,21 +192,21 @@ class WritableStage(override val uid: String) extends Transformer with Writable override def copy(extra: ParamMap): WritableStage = defaultCopy(extra) - override def write: Writer = new DefaultParamsWriter(this) + override def write: MLWriter = new DefaultParamsWriter(this) override def transform(dataset: DataFrame): DataFrame = dataset override def transformSchema(schema: StructType): StructType = schema } -object WritableStage extends Readable[WritableStage] { +object WritableStage extends MLReadable[WritableStage] { - override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage] + override def read: MLReader[WritableStage] = new DefaultParamsReader[WritableStage] - override def load(path: String): WritableStage = read.load(path) + override def load(path: String): WritableStage = super.load(path) } -/** Used to test [[Pipeline]] with non-[[Writable]] stages */ +/** Used to test [[Pipeline]] with non-[[MLWritable]] stages */ class UnWritableStage(override val uid: String) extends Transformer { final val intParam: IntParam = new IntParam(this, "intParam", "doc") diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index dd1e8acce9418..84d06b43d6224 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -38,7 +38,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam T ML instance type * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with Writable]( + def testDefaultReadWrite[T <: Params with MLWritable]( instance: T, testParams: Boolean = true): T = { val uid = instance.uid @@ -52,7 +52,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => instance.save(path) } instance.write.overwrite().save(path) - val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]] + val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] val newInstance = loader.load(path) assert(newInstance.uid === instance.uid) @@ -92,7 +92,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam E Type of [[Estimator]] * @tparam M Type of [[Model]] produced by estimator */ - def testEstimatorAndModelReadWrite[E <: Estimator[M] with Writable, M <: Model[M] with Writable]( + def testEstimatorAndModelReadWrite[ + E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, dataset: DataFrame, testParams: Map[String, Any], @@ -119,7 +120,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => } } -class MyParams(override val uid: String) extends Params with Writable { +class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -145,14 +146,14 @@ class MyParams(override val uid: String) extends Params with Writable { override def copy(extra: ParamMap): Params = defaultCopy(extra) - override def write: Writer = new DefaultParamsWriter(this) + override def write: MLWriter = new DefaultParamsWriter(this) } -object MyParams extends Readable[MyParams] { +object MyParams extends MLReadable[MyParams] { - override def read: Reader[MyParams] = new DefaultParamsReader[MyParams] + override def read: MLReader[MyParams] = new DefaultParamsReader[MyParams] - override def load(path: String): MyParams = read.load(path) + override def load(path: String): MyParams = super.load(path) } class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext From 3b1b13e04ed50e6beb0f893f8e26acd778ebcdb0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 18:34:36 -0800 Subject: [PATCH 0817/1168] [SPARK-11833][SQL] Add Java tests for Kryo/Java Dataset encoders Also added some nicer error messages for incompatible types (private types and primitive types) for Kryo/Java encoder. Author: Reynold Xin Closes #9823 from rxin/SPARK-11833. --- .../scala/org/apache/spark/sql/Encoder.scala | 69 +++++++++++------ .../encoders/EncoderErrorMessageSuite.scala | 40 ++++++++++ .../catalyst/encoders/FlatEncoderSuite.scala | 22 ++---- .../apache/spark/sql/JavaDatasetSuite.java | 75 ++++++++++++++++++- 4 files changed, 166 insertions(+), 40 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 1ed5111440c80..d54f2854fb33f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.lang.reflect.Modifier + import scala.reflect.{ClassTag, classTag} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} @@ -43,30 +45,28 @@ trait Encoder[T] extends Serializable { */ object Encoders { - /** A way to construct encoders using generic serializers. */ - private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { - ExpressionEncoder[T]( - schema = new StructType().add("value", BinaryType), - flat = true, - toRowExpressions = Seq( - EncodeUsingSerializer( - BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), - fromRowExpression = - DecodeUsingSerializer[T]( - BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), - clsTag = classTag[T] - ) - } + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) + def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) + def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) + def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. */ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) /** * Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) @@ -75,6 +75,8 @@ object Encoders { * serialization. This encoder maps T into a single byte array (binary) field. * * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. */ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) @@ -83,17 +85,40 @@ object Encoders { * This encoder maps T into a single byte array (binary) field. * * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) - def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) - def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) - def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + /** Throws an exception if T is not a public class. */ + private def validatePublicClass[T: ClassTag](): Unit = { + if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) { + throw new UnsupportedOperationException( + s"${classTag[T].runtimeClass.getName} is not a public class. " + + "Only public classes are supported.") + } + } + + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { + if (classTag[T].runtimeClass.isPrimitive) { + throw new UnsupportedOperationException("Primitive types are not supported.") + } + + validatePublicClass[T]() + + ExpressionEncoder[T]( + schema = new StructType().add("value", BinaryType), + flat = true, + toRowExpressions = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + fromRowExpression = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), + clsTag = classTag[T] + ) + } def tuple[T1, T2]( e1: Encoder[T1], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala new file mode 100644 index 0000000000000..0b2a10bb04c10 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders + + +class EncoderErrorMessageSuite extends SparkFunSuite { + + // Note: we also test error messages for encoders for private classes in JavaDatasetSuite. + // That is done in Java because Scala cannot create truly private classes. + + test("primitive types in encoders using Kryo serialization") { + intercept[UnsupportedOperationException] { Encoders.kryo[Int] } + intercept[UnsupportedOperationException] { Encoders.kryo[Long] } + intercept[UnsupportedOperationException] { Encoders.kryo[Char] } + } + + test("primitive types in encoders using Java serialization") { + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Int] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala index 6e0322fb6e019..07523d49f4266 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -74,24 +74,14 @@ class FlatEncoderSuite extends ExpressionEncoderSuite { FlatEncoder[Map[Int, Map[String, Int]]], "map of map") // Kryo encoders - encodeDecodeTest( - "hello", - encoderFor(Encoders.kryo[String]), - "kryo string") - encodeDecodeTest( - new KryoSerializable(15), - encoderFor(Encoders.kryo[KryoSerializable]), - "kryo object serialization") + encodeDecodeTest("hello", encoderFor(Encoders.kryo[String]), "kryo string") + encodeDecodeTest(new KryoSerializable(15), + encoderFor(Encoders.kryo[KryoSerializable]), "kryo object") // Java encoders - encodeDecodeTest( - "hello", - encoderFor(Encoders.javaSerialization[String]), - "java string") - encodeDecodeTest( - new JavaSerializable(15), - encoderFor(Encoders.javaSerialization[JavaSerializable]), - "java object serialization") + encodeDecodeTest("hello", encoderFor(Encoders.javaSerialization[String]), "java string") + encodeDecodeTest(new JavaSerializable(15), + encoderFor(Encoders.javaSerialization[JavaSerializable]), "java object") } /** For testing Kryo serialization based encoder. */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index d9b22506fbd3b..ce40dd856f679 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -24,6 +24,7 @@ import scala.Tuple3; import scala.Tuple4; import scala.Tuple5; + import org.junit.*; import org.apache.spark.Accumulator; @@ -410,8 +411,8 @@ public String call(Tuple2 value) throws Exception { .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG())); Assert.assertEquals( Arrays.asList( - new Tuple4("a", 3, 3L, 2L), - new Tuple4("b", 3, 3L, 1L)), + new Tuple4<>("a", 3, 3L, 2L), + new Tuple4<>("b", 3, 3L, 1L)), agged2.collectAsList()); } @@ -437,4 +438,74 @@ public Integer finish(Integer reduction) { return reduction; } } + + public static class KryoSerializable { + String value; + + KryoSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + return this.value.equals(((KryoSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + public static class JavaSerializable implements Serializable { + String value; + + JavaSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + return this.value.equals(((JavaSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + @Test + public void testKryoEncoder() { + Encoder encoder = Encoders.kryo(KryoSerializable.class); + List data = Arrays.asList( + new KryoSerializable("hello"), new KryoSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + @Test + public void testJavaEncoder() { + Encoder encoder = Encoders.javaSerialization(JavaSerializable.class); + List data = Arrays.asList( + new JavaSerializable("hello"), new JavaSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + /** + * For testing error messages when creating an encoder on a private class. This is done + * here since we cannot create truly private classes in Scala. + */ + private static class PrivateClassTest { } + + @Test(expected = UnsupportedOperationException.class) + public void testJavaEncoderErrorMessageForPrivateClass() { + Encoders.javaSerialization(PrivateClassTest.class); + } + + @Test(expected = UnsupportedOperationException.class) + public void testKryoEncoderErrorMessageForPrivateClass() { + Encoders.kryo(PrivateClassTest.class); + } } From 4223dcf8c2dedd80992eea47dd73c5d599e2932d Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 18 Nov 2015 18:38:45 -0800 Subject: [PATCH 0818/1168] [SPARK-11787][SQL] Improve Parquet scan performance when using flat schemas. This patch adds an alternate to the Parquet RecordReader from the parquet-mr project that is much faster for flat schemas. Instead of using the general converter mechanism from parquet-mr, this directly uses the lower level APIs from parquet-columnar and a customer RecordReader that directly assembles into UnsafeRows. This is optionally disabled and only used for supported schemas. Using the tpcds store sales table and doing a sum of increasingly more columns, the results are: For 1 Column: Before: 11.3M rows/second After: 18.2M rows/second For 2 Columns: Before: 7.2M rows/second After: 11.2M rows/second For 5 Columns: Before: 2.9M rows/second After: 4.5M rows/second Author: Nong Li Closes #9774 from nongli/parquet. --- .../apache/spark/rdd/SqlNewHadoopRDD.scala | 41 +- .../sql/catalyst/expressions/UnsafeRow.java | 9 + .../expressions/codegen/BufferHolder.java | 32 +- .../expressions/codegen/UnsafeRowWriter.java | 20 +- .../SpecificParquetRecordReaderBase.java | 240 +++++++ .../parquet/UnsafeRowParquetRecordReader.java | 593 ++++++++++++++++++ .../parquet/CatalystRowConverter.scala | 48 +- .../parquet/ParquetFilterSuite.scala | 4 +- 8 files changed, 944 insertions(+), 43 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 264dae7f39085..4d176332b69ce 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -20,8 +20,6 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date -import scala.reflect.ClassTag - import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -30,10 +28,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{Utils, SerializableConfiguration, ShutdownHookManager} import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} + +import scala.reflect.ClassTag private[spark] class SqlNewHadoopPartition( @@ -96,6 +96,11 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( @transient protected val jobId = new JobID(jobTrackerId, id) + // If true, enable using the custom RecordReader for parquet. This only works for + // a subset of the types (no complex types). + protected val enableUnsafeRowParquetReader: Boolean = + sc.conf.getBoolean("spark.parquet.enableUnsafeRowRecordReader", true) + override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) val inputFormat = inputFormatClass.newInstance @@ -150,9 +155,31 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( configurable.setConf(conf) case _ => } - private[this] var reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + private[this] var reader: RecordReader[Void, V] = null + + /** + * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this + * fails (for example, unsupported schema), try with the normal reader. + * TODO: plumb this through a different way? + */ + if (enableUnsafeRowParquetReader && + format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { + // TODO: move this class to sql.execution and remove this. + reader = Utils.classForName( + "org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader") + .newInstance().asInstanceOf[RecordReader[Void, V]] + try { + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + } catch { + case e: Exception => reader = null + } + } + + if (reader == null) { + reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5ba14ebdb62a4..33769363a0ed5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -178,6 +178,15 @@ public void pointTo(byte[] buf, int numFields, int sizeInBytes) { pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); } + /** + * Updates this UnsafeRow preserving the number of fields. + * @param buf byte array to point to + * @param sizeInBytes the number of bytes valid in the byte array + */ + public void pointTo(byte[] buf, int sizeInBytes) { + pointTo(buf, numFields, sizeInBytes); + } + @Override public void setNullAt(int i) { assertIndexIsValid(i); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 9c9468678065d..d26b1b187c27b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -17,19 +17,28 @@ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.unsafe.Platform; /** - * A helper class to manage the row buffer used in `GenerateUnsafeProjection`. - * - * Note that it is only used in `GenerateUnsafeProjection`, so it's safe to mark member variables - * public for ease of use. + * A helper class to manage the row buffer when construct unsafe rows. */ public class BufferHolder { - public byte[] buffer = new byte[64]; + public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; - public void grow(int neededSize) { + public BufferHolder() { + this(64); + } + + public BufferHolder(int size) { + buffer = new byte[size]; + } + + /** + * Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer. + */ + public void grow(int neededSize, UnsafeRow row) { final int length = totalSize() + neededSize; if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. @@ -41,12 +50,23 @@ public void grow(int neededSize) { Platform.BYTE_ARRAY_OFFSET, totalSize()); buffer = tmp; + if (row != null) { + row.pointTo(buffer, length * 2); + } } } + public void grow(int neededSize) { + grow(neededSize, null); + } + public void reset() { cursor = Platform.BYTE_ARRAY_OFFSET; } + public void resetTo(int offset) { + assert(offset <= buffer.length); + cursor = Platform.BYTE_ARRAY_OFFSET + offset; + } public int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 048b7749d8fb4..e227c0dec9748 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -35,6 +35,7 @@ public class UnsafeRowWriter { // The offset of the global buffer where we start to write this row. private int startingOffset; private int nullBitsSize; + private UnsafeRow row; public void initialize(BufferHolder holder, int numFields) { this.holder = holder; @@ -43,7 +44,7 @@ public void initialize(BufferHolder holder, int numFields) { // grow the global buffer to make sure it has enough space to write fixed-length data. final int fixedSize = nullBitsSize + 8 * numFields; - holder.grow(fixedSize); + holder.grow(fixedSize, row); holder.cursor += fixedSize; // zero-out the null bits region @@ -52,12 +53,19 @@ public void initialize(BufferHolder holder, int numFields) { } } + public void initialize(UnsafeRow row, BufferHolder holder, int numFields) { + initialize(holder, numFields); + this.row = row; + } + private void zeroOutPaddingBytes(int numBytes) { if ((numBytes & 0x07) > 0) { Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); } } + public BufferHolder holder() { return holder; } + public boolean isNullAt(int ordinal) { return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); } @@ -90,7 +98,7 @@ public void alignToWords(int numBytes) { if (remainder > 0) { final int paddingBytes = 8 - remainder; - holder.grow(paddingBytes); + holder.grow(paddingBytes, row); for (int i = 0; i < paddingBytes; i++) { Platform.putByte(holder.buffer, holder.cursor, (byte) 0); @@ -153,7 +161,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { } } else { // grow the global buffer before writing data. - holder.grow(16); + holder.grow(16, row); // zero-out the bytes Platform.putLong(holder.buffer, holder.cursor, 0L); @@ -185,7 +193,7 @@ public void write(int ordinal, UTF8String input) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize); + holder.grow(roundedSize, row); zeroOutPaddingBytes(numBytes); @@ -206,7 +214,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize); + holder.grow(roundedSize, row); zeroOutPaddingBytes(numBytes); @@ -222,7 +230,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { public void write(int ordinal, CalendarInterval input) { // grow the global buffer before writing data. - holder.grow(16); + holder.grow(16, row); // Write the months and microseconds fields of Interval to the variable length portion. Platform.putLong(holder.buffer, holder.cursor, input.months); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java new file mode 100644 index 0000000000000..2ed30c1f5a8d9 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.range; +import static org.apache.parquet.hadoop.ParquetFileReader.readFooter; +import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.bytes.BytesInput; +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.column.values.rle.RunLengthBitPackingHybridDecoder; +import org.apache.parquet.filter2.compat.FilterCompat; +import org.apache.parquet.hadoop.BadConfigurationException; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.ParquetInputFormat; +import org.apache.parquet.hadoop.ParquetInputSplit; +import org.apache.parquet.hadoop.api.InitContext; +import org.apache.parquet.hadoop.api.ReadSupport; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.hadoop.util.ConfigurationUtil; +import org.apache.parquet.schema.MessageType; + +/** + * Base class for custom RecordReaaders for Parquet that directly materialize to `T`. + * This class handles computing row groups, filtering on them, setting up the column readers, + * etc. + * This is heavily based on parquet-mr's RecordReader. + * TODO: move this to the parquet-mr project. There are performance benefits of doing it + * this way, albeit at a higher cost to implement. This base class is reusable. + */ +public abstract class SpecificParquetRecordReaderBase extends RecordReader { + protected Path file; + protected MessageType fileSchema; + protected MessageType requestedSchema; + protected ReadSupport readSupport; + + /** + * The total number of rows this RecordReader will eventually read. The sum of the + * rows of all the row groups. + */ + protected long totalRowCount; + + protected ParquetFileReader reader; + + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + Configuration configuration = taskAttemptContext.getConfiguration(); + ParquetInputSplit split = (ParquetInputSplit)inputSplit; + this.file = split.getPath(); + long[] rowGroupOffsets = split.getRowGroupOffsets(); + + ParquetMetadata footer; + List blocks; + + // if task.side.metadata is set, rowGroupOffsets is null + if (rowGroupOffsets == null) { + // then we need to apply the predicate push down filter + footer = readFooter(configuration, file, range(split.getStart(), split.getEnd())); + MessageType fileSchema = footer.getFileMetaData().getSchema(); + FilterCompat.Filter filter = getFilter(configuration); + blocks = filterRowGroups(filter, footer.getBlocks(), fileSchema); + } else { + // otherwise we find the row groups that were selected on the client + footer = readFooter(configuration, file, NO_FILTER); + Set offsets = new HashSet<>(); + for (long offset : rowGroupOffsets) { + offsets.add(offset); + } + blocks = new ArrayList<>(); + for (BlockMetaData block : footer.getBlocks()) { + if (offsets.contains(block.getStartingPos())) { + blocks.add(block); + } + } + // verify we found them all + if (blocks.size() != rowGroupOffsets.length) { + long[] foundRowGroupOffsets = new long[footer.getBlocks().size()]; + for (int i = 0; i < foundRowGroupOffsets.length; i++) { + foundRowGroupOffsets[i] = footer.getBlocks().get(i).getStartingPos(); + } + // this should never happen. + // provide a good error message in case there's a bug + throw new IllegalStateException( + "All the offsets listed in the split should be found in the file." + + " expected: " + Arrays.toString(rowGroupOffsets) + + " found: " + blocks + + " out of: " + Arrays.toString(foundRowGroupOffsets) + + " in range " + split.getStart() + ", " + split.getEnd()); + } + } + MessageType fileSchema = footer.getFileMetaData().getSchema(); + Map fileMetadata = footer.getFileMetaData().getKeyValueMetaData(); + this.readSupport = getReadSupportInstance( + (Class>) getReadSupportClass(configuration)); + ReadSupport.ReadContext readContext = readSupport.init(new InitContext( + taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); + this.requestedSchema = readContext.getRequestedSchema(); + this.fileSchema = fileSchema; + this.reader = new ParquetFileReader(configuration, file, blocks, requestedSchema.getColumns()); + for (BlockMetaData block : blocks) { + this.totalRowCount += block.getRowCount(); + } + } + + @Override + public Void getCurrentKey() throws IOException, InterruptedException { + return null; + } + + @Override + public void close() throws IOException { + if (reader != null) { + reader.close(); + reader = null; + } + } + + /** + * Utility classes to abstract over different way to read ints with different encodings. + * TODO: remove this layer of abstraction? + */ + abstract static class IntIterator { + abstract int nextInt() throws IOException; + } + + protected static final class ValuesReaderIntIterator extends IntIterator { + ValuesReader delegate; + + public ValuesReaderIntIterator(ValuesReader delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInteger(); + } + } + + protected static final class RLEIntIterator extends IntIterator { + RunLengthBitPackingHybridDecoder delegate; + + public RLEIntIterator(RunLengthBitPackingHybridDecoder delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInt(); + } + } + + protected static final class NullIntIterator extends IntIterator { + @Override + int nextInt() throws IOException { return 0; } + } + + /** + * Creates a reader for definition and repetition levels, returning an optimized one if + * the levels are not needed. + */ + static protected IntIterator createRLEIterator(int maxLevel, BytesInput bytes, + ColumnDescriptor descriptor) throws IOException { + try { + if (maxLevel == 0) return new NullIntIterator(); + return new RLEIntIterator( + new RunLengthBitPackingHybridDecoder( + BytesUtils.getWidthFromMaxInt(maxLevel), + new ByteArrayInputStream(bytes.toByteArray()))); + } catch (IOException e) { + throw new IOException("could not read levels in page for col " + descriptor, e); + } + } + + private static Map> toSetMultiMap(Map map) { + Map> setMultiMap = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) { + Set set = new HashSet<>(); + set.add(entry.getValue()); + setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set)); + } + return Collections.unmodifiableMap(setMultiMap); + } + + private static Class getReadSupportClass(Configuration configuration) { + return ConfigurationUtil.getClassFromConfig(configuration, + ParquetInputFormat.READ_SUPPORT_CLASS, ReadSupport.class); + } + + /** + * @param readSupportClass to instantiate + * @return the configured read support + */ + private static ReadSupport getReadSupportInstance( + Class> readSupportClass){ + try { + return readSupportClass.newInstance(); + } catch (InstantiationException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } catch (IllegalAccessException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java new file mode 100644 index 0000000000000..8a92e489ccb7c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -0,0 +1,593 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; + +import static org.apache.parquet.column.ValuesType.DEFINITION_LEVEL; +import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; +import static org.apache.parquet.column.ValuesType.VALUES; + +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.Preconditions; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.page.DataPage; +import org.apache.parquet.column.page.DataPageV1; +import org.apache.parquet.column.page.DataPageV2; +import org.apache.parquet.column.page.DictionaryPage; +import org.apache.parquet.column.page.PageReadStore; +import org.apache.parquet.column.page.PageReader; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +/** + * A specialized RecordReader that reads into UnsafeRows directly using the Parquet column APIs. + * + * This is somewhat based on parquet-mr's ColumnReader. + * + * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch. + * All of these can be handled efficiently and easily with codegen. + */ +public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase { + /** + * Batch of unsafe rows that we assemble and the current index we've returned. Everytime this + * batch is used up (batchIdx == numBatched), we populated the batch. + */ + private UnsafeRow[] rows = new UnsafeRow[64]; + private int batchIdx = 0; + private int numBatched = 0; + + /** + * Used to write variable length columns. Same length as `rows`. + */ + private UnsafeRowWriter[] rowWriters = null; + /** + * True if the row contains variable length fields. + */ + private boolean containsVarLenFields; + + /** + * The number of bytes in the fixed length portion of the row. + */ + private int fixedSizeBytes; + + /** + * For each request column, the reader to read this column. + * columnsReaders[i] populated the UnsafeRow's attribute at i. + */ + private ColumnReader[] columnReaders; + + /** + * The number of rows that have been returned. + */ + private long rowsReturned; + + /** + * The number of rows that have been reading, including the current in flight row group. + */ + private long totalCountLoadedSoFar = 0; + + /** + * For each column, the annotated original type. + */ + private OriginalType[] originalTypes; + + /** + * The default size for varlen columns. The row grows as necessary to accommodate the + * largest column. + */ + private static final int DEFAULT_VAR_LEN_SIZE = 32; + + /** + * Implementation of RecordReader API. + */ + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + super.initialize(inputSplit, taskAttemptContext); + + /** + * Check that the requested schema is supported. + */ + if (requestedSchema.getFieldCount() == 0) { + // TODO: what does this mean? + throw new IOException("Empty request schema not supported."); + } + int numVarLenFields = 0; + originalTypes = new OriginalType[requestedSchema.getFieldCount()]; + for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { + Type t = requestedSchema.getFields().get(i); + if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { + throw new IOException("Complex types not supported."); + } + PrimitiveType primitiveType = t.asPrimitiveType(); + + originalTypes[i] = t.getOriginalType(); + + // TODO: Be extremely cautious in what is supported. Expand this. + if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL && + originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) { + throw new IOException("Unsupported type: " + t); + } + if (originalTypes[i] == OriginalType.DECIMAL && + primitiveType.getDecimalMetadata().getPrecision() > + CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) { + throw new IOException("Decimal with high precision is not supported."); + } + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { + throw new IOException("Int96 not supported."); + } + ColumnDescriptor fd = fileSchema.getColumnDescription(requestedSchema.getPaths().get(i)); + if (!fd.equals(requestedSchema.getColumns().get(i))) { + throw new IOException("Schema evolution not supported."); + } + + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BINARY) { + ++numVarLenFields; + } + } + + /** + * Initialize rows and rowWriters. These objects are reused across all rows in the relation. + */ + int rowByteSize = UnsafeRow.calculateBitSetWidthInBytes(requestedSchema.getFieldCount()); + rowByteSize += 8 * requestedSchema.getFieldCount(); + fixedSizeBytes = rowByteSize; + rowByteSize += numVarLenFields * DEFAULT_VAR_LEN_SIZE; + containsVarLenFields = numVarLenFields > 0; + rowWriters = new UnsafeRowWriter[rows.length]; + + for (int i = 0; i < rows.length; ++i) { + rows[i] = new UnsafeRow(); + rowWriters[i] = new UnsafeRowWriter(); + BufferHolder holder = new BufferHolder(rowByteSize); + rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount()); + rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, requestedSchema.getFieldCount(), + holder.buffer.length); + } + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (batchIdx >= numBatched) { + if (!loadBatch()) return false; + } + ++batchIdx; + return true; + } + + @Override + public UnsafeRow getCurrentValue() throws IOException, InterruptedException { + return rows[batchIdx - 1]; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return (float) rowsReturned / totalRowCount; + } + + /** + * Decodes a batch of values into `rows`. This function is the hot path. + */ + private boolean loadBatch() throws IOException { + // no more records left + if (rowsReturned >= totalRowCount) { return false; } + checkEndOfRowGroup(); + + int num = (int)Math.min(rows.length, totalCountLoadedSoFar - rowsReturned); + rowsReturned += num; + + if (containsVarLenFields) { + for (int i = 0; i < rowWriters.length; ++i) { + rowWriters[i].holder().resetTo(fixedSizeBytes); + } + } + + for (int i = 0; i < columnReaders.length; ++i) { + switch (columnReaders[i].descriptor.getType()) { + case BOOLEAN: + decodeBooleanBatch(i, num); + break; + case INT32: + if (originalTypes[i] == OriginalType.DECIMAL) { + decodeIntAsDecimalBatch(i, num); + } else { + decodeIntBatch(i, num); + } + break; + case INT64: + Preconditions.checkState(originalTypes[i] == null + || originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeLongBatch(i, num); + break; + case FLOAT: + decodeFloatBatch(i, num); + break; + case DOUBLE: + decodeDoubleBatch(i, num); + break; + case BINARY: + decodeBinaryBatch(i, num); + break; + case FIXED_LEN_BYTE_ARRAY: + Preconditions.checkState(originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeFixedLenArrayAsDecimalBatch(i, num); + break; + case INT96: + throw new IOException("Unsupported " + columnReaders[i].descriptor.getType()); + } + numBatched = num; + batchIdx = 0; + } + return true; + } + + private void decodeBooleanBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setBoolean(col, columnReaders[col].nextBoolean()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setInt(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntAsDecimalBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + // Since this is stored as an INT, it is always a compact decimal. Just set it as a long. + rows[n].setLong(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeLongBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setLong(col, columnReaders[col].nextLong()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFloatBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setFloat(col, columnReaders[col].nextFloat()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeDoubleBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setDouble(col, columnReaders[col].nextDouble()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeBinaryBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer(); + int len = bytes.limit() - bytes.position(); + if (originalTypes[col] == OriginalType.UTF8) { + UTF8String str = UTF8String.fromBytes(bytes.array(), bytes.position(), len); + rowWriters[n].write(col, str); + } else { + rowWriters[n].write(col, bytes.array(), bytes.position(), len); + } + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOException { + PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType(); + int precision = type.getDecimalMetadata().getPrecision(); + int scale = type.getDecimalMetadata().getScale(); + Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(), + "Unsupported precision."); + + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + Binary v = columnReaders[col].nextBinary(); + // Constructs a `Decimal` with an unscaled `Long` value if possible. + long unscaled = CatalystRowConverter.binaryToUnscaledLong(v); + rows[n].setDecimal(col, Decimal.apply(unscaled, precision, scale), precision); + } else { + rows[n].setNullAt(col); + } + } + } + + /** + * + * Decoder to return values from a single column. + */ + private static final class ColumnReader { + /** + * Total number of values read. + */ + private long valuesRead; + + /** + * value that indicates the end of the current page. That is, + * if valuesRead == endOfPageValueCount, we are at the end of the page. + */ + private long endOfPageValueCount; + + /** + * The dictionary, if this column has dictionary encoding. + */ + private final Dictionary dictionary; + + /** + * If true, the current page is dictionary encoded. + */ + private boolean useDictionary; + + /** + * Maximum definition level for this column. + */ + private final int maxDefLevel; + + /** + * Repetition/Definition/Value readers. + */ + private IntIterator repetitionLevelColumn; + private IntIterator definitionLevelColumn; + private ValuesReader dataColumn; + + /** + * Total number of values in this column (in this row group). + */ + private final long totalValueCount; + + /** + * Total values in the current page. + */ + private int pageValueCount; + + private final PageReader pageReader; + private final ColumnDescriptor descriptor; + + public ColumnReader(ColumnDescriptor descriptor, PageReader pageReader) + throws IOException { + this.descriptor = descriptor; + this.pageReader = pageReader; + this.maxDefLevel = descriptor.getMaxDefinitionLevel(); + + DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); + if (dictionaryPage != null) { + try { + this.dictionary = dictionaryPage.getEncoding().initDictionary(descriptor, dictionaryPage); + this.useDictionary = true; + } catch (IOException e) { + throw new IOException("could not decode the dictionary for " + descriptor, e); + } + } else { + this.dictionary = null; + this.useDictionary = false; + } + this.totalValueCount = pageReader.getTotalValueCount(); + if (totalValueCount == 0) { + throw new IOException("totalValueCount == 0"); + } + } + + /** + * TODO: Hoist the useDictionary branch to decode*Batch and make the batch page aligned. + */ + public boolean nextBoolean() { + if (!useDictionary) { + return dataColumn.readBoolean(); + } else { + return dictionary.decodeToBoolean(dataColumn.readValueDictionaryId()); + } + } + + public int nextInt() { + if (!useDictionary) { + return dataColumn.readInteger(); + } else { + return dictionary.decodeToInt(dataColumn.readValueDictionaryId()); + } + } + + public long nextLong() { + if (!useDictionary) { + return dataColumn.readLong(); + } else { + return dictionary.decodeToLong(dataColumn.readValueDictionaryId()); + } + } + + public float nextFloat() { + if (!useDictionary) { + return dataColumn.readFloat(); + } else { + return dictionary.decodeToFloat(dataColumn.readValueDictionaryId()); + } + } + + public double nextDouble() { + if (!useDictionary) { + return dataColumn.readDouble(); + } else { + return dictionary.decodeToDouble(dataColumn.readValueDictionaryId()); + } + } + + public Binary nextBinary() { + if (!useDictionary) { + return dataColumn.readBytes(); + } else { + return dictionary.decodeToBinary(dataColumn.readValueDictionaryId()); + } + } + + /** + * Advances to the next value. Returns true if the value is non-null. + */ + private boolean next() throws IOException { + if (valuesRead >= endOfPageValueCount) { + if (valuesRead >= totalValueCount) { + // How do we get here? Throw end of stream exception? + return false; + } + readPage(); + } + ++valuesRead; + // TODO: Don't read for flat schemas + //repetitionLevel = repetitionLevelColumn.nextInt(); + return definitionLevelColumn.nextInt() == maxDefLevel; + } + + private void readPage() throws IOException { + DataPage page = pageReader.readPage(); + // TODO: Why is this a visitor? + page.accept(new DataPage.Visitor() { + @Override + public Void visit(DataPageV1 dataPageV1) { + try { + readPageV1(dataPageV1); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Void visit(DataPageV2 dataPageV2) { + try { + readPageV2(dataPageV2); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + }); + } + + private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset, int valueCount) + throws IOException { + this.pageValueCount = valueCount; + this.endOfPageValueCount = valuesRead + pageValueCount; + if (dataEncoding.usesDictionary()) { + if (dictionary == null) { + throw new IOException( + "could not read page in col " + descriptor + + " as the dictionary was missing for encoding " + dataEncoding); + } + this.dataColumn = dataEncoding.getDictionaryBasedValuesReader( + descriptor, VALUES, dictionary); + this.useDictionary = true; + } else { + this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES); + this.useDictionary = false; + } + + try { + dataColumn.initFromPage(pageValueCount, bytes, offset); + } catch (IOException e) { + throw new IOException("could not read page in col " + descriptor, e); + } + } + + private void readPageV1(DataPageV1 page) throws IOException { + ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); + ValuesReader dlReader = page.getDlEncoding().getValuesReader(descriptor, DEFINITION_LEVEL); + this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); + this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); + try { + byte[] bytes = page.getBytes().toByteArray(); + rlReader.initFromPage(pageValueCount, bytes, 0); + int next = rlReader.getNextOffset(); + dlReader.initFromPage(pageValueCount, bytes, next); + next = dlReader.getNextOffset(); + initDataReader(page.getValueEncoding(), bytes, next, page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + + private void readPageV2(DataPageV2 page) throws IOException { + this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(), + page.getRepetitionLevels(), descriptor); + this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(), + page.getDefinitionLevels(), descriptor); + try { + initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0, + page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + } + + private void checkEndOfRowGroup() throws IOException { + if (rowsReturned != totalCountLoadedSoFar) return; + PageReadStore pages = reader.readNextRowGroup(); + if (pages == null) { + throw new IOException("expecting more rows but reached last block. Read " + + rowsReturned + " out of " + totalRowCount); + } + List columns = requestedSchema.getColumns(); + columnReaders = new ColumnReader[columns.size()]; + for (int i = 0; i < columns.size(); ++i) { + columnReaders[i] = new ColumnReader(columns.get(i), pages.getPageReader(columns.get(i))); + } + totalCountLoadedSoFar += pages.getRowCount(); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 1f653cd3d3cb1..94298fae2d69b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -370,35 +370,13 @@ private[parquet] class CatalystRowConverter( protected def decimalFromBinary(value: Binary): Decimal = { if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { // Constructs a `Decimal` with an unscaled `Long` value if possible. - val unscaled = binaryToUnscaledLong(value) + val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) } else { // Otherwise, resorts to an unscaled `BigInteger` instead. Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale) } } - - private def binaryToUnscaledLong(binary: Binary): Long = { - // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here - // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without - // copying it. - val buffer = binary.toByteBuffer - val bytes = buffer.array() - val start = buffer.position() - val end = buffer.limit() - - var unscaled = 0L - var i = start - - while (i < end) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } - - val bits = 8 * (end - start) - unscaled = (unscaled << (64 - bits)) >> (64 - bits) - unscaled - } } private class CatalystIntDictionaryAwareDecimalConverter( @@ -658,3 +636,27 @@ private[parquet] class CatalystRowConverter( override def start(): Unit = elementConverter.start() } } + +private[parquet] object CatalystRowConverter { + def binaryToUnscaledLong(binary: Binary): Long = { + // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here + // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without + // copying it. + val buffer = binary.toByteBuffer + val bytes = buffer.array() + val start = buffer.position() + val end = buffer.limit() + + var unscaled = 0L + var i = start + + while (i < end) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } + + val bits = 8 * (end - start) + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + unscaled + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 458786f77af3f..c8028a5ef5528 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -337,7 +337,9 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { + // Renable when we can toggle custom ParquetRecordReader on/off. The custom reader does + // not do row by row filtering (and we probably don't want to push that). + ignore("SPARK-11661 Still pushdown filters returned by unhandledFilters") { import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => From da3686f5f0fb6c1b011f43f449d78c7687780681 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 18 Nov 2015 18:41:40 -0800 Subject: [PATCH 0819/1168] Revert "[SPARK-11544][SQL] sqlContext doesn't use PathFilter" This reverts commit 54db79702513e11335c33bcf3a03c59e965e6f16. --- .../apache/spark/sql/sources/interfaces.scala | 25 +++---------- .../datasources/json/JsonSuite.scala | 36 ++----------------- 2 files changed, 7 insertions(+), 54 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index f9465157c936d..b3d3bdf50df63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,8 +21,7 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} -import org.apache.hadoop.mapred.{JobConf, FileInputFormat} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -448,15 +447,9 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + logInfo(s"Listing $qualified on driver") - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(hadoopConf, this.getClass()) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - if (pathFilter != null) { - Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) - } else { - Try(fs.listStatus(qualified)).getOrElse(Array.empty) - } + Try(fs.listStatus(qualified)).getOrElse(Array.empty) }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -854,16 +847,8 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(fs.getConf, this.getClass()) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - if (pathFilter != null) { - val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index f09b61e838159..6042b1178affe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,27 +19,19 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} -import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.commons.io.FileUtils -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class TestFileFilter extends PathFilter { - override def accept(path: Path): Boolean = path.getParent.getName != "p=2" -} - class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1398,28 +1390,4 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } - - test("SPARK-11544 test pathfilter") { - withTempPath { dir => - val path = dir.getCanonicalPath - - val df = sqlContext.range(2) - df.write.json(path + "/p=1") - df.write.json(path + "/p=2") - assert(sqlContext.read.json(path).count() === 4) - - val clonedConf = new Configuration(hadoopConfiguration) - try { - hadoopConfiguration.setClass( - "mapreduce.input.pathFilter.class", - classOf[TestFileFilter], - classOf[PathFilter]) - assert(sqlContext.read.json(path).count() === 2) - } finally { - // Hadoop 1 doesn't have `Configuration.unset` - hadoopConfiguration.clear() - clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) - } - } - } } From 35d788cb37bef3ef736fa2806f84b1723fe67ee9 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 18 Nov 2015 18:49:46 -0800 Subject: [PATCH 0820/1168] [SPARK-11816][ML] fix some style issue in ML/MLlib examples jira: https://issues.apache.org/jira/browse/SPARK-11816 Currently I only fixed some obvious comments issue like // scalastyle:off println on the bottom. Yet the style in examples is not quite consistent, like only half of the examples are with // Example usage: ./bin/run-example mllib.FPGrowthExample \, Author: Yuhao Yang Closes #9808 from hhbyyh/exampleStyle. --- .../java/org/apache/spark/examples/ml/JavaKMeansExample.java | 2 +- .../apache/spark/examples/ml/AFTSurvivalRegressionExample.scala | 2 +- .../spark/examples/ml/DecisionTreeClassificationExample.scala | 1 + .../spark/examples/ml/DecisionTreeRegressionExample.scala | 1 + .../examples/ml/MultilayerPerceptronClassifierExample.scala | 2 +- 5 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java index be2bf0c7b465c..47665ff2b1f3c 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -41,7 +41,7 @@ * An example demonstrating a k-means clustering. * Run with *
    - * bin/run-example ml.JavaSimpleParamsExample  
    + * bin/run-example ml.JavaKMeansExample  
      * 
    */ public class JavaKMeansExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala index 5da285e83681f..f4b3613ccb94f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala @@ -59,4 +59,4 @@ object AFTSurvivalRegressionExample { sc.stop() } } -// scalastyle:off println +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index ff8a0a90f1e44..db024b5cad935 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -90,3 +90,4 @@ object DecisionTreeClassificationExample { // $example off$ } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index fc402724d2156..ad01f55df72b5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -78,3 +78,4 @@ object DecisionTreeRegressionExample { // $example off$ } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 146b83c8be490..9c98076bd24b1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -66,4 +66,4 @@ object MultilayerPerceptronClassifierExample { sc.stop() } } -// scalastyle:off println +// scalastyle:on println From ed863a58a37c58ab9643ad6668ecaf674e05ce9f Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Wed, 18 Nov 2015 19:37:14 -0800 Subject: [PATCH 0821/1168] [SPARK-11614][SQL] serde parameters should be set only when all params are ready see HIVE-7975 and HIVE-12373 With changed semantic of setters in thrift objects in hive, setter should be called only after all parameters are set. It's not problem of current state but will be a problem in some day. Author: navis.ryu Closes #9580 from navis/SPARK-11614. --- .../scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f4d45714fae4e..9a981d02ad67c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -804,12 +804,13 @@ private[hive] case class MetastoreRelation val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo sd.setSerdeInfo(serdeInfo) + // maps and lists should be set only after all elements are ready (see HIVE-7975) serdeInfo.setSerializationLib(p.storage.serde) val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + serdeInfo.setParameters(serdeParameters) new Partition(hiveQlTable, tPartition) } From 81c5fc2449c6cc5c88bc5610b62ec2c2c2094be9 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 18 Nov 2015 21:44:01 -0800 Subject: [PATCH 0822/1168] [SPARK-11842][ML] Small cleanups to existing Readers and Writers Updates: * Add repartition(1) to save() methods' saving of data for LogisticRegressionModel, LinearRegressionModel. * Strengthen privacy to class and companion object for Writers and Readers * Change LogisticRegressionSuite read/write test to fit intercept * Add Since versions for read/write methods in Pipeline, LogisticRegression * Switch from hand-written class names in Readers to using getClass CC: mengxr CC: yanboliang Would you mind taking a look at this PR? mengxr might not be able to soon. Thank you! Author: Joseph K. Bradley Closes #9829 from jkbradley/ml-io-cleanups. --- .../scala/org/apache/spark/ml/Pipeline.scala | 22 +++++++++++++------ .../classification/LogisticRegression.scala | 19 ++++++++++------ .../spark/ml/feature/CountVectorizer.scala | 2 +- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- .../apache/spark/ml/recommendation/ALS.scala | 6 ++--- .../ml/regression/LinearRegression.scala | 4 ++-- .../LogisticRegressionSuite.scala | 2 +- 10 files changed, 38 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index b0f22e042ec56..6f15b37abcb30 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -27,7 +27,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, Logging} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util.MLReader import org.apache.spark.ml.util.MLWriter @@ -174,16 +174,20 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } + @Since("1.6.0") override def write: MLWriter = new Pipeline.PipelineWriter(this) } +@Since("1.6.0") object Pipeline extends MLReadable[Pipeline] { + @Since("1.6.0") override def read: MLReader[Pipeline] = new PipelineReader + @Since("1.6.0") override def load(path: String): Pipeline = super.load(path) - private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter { + private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter { SharedReadWrite.validateStages(instance.getStages) @@ -191,10 +195,10 @@ object Pipeline extends MLReadable[Pipeline] { SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) } - private[ml] class PipelineReader extends MLReader[Pipeline] { + private class PipelineReader extends MLReader[Pipeline] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.Pipeline" + private val className = classOf[Pipeline].getName override def load(path: String): Pipeline = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) @@ -333,18 +337,22 @@ class PipelineModel private[ml] ( new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } + @Since("1.6.0") override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) } +@Since("1.6.0") object PipelineModel extends MLReadable[PipelineModel] { import Pipeline.SharedReadWrite + @Since("1.6.0") override def read: MLReader[PipelineModel] = new PipelineModelReader + @Since("1.6.0") override def load(path: String): PipelineModel = super.load(path) - private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { + private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) @@ -352,10 +360,10 @@ object PipelineModel extends MLReadable[PipelineModel] { instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } - private[ml] class PipelineModelReader extends MLReader[PipelineModel] { + private class PipelineModelReader extends MLReader[PipelineModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.PipelineModel" + private val className = classOf[PipelineModel].getName override def load(path: String): PipelineModel = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a3cc49f7f018c..418bbdc9a058f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -525,18 +525,23 @@ class LogisticRegressionModel private[ml] ( * * This also does not save the [[parent]] currently. */ + @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } +@Since("1.6.0") object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { + @Since("1.6.0") override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader + @Since("1.6.0") override def load(path: String): LogisticRegressionModel = super.load(path) /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ - private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) + private[LogisticRegressionModel] + class LogisticRegressionModelWriter(instance: LogisticRegressionModel) extends MLWriter with Logging { private case class Data( @@ -552,15 +557,15 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } - private[classification] class LogisticRegressionModelReader + private class LogisticRegressionModelReader extends MLReader[LogisticRegressionModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" + private val className = classOf[LogisticRegressionModel].getName override def load(path: String): LogisticRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) @@ -603,7 +608,7 @@ private[classification] class MultiClassSummarizer extends Serializable { * @return This MultilabelSummarizer */ def add(label: Double, weight: Double = 1.0): this.type = { - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -839,7 +844,7 @@ private class LogisticAggregator( instance match { case Instance(label, weight, features) => require(dim == features.size, s"Dimensions mismatch when adding new instance." + s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 4969cf42450d2..b9e2144c0ad40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -266,7 +266,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { - private val className = "org.apache.spark.ml.feature.CountVectorizerModel" + private val className = classOf[CountVectorizerModel].getName override def load(path: String): CountVectorizerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 0e00ef6f2ee20..f7b0f29a27c2d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -155,7 +155,7 @@ object IDFModel extends MLReadable[IDFModel] { private class IDFModelReader extends MLReader[IDFModel] { - private val className = "org.apache.spark.ml.feature.IDFModel" + private val className = classOf[IDFModel].getName override def load(path: String): IDFModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index ed24eabb50444..c2866f5eceff3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -210,7 +210,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { - private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" + private val className = classOf[MinMaxScalerModel].getName override def load(path: String): MinMaxScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1f689c1da1ba9..6d545219ebf49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -180,7 +180,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { private class StandardScalerModelReader extends MLReader[StandardScalerModel] { - private val className = "org.apache.spark.ml.feature.StandardScalerModel" + private val className = classOf[StandardScalerModel].getName override def load(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 97a2e4f6d6ca4..5c40c35eeaa48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -210,7 +210,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { private class StringIndexerModelReader extends MLReader[StringIndexerModel] { - private val className = "org.apache.spark.ml.feature.StringIndexerModel" + private val className = classOf[StringIndexerModel].getName override def load(path: String): StringIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 795b73c4c2121..4d35177ad9b0f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -237,7 +237,7 @@ object ALSModel extends MLReadable[ALSModel] { @Since("1.6.0") override def load(path: String): ALSModel = super.load(path) - private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter { + private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { val extraMetadata = render("rank" -> instance.rank) @@ -249,10 +249,10 @@ object ALSModel extends MLReadable[ALSModel] { } } - private[recommendation] class ALSModelReader extends MLReader[ALSModel] { + private class ALSModelReader extends MLReader[ALSModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.recommendation.ALSModel" + private val className = classOf[ALSModel].getName override def load(path: String): ALSModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 7ba1a60edaf71..70ccec766c471 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -467,14 +467,14 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { // Save model data: intercept, coefficients val data = Data(instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + private val className = classOf[LinearRegressionModel].getName override def load(path: String): LinearRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 48ce1bb630685..a9a6ff8a783d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -898,7 +898,7 @@ object LogisticRegressionSuite { "regParam" -> 0.01, "elasticNetParam" -> 0.1, "maxIter" -> 2, // intentionally small - "fitIntercept" -> false, + "fitIntercept" -> true, "tol" -> 0.8, "standardization" -> false, "threshold" -> 0.6 From 24200188e88dfdc32b47b2b09442a4425cb61073 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 18 Nov 2015 23:32:49 -0800 Subject: [PATCH 0823/1168] [SPARK-11339][SPARKR] Document the list of functions in R base package that are masked by functions with same name in SparkR Added tests for function that are reported as masked, to make sure the base:: or stats:: function can be called. For those we can't call, added them to SparkR programming guide. It would seem to me `table, sample, subset, filter, cov` not working are not actually expected - I investigated/experimented with them but couldn't get them to work. It looks like as they are defined in base or stats they are missing the S3 generic, eg. ``` > methods("transform") [1] transform,ANY-method transform.data.frame [3] transform,DataFrame-method transform.default see '?methods' for accessing help and source code > methods("subset") [1] subset.data.frame subset,DataFrame-method subset.default [4] subset.matrix see '?methods' for accessing help and source code Warning message: In .S3methods(generic.function, class, parent.frame()) : function 'subset' appears not to be S3 generic; found functions that look like S3 methods ``` Any idea? More information on masking: http://www.ats.ucla.edu/stat/r/faq/referencing_objects.htm http://www.sfu.ca/~sweldon/howTo/guide4.pdf This is what the output doc looks like (minus css): ![image](https://cloud.githubusercontent.com/assets/8969467/11229714/2946e5de-8d4d-11e5-94b0-dda9696b6fdd.png) Author: felixcheung Closes #9785 from felixcheung/rmasked. --- R/pkg/R/DataFrame.R | 2 +- R/pkg/R/functions.R | 2 +- R/pkg/R/generics.R | 4 ++-- R/pkg/inst/tests/test_mllib.R | 5 +++++ R/pkg/inst/tests/test_sparkSQL.R | 33 +++++++++++++++++++++++++++- docs/sparkr.md | 37 +++++++++++++++++++++++++++++++- 6 files changed, 77 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 34177e3cdd94f..06b0108b1389e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2152,7 +2152,7 @@ setMethod("with", }) #' Returns the column types of a DataFrame. -#' +#' #' @name coltypes #' @title Get column types of a DataFrame #' @family dataframe_funcs diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index ff0f438045c14..25a1f22101494 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2204,7 +2204,7 @@ setMethod("denseRank", #' @export #' @examples \dontrun{lag(df$c)} setMethod("lag", - signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), + signature(x = "characterOrColumn"), function(x, offset, defaultValue = NULL) { col <- if (class(x) == "Column") { x@jc diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0dcd05438222b..71004a05ba611 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -539,7 +539,7 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) # @rdname subset # @export -setGeneric("subset", function(x, subset, select, ...) { standardGeneric("subset") }) +setGeneric("subset", function(x, ...) { standardGeneric("subset") }) #' @rdname agg #' @export @@ -790,7 +790,7 @@ setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) #' @rdname lag #' @export -setGeneric("lag", function(x, offset, defaultValue = NULL) { standardGeneric("lag") }) +setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last #' @export diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index d497ad8c9daa3..e0667e5e22c18 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -31,6 +31,11 @@ test_that("glm and predict", { model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") prediction <- predict(model, test) expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + + # Test stats::predict is working + x <- rnorm(15) + y <- x + rnorm(15) + expect_equal(length(predict(lm(y ~ x))), 15) }) test_that("glm should work with long formula", { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d9a94faff7ac0..3f4f319fe745d 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -433,6 +433,10 @@ test_that("table() returns a new DataFrame", { expect_is(tabledf, "DataFrame") expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") + + # Test base::table is working + #a <- letters[1:3] + #expect_equal(class(table(a, sample(a))), "table") }) test_that("toRDD() returns an RRDD", { @@ -673,6 +677,9 @@ test_that("sample on a DataFrame", { # Also test sample_frac sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) + + # Test base::sample is working + #expect_equal(length(sample(1:12)), 12) }) test_that("select operators", { @@ -753,6 +760,9 @@ test_that("subsetting", { df6 <- subset(df, df$age %in% c(30), c(1,2)) expect_equal(count(df6), 1) expect_equal(columns(df6), c("name", "age")) + + # Test base::subset is working + expect_equal(nrow(subset(airquality, Temp > 80, select = c(Ozone, Temp))), 68) }) test_that("selectExpr() on a DataFrame", { @@ -888,6 +898,9 @@ test_that("column functions", { expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + + # Test that stats::lag is working + expect_equal(length(lag(ldeaths, 12)), 72) }) # test_that("column binary mathfunctions", { @@ -1086,7 +1099,7 @@ test_that("group by, agg functions", { gd3_local <- collect(agg(gd3, var(df8$age))) expect_equal(162, gd3_local[gd3_local$name == "Justin",][1, 2]) - # make sure base:: or stats::sd, var are working + # Test stats::sd, stats::var are working expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) @@ -1138,6 +1151,9 @@ test_that("filter() on a DataFrame", { expect_equal(count(filtered5), 1) filtered6 <- where(df, df$age %in% c(19, 30)) expect_equal(count(filtered6), 2) + + # Test stats::filter is working + #expect_true(is.ts(filter(1:100, rep(1, 3)))) }) test_that("join() and merge() on a DataFrame", { @@ -1284,6 +1300,12 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { expect_is(unioned, "DataFrame") expect_equal(count(intersected), 1) expect_equal(first(intersected)$name, "Andy") + + # Test base::rbind is working + expect_equal(length(rbind(1:4, c = 2, a = 10, 10, deparse.level = 0)), 16) + + # Test base::intersect is working + expect_equal(length(intersect(1:20, 3:23)), 18) }) test_that("withColumn() and withColumnRenamed()", { @@ -1365,6 +1387,9 @@ test_that("describe() and summarize() on a DataFrame", { stats2 <- summary(df) expect_equal(collect(stats2)[4, "name"], "Andy") expect_equal(collect(stats2)[5, "age"], "30") + + # Test base::summary is working + expect_equal(length(summary(attenu, digits = 4)), 35) }) test_that("dropna() and na.omit() on a DataFrame", { @@ -1448,6 +1473,9 @@ test_that("dropna() and na.omit() on a DataFrame", { expect_identical(expected, actual) actual <- collect(na.omit(df, minNonNulls = 3, cols = c("name", "age", "height"))) expect_identical(expected, actual) + + # Test stats::na.omit is working + expect_equal(nrow(na.omit(data.frame(x = c(0, 10, NA)))), 2) }) test_that("fillna() on a DataFrame", { @@ -1510,6 +1538,9 @@ test_that("cov() and corr() on a DataFrame", { expect_true(abs(result - 1.0) < 1e-12) result <- corr(df, "singles", "doubles", "pearson") expect_true(abs(result - 1.0) < 1e-12) + + # Test stats::cov is working + #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3) }) test_that("freqItems() on a DataFrame", { diff --git a/docs/sparkr.md b/docs/sparkr.md index a744b76be7466..cfb9b41350f45 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -286,7 +286,7 @@ head(teenagers) # Machine Learning -SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html). @@ -351,3 +351,38 @@ summary(model) ##Sepal_Width 0.404655 {% endhighlight %}
    + +# R Function Name Conflicts + +When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a +function is masking another function. + +The following functions are masked by the SparkR package: + + + + + + + + + + + + + + + + + + + +
    Masked functionHow to Access
    cov in package:stats
    stats::cov(x, y = NULL, use = "everything",
    +           method = c("pearson", "kendall", "spearman"))
    filter in package:stats
    stats::filter(x, filter, method = c("convolution", "recursive"),
    +              sides = 2, circular = FALSE, init)
    sample in package:basebase::sample(x, size, replace = FALSE, prob = NULL)
    table in package:base
    base::table(...,
    +            exclude = if (useNA == "no") c(NA, NaN),
    +            useNA = c("no", "ifany", "always"),
    +            dnn = list.names(...), deparse.level = 1)
    + +You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) + From 42f5dfa5a47c22b5d7ffe00c7bfd1962374c5a48 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Nov 2015 10:48:04 -0800 Subject: [PATCH 0824/1168] [SPARK-11849][SQL] Analyzer should replace current_date and current_timestamp with literals We currently rely on the optimizer's constant folding to replace current_timestamp and current_date. However, this can still result in different values for different instances of current_timestamp/current_date if the optimizer is not running fast enough. A better solution is to replace these functions in the analyzer in one shot. Author: Reynold Xin Closes #9833 from rxin/SPARK-11849. --- .../sql/catalyst/analysis/Analyzer.scala | 27 ++++++++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 38 +++++++++++++++++++ 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f00c451b5981a..84781cd57f3dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -65,9 +65,8 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, - CTESubstitution :: - WindowsSubstitution :: - Nil : _*), + CTESubstitution, + WindowsSubstitution), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: @@ -84,7 +83,8 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic), + PullOutNondeterministic, + ComputeCurrentTime), Batch("UDF", Once, HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, @@ -1076,7 +1076,7 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. - case plan => plan transformExpressionsUp { + case p => p transformExpressionsUp { case udf @ ScalaUDF(func, _, inputs, _) => val parameterTypes = ScalaReflection.getParameterTypes(func) @@ -1162,3 +1162,20 @@ object CleanupAliases extends Rule[LogicalPlan] { } } } + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 08586a97411ae..e051069951887 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class AnalysisSuite extends AnalysisTest { @@ -218,4 +219,41 @@ class AnalysisSuite extends AnalysisTest { udf4) // checkUDF(udf4, expected4) } + + test("analyzer should replace current_timestamp with literals") { + val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), + LocalRelation()) + + val min = System.currentTimeMillis() * 1000 + val plan = in.analyze.asInstanceOf[Project] + val max = (System.currentTimeMillis() + 1) * 1000 + + val lits = new scala.collection.mutable.ArrayBuffer[Long] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Long] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } + + test("analyzer should replace current_date with literals") { + val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) + + val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val plan = in.analyze.asInstanceOf[Project] + val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) + + val lits = new scala.collection.mutable.ArrayBuffer[Int] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Int] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } } From 8cc4b69bf192760ced7e0faa722b380ed8c5212c Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 19 Nov 2015 11:02:17 -0800 Subject: [PATCH 0825/1168] [SPARK-11840][SQL] Restore the 1.5's behavior of planning a single distinct aggregation. The impact of this change is for a query that has a single distinct column and does not have any grouping expression like `SELECT COUNT(DISTINCT a) FROM table` The plan will be changed from ``` AGG-2 (count distinct) Shuffle to a single reducer Partial-AGG-2 (count distinct) AGG-1 (grouping on a) Shuffle by a Partial-AGG-1 (grouping on 1) ``` to the following one (1.5 uses this) ``` AGG-2 AGG-1 (grouping on a) Shuffle to a single reducer Partial-AGG-1(grouping on a) ``` The first plan is more robust. However, to better benchmark the impact of this change, we should use 1.5's plan and use the conf of `spark.sql.specializeSingleDistinctAggPlanning` to control the plan. Author: Yin Huai Closes #9828 from yhuai/distinctRewriter. --- .../sql/catalyst/analysis/DistinctAggregationRewriter.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index c0c960471a61a..9c78f6d4cc71b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -126,8 +126,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) { // When the flag is set to specialize single distinct agg planning, // we will rely on our Aggregation strategy to handle queries with a single - // distinct column and this aggregate operator does have grouping expressions. - distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && a.groupingExpressions.isEmpty) + // distinct column. + distinctAggGroups.size > 1 } else { distinctAggGroups.size >= 1 } From 2d75201ce10f54a0dad8e719c6f92643245b8cca Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 19 Nov 2015 11:57:50 -0800 Subject: [PATCH 0826/1168] [SPARK-11830][CORE] Make NettyRpcEnv bind to the specified host This PR includes the following change: 1. Bind NettyRpcEnv to the specified host 2. Fix the port information in the log for NettyRpcEnv. 3. Fix the service name of NettyRpcEnv. Author: zsxwing Author: Shixiong Zhu Closes #9821 from zsxwing/SPARK-11830. --- .../src/main/scala/org/apache/spark/SparkEnv.scala | 9 ++++++++- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 7 +++---- .../org/apache/spark/network/TransportContext.java | 8 +++++++- .../spark/network/server/TransportServer.java | 14 ++++++++++---- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4474a83bedbdb..88df27f733f2a 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -258,8 +258,15 @@ object SparkEnv extends Logging { if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem } else { + val actorSystemPort = if (port == 0) 0 else rpcEnv.address.port + 1 // Create a ActorSystem for legacy codes - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1 + AkkaUtils.createActorSystem( + actorSystemName + "ActorSystem", + hostname, + actorSystemPort, + conf, + securityManager + )._1 } // Figure out which port Akka actually bound to in case the original port is 0 or occupied. diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 3e0c497969502..3ce359868039b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -102,7 +102,7 @@ private[netty] class NettyRpcEnv( } else { java.util.Collections.emptyList() } - server = transportContext.createServer(port, bootstraps) + server = transportContext.createServer(host, port, bootstraps) dispatcher.registerRpcEndpoint( RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) } @@ -337,10 +337,10 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { if (!config.clientMode) { val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => nettyEnv.startServer(actualPort) - (nettyEnv, actualPort) + (nettyEnv, nettyEnv.address.port) } try { - Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 + Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1 } catch { case NonFatal(e) => nettyEnv.shutdown() @@ -370,7 +370,6 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { * @param conf Spark configuration. * @param endpointAddress The address where the endpoint is listening. * @param nettyEnv The RpcEnv associated with this ref. - * @param local Whether the referenced endpoint lives in the same process. */ private[netty] class NettyRpcEndpointRef( @transient private val conf: SparkConf, diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 1b64b863a9fe5..238710d17249a 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -94,7 +94,13 @@ public TransportClientFactory createClientFactory() { /** Create a server which will attempt to bind to a specific port. */ public TransportServer createServer(int port, List bootstraps) { - return new TransportServer(this, port, rpcHandler, bootstraps); + return new TransportServer(this, null, port, rpcHandler, bootstraps); + } + + /** Create a server which will attempt to bind to a specific host and port. */ + public TransportServer createServer( + String host, int port, List bootstraps) { + return new TransportServer(this, host, port, rpcHandler, bootstraps); } /** Creates a new server, binding to any available ephemeral port. */ diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index f4fadb1ee3b8d..baae235e02205 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -55,9 +55,13 @@ public class TransportServer implements Closeable { private ChannelFuture channelFuture; private int port = -1; - /** Creates a TransportServer that binds to the given port, or to any available if 0. */ + /** + * Creates a TransportServer that binds to the given host and the given port, or to any available + * if 0. If you don't want to bind to any special host, set "hostToBind" to null. + * */ public TransportServer( TransportContext context, + String hostToBind, int portToBind, RpcHandler appRpcHandler, List bootstraps) { @@ -67,7 +71,7 @@ public TransportServer( this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); try { - init(portToBind); + init(hostToBind, portToBind); } catch (RuntimeException e) { JavaUtils.closeQuietly(this); throw e; @@ -81,7 +85,7 @@ public int getPort() { return port; } - private void init(int portToBind) { + private void init(String hostToBind, int portToBind) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); EventLoopGroup bossGroup = @@ -120,7 +124,9 @@ protected void initChannel(SocketChannel ch) throws Exception { } }); - channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); + InetSocketAddress address = hostToBind == null ? + new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind); + channelFuture = bootstrap.bind(address); channelFuture.syncUninterruptibly(); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); From ce26610955a0081f17996d64f5f31810bee6b13b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 19 Nov 2015 12:45:04 -0800 Subject: [PATCH 0827/1168] [SPARK-11633][SQL] LogicalRDD throws TreeNode Exception : Failed to Copy Node When handling self joins, the implementation did not consider the case insensitivity of HiveContext. It could cause an exception as shown in the JIRA: ``` TreeNodeException: Failed to copy node. ``` The fix is low risk. It avoids unnecessary attribute replacement. It should not affect the existing behavior of self joins. Also added the test case to cover this case. Author: gatorsmile Closes #9762 from gatorsmile/joinMakeCopy. --- .../apache/spark/sql/execution/ExistingRDD.scala | 4 ++++ .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 62620ec642c78..623348f6768a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -74,6 +74,10 @@ private[sql] case class LogicalRDD( override def children: Seq[LogicalPlan] = Nil + override protected final def otherCopyArgs: Seq[AnyRef] = { + sqlContext :: Nil + } + override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6399b0165c4c3..dd6d06512ff60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1110,6 +1110,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + // This test case is to verify a bug when making a new instance of LogicalRDD. + test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1))) + val df = sqlContext.createDataFrame( + rdd, + new StructType().add("f1", IntegerType).add("f2", IntegerType), + needsConversion = false).select($"F1", $"f2".as("f2")) + val df1 = df.as("a") + val df2 = df.as("b") + checkAnswer(df1.join(df2, $"a.f2" === $"b.f2"), Row(1, 3, 1, 3) :: Row(2, 1, 2, 1) :: Nil) + } + } + test("SPARK-10656: completely support special chars") { val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.") checkAnswer(df.select(df("*")), Row(1, "a")) From c05d8a9507813939373f74537e18ecc279e9dddd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 19 Nov 2015 12:46:36 -0800 Subject: [PATCH 0828/1168] [SPARK-11848][SQL] Support EXPLAIN in DataSet APIs When debugging DataSet API, I always need to print the logical and physical plans. I am wondering if we should provide a simple API for EXPLAIN? Author: gatorsmile Closes #9832 from gatorsmile/explainDS. --- .../org/apache/spark/sql/DataFrame.scala | 23 +------------------ .../spark/sql/execution/Queryable.scala | 21 +++++++++++++++++ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3ba4ba18d2122..98358127e2709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -308,27 +308,6 @@ class DataFrame private[sql]( def printSchema(): Unit = println(schema.treeString) // scalastyle:on println - /** - * Prints the plans (logical and physical) to the console for debugging purposes. - * @group basic - * @since 1.3.0 - */ - def explain(extended: Boolean): Unit = { - val explain = ExplainCommand(queryExecution.logical, extended = extended) - withPlan(explain).queryExecution.executedPlan.executeCollect().foreach { - // scalastyle:off println - r => println(r.getString(0)) - // scalastyle:on println - } - } - - /** - * Only prints the physical plan to the console for debugging purposes. - * @group basic - * @since 1.3.0 - */ - def explain(): Unit = explain(extended = false) - /** * Returns true if the `collect` and `take` methods can be run locally * (without any Spark executors). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index 9ca383896a09b..e86a52c149a2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType import scala.util.control.NonFatal @@ -25,6 +26,7 @@ import scala.util.control.NonFatal private[sql] trait Queryable { def schema: StructType def queryExecution: QueryExecution + def sqlContext: SQLContext override def toString: String = { try { @@ -34,4 +36,23 @@ private[sql] trait Queryable { s"Invalid tree; ${e.getMessage}:\n$queryExecution" } } + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @since 1.3.0 + */ + def explain(extended: Boolean): Unit = { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + // scalastyle:off println + r => println(r.getString(0)) + // scalastyle:on println + } + } + + /** + * Only prints the physical plan to the console for debugging purposes. + * @since 1.3.0 + */ + def explain(): Unit = explain(extended = false) } From 923dafda886beed368ec662ac6e28a4501701236 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Nov 2015 12:54:25 -0800 Subject: [PATCH 0829/1168] [SPARK-11750][SQL] revert SPARK-11727 and code clean up After some experiment, I found it's not convenient to have separate encoder builders: `FlatEncoder` and `ProductEncoder`. For example, when create encoders for `ScalaUDF`, we have no idea if the type `T` is flat or not. So I revert the splitting change in https://github.com/apache/spark/pull/9693, while still keeping the bug fixes and tests. Author: Wenchen Fan Closes #9726 from cloud-fan/follow. --- .../scala/org/apache/spark/sql/Encoder.scala | 16 +- .../spark/sql/catalyst/ScalaReflection.scala | 354 +++++--------- .../catalyst/encoders/ExpressionEncoder.scala | 19 +- .../sql/catalyst/encoders/FlatEncoder.scala | 50 -- .../catalyst/encoders/ProductEncoder.scala | 452 ------------------ .../sql/catalyst/encoders/RowEncoder.scala | 12 +- .../sql/catalyst/expressions/objects.scala | 7 +- .../sql/catalyst/ScalaReflectionSuite.scala | 68 --- .../encoders/ExpressionEncoderSuite.scala | 218 ++++++++- .../catalyst/encoders/FlatEncoderSuite.scala | 99 ---- .../encoders/ProductEncoderSuite.scala | 156 ------ .../org/apache/spark/sql/GroupedDataset.scala | 4 +- .../org/apache/spark/sql/SQLImplicits.scala | 23 +- .../org/apache/spark/sql/functions.scala | 4 +- 14 files changed, 364 insertions(+), 1118 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index d54f2854fb33f..86bb536459035 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -45,14 +45,14 @@ trait Encoder[T] extends Serializable { */ object Encoders { - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) - def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) - def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) - def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() + def INT: Encoder[java.lang.Integer] = ExpressionEncoder() + def LONG: Encoder[java.lang.Long] = ExpressionEncoder() + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() + def STRING: Encoder[java.lang.String] = ExpressionEncoder() /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 59ccf356f2c48..33ae700706dae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -50,39 +50,29 @@ object ScalaReflection extends ScalaReflection { * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type * system. As a result, ObjectType will be returned for things like boxed Integers */ - def dataTypeFor(tpe: `Type`): DataType = tpe match { - case t if t <:< definitions.IntTpe => IntegerType - case t if t <:< definitions.LongTpe => LongType - case t if t <:< definitions.DoubleTpe => DoubleType - case t if t <:< definitions.FloatTpe => FloatType - case t if t <:< definitions.ShortTpe => ShortType - case t if t <:< definitions.ByteTpe => ByteType - case t if t <:< definitions.BooleanTpe => BooleanType - case t if t <:< localTypeOf[Array[Byte]] => BinaryType - case _ => - val className: String = tpe.erasure.typeSymbol.asClass.fullName - className match { - case "scala.Array" => - val TypeRef(_, _, Seq(arrayType)) = tpe - val cls = arrayType match { - case t if t <:< definitions.IntTpe => classOf[Array[Int]] - case t if t <:< definitions.LongTpe => classOf[Array[Long]] - case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] - case t if t <:< definitions.FloatTpe => classOf[Array[Float]] - case t if t <:< definitions.ShortTpe => classOf[Array[Short]] - case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] - case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] - case other => - // There is probably a better way to do this, but I couldn't find it... - val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls - java.lang.reflect.Array.newInstance(elementType, 1).getClass + def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) - } - ObjectType(cls) - case other => - val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) - ObjectType(clazz) - } + private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { + tpe match { + case t if t <:< definitions.IntTpe => IntegerType + case t if t <:< definitions.LongTpe => LongType + case t if t <:< definitions.DoubleTpe => DoubleType + case t if t <:< definitions.FloatTpe => FloatType + case t if t <:< definitions.ShortTpe => ShortType + case t if t <:< definitions.ByteTpe => ByteType + case t if t <:< definitions.BooleanTpe => BooleanType + case t if t <:< localTypeOf[Array[Byte]] => BinaryType + case _ => + val className: String = tpe.erasure.typeSymbol.asClass.fullName + className match { + case "scala.Array" => + val TypeRef(_, _, Seq(elementType)) = tpe + arrayClassFor(elementType) + case other => + val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + ObjectType(clazz) + } + } } /** @@ -90,7 +80,7 @@ object ScalaReflection extends ScalaReflection { * Array[T]. Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ - def arrayClassFor(tpe: `Type`): DataType = { + private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { val cls = tpe match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] @@ -108,6 +98,15 @@ object ScalaReflection extends ScalaReflection { ObjectType(cls) } + /** + * Returns true if the value of this data type is same between internal and external. + */ + def isNativeType(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType => true + case _ => false + } + /** * Returns an expression that can be used to construct an object of type `T` given an input * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes @@ -116,63 +115,33 @@ object ScalaReflection extends ScalaReflection { * * When used on a primitive type, the constructor will instead default to extracting the value * from ordinal 0 (since there are no names to map to). The actual location can be moved by - * calling unbind/bind with a new schema. + * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) + def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None) private def constructorFor( tpe: `Type`, path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = - path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = - path - .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal)) - .getOrElse(BoundReference(ordinal, dataType, false)) + def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path + .map(p => GetInternalRowField(p, ordinal, dataType)) + .getOrElse(BoundReference(ordinal, dataType, false)) - /** Returns the current path or throws an error. */ - def getPath = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) tpe match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => - getPath + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - val boxedType = optType match { - // For primitive types we must manually box the primitive value. - case t if t <:< definitions.IntTpe => Some(classOf[java.lang.Integer]) - case t if t <:< definitions.LongTpe => Some(classOf[java.lang.Long]) - case t if t <:< definitions.DoubleTpe => Some(classOf[java.lang.Double]) - case t if t <:< definitions.FloatTpe => Some(classOf[java.lang.Float]) - case t if t <:< definitions.ShortTpe => Some(classOf[java.lang.Short]) - case t if t <:< definitions.ByteTpe => Some(classOf[java.lang.Byte]) - case t if t <:< definitions.BooleanTpe => Some(classOf[java.lang.Boolean]) - case _ => None - } - - boxedType.map { boxedType => - val objectType = ObjectType(boxedType) - WrapOption( - objectType, - NewInstance( - boxedType, - getPath :: Nil, - propagateNull = true, - objectType)) - }.getOrElse { - val className: String = optType.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) - val objectType = ObjectType(cls) - - WrapOption(objectType, constructorFor(optType, path)) - } + WrapOption(constructorFor(optType, path)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] @@ -231,11 +200,11 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.math.BigDecimal] => Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + case t if t <:< localTypeOf[BigDecimal] => + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => Some("toIntArray") case t if t <:< definitions.LongTpe => Some("toLongArray") @@ -248,57 +217,52 @@ object ScalaReflection extends ScalaReflection { } primitiveMethod.map { method => - Invoke(getPath, method, dataTypeFor(t)) + Invoke(getPath, method, arrayClassFor(elementType)) }.getOrElse { - val returnType = dataTypeFor(t) Invoke( - MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), "array", - returnType) + arrayClassFor(elementType)) } + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val arrayData = + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - - val primitiveMethodKey = keyType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } val keyData = Invoke( MapObjects( p => constructorFor(keyType, Some(p)), - Invoke(getPath, "keyArray", ArrayType(keyDataType)), - keyDataType), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + schemaFor(keyType).dataType), "array", ObjectType(classOf[Array[Any]])) - val primitiveMethodValue = valueType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - val valueData = Invoke( MapObjects( p => constructorFor(valueType, Some(p)), - Invoke(getPath, "valueArray", ArrayType(valueDataType)), - valueDataType), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + schemaFor(valueType).dataType), "array", ObjectType(classOf[Array[Any]])) @@ -308,40 +272,6 @@ object ScalaReflection extends ScalaReflection { "toScalaMap", keyData :: valueData :: Nil) - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - // Avoid boxing when possible by just wrapping a primitive array. - val primitiveMethod = elementType match { - case _ if nullable => None - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - - val arrayData = primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - Invoke( - MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), - "array", - arrayClassFor(elementType)) - } - - StaticInvoke( - scala.collection.mutable.WrappedArray, - ObjectType(classOf[Seq[_]]), - "make", - arrayData :: Nil) - - case t if t <:< localTypeOf[Product] => val formalTypeArgs = t.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = t @@ -361,8 +291,7 @@ object ScalaReflection extends ScalaReflection { } } - val className: String = t.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) + val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString @@ -370,7 +299,7 @@ object ScalaReflection extends ScalaReflection { val dataType = schemaFor(fieldType).dataType // For tuples, we based grab the inner fields by ordinal instead of name. - if (className startsWith "scala.Tuple") { + if (cls.getName startsWith "scala.Tuple") { constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) } else { constructorFor(fieldType, Some(addToPath(fieldName))) @@ -388,22 +317,19 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } - } } /** Returns expressions for extracting all the fields from the given type. */ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { - ScalaReflectionLock.synchronized { - extractorFor(inputObject, typeTag[T].tpe) match { - case s: CreateNamedStruct => s - case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil) - } + extractorFor(inputObject, localTypeOf[T]) match { + case s: CreateNamedStruct => s + case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } /** Helper for extracting internal fields from a case class. */ - protected def extractorFor( + private def extractorFor( inputObject: Expression, tpe: `Type`): Expression = ScalaReflectionLock.synchronized { if (!inputObject.dataType.isInstanceOf[ObjectType]) { @@ -491,51 +417,36 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - if (!elementDataType.isInstanceOf[AtomicType]) { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } else { - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = ArrayType(dataType, nullable)) - } + toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - if (dataType.isInstanceOf[AtomicType]) { - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = ArrayType(dataType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } + toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - val rawMap = inputObject val keys = - NewInstance( - classOf[GenericArrayData], - Invoke(rawMap, "keys", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, - dataType = ObjectType(classOf[ArrayData])) + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + val values = - NewInstance( - classOf[GenericArrayData], - Invoke(rawMap, "values", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, - dataType = ObjectType(classOf[ArrayData])) + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) NewInstance( classOf[ArrayBasedMapData], - keys :: values :: Nil, + convertedKeys :: convertedValues :: Nil, dataType = MapType(keyDataType, valueDataType, valueNullable)) case t if t <:< localTypeOf[String] => @@ -558,6 +469,7 @@ object ScalaReflection extends ScalaReflection { DateType, "fromJavaDate", inputObject :: Nil) + case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( Decimal, @@ -587,26 +499,24 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) - case t if t <:< definitions.IntTpe => - BoundReference(0, IntegerType, false) - case t if t <:< definitions.LongTpe => - BoundReference(0, LongType, false) - case t if t <:< definitions.DoubleTpe => - BoundReference(0, DoubleType, false) - case t if t <:< definitions.FloatTpe => - BoundReference(0, FloatType, false) - case t if t <:< definitions.ShortTpe => - BoundReference(0, ShortType, false) - case t if t <:< definitions.ByteTpe => - BoundReference(0, ByteType, false) - case t if t <:< definitions.BooleanTpe => - BoundReference(0, BooleanType, false) - case other => throw new UnsupportedOperationException(s"Extractor for type $other is not supported") } } } + + private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = schemaFor(elementType) + if (isNativeType(catalystType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), input, externalDataType) + } + } } /** @@ -635,8 +545,7 @@ trait ScalaReflection { } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = - ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } + def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) /** * Return the Scala Type for `T` in the current classloader mirror. @@ -736,39 +645,4 @@ trait ScalaReflection { assert(methods.length == 1) methods.head.getParameterTypes } - - def typeOfObject: PartialFunction[Any, DataType] = { - // The data type can be determined without ambiguity. - case obj: Boolean => BooleanType - case obj: Array[Byte] => BinaryType - case obj: String => StringType - case obj: UTF8String => StringType - case obj: Byte => ByteType - case obj: Short => ShortType - case obj: Int => IntegerType - case obj: Long => LongType - case obj: Float => FloatType - case obj: Double => DoubleType - case obj: java.sql.Date => DateType - case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case obj: Decimal => DecimalType.SYSTEM_DEFAULT - case obj: java.sql.Timestamp => TimestampType - case null => NullType - // For other cases, there is no obvious mapping from the type of the given object to a - // Catalyst data type. A user should provide his/her specific rules - // (in a user-defined PartialFunction) to infer the Catalyst data type for other types of - // objects and then compose the user-defined PartialFunction with this one. - } - - implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { - - /** - * Implicitly added to Sequences of case class objects. Returns a catalyst logical relation - * for the the data in the sequence. - */ - def asRelation: LocalRelation = { - val output = attributesFor[A] - LocalRelation.fromProduct(output, data) - } - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 456b595008479..6eeba1442c1f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -30,10 +30,10 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType} +import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** - * A factory for constructing encoders that convert objects and primitves to and from the + * A factory for constructing encoders that convert objects and primitives to and from the * internal row format using catalyst expressions and code generation. By default, the * expressions used to retrieve values from an input row when producing an object will be created as * follows: @@ -44,20 +44,21 @@ import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType * to the name `value`. */ object ExpressionEncoder { - def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = { + def apply[T : TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = typeTag[T].mirror val cls = mirror.runtimeClass(typeTag[T].tpe) + val flat = !classOf[Product].isAssignableFrom(cls) - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpression = ScalaReflection.extractorsFor[T](inputObject) - val constructExpression = ScalaReflection.constructorFor[T] + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) + val fromRowExpression = ScalaReflection.constructorFor[T] new ExpressionEncoder[T]( - extractExpression.dataType, + toRowExpression.dataType, flat, - extractExpression.flatten, - constructExpression, + toRowExpression.flatten, + fromRowExpression, ClassTag[T](cls)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala deleted file mode 100644 index 6d307ab13a9fc..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag} - -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.catalyst.expressions.{Literal, CreateNamedStruct, BoundReference} -import org.apache.spark.sql.catalyst.ScalaReflection - -object FlatEncoder { - import ScalaReflection.schemaFor - import ScalaReflection.dataTypeFor - - def apply[T : TypeTag]: ExpressionEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. - val tpe = typeTag[T].tpe - val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(tpe) - assert(!schemaFor(tpe).dataType.isInstanceOf[StructType]) - - val input = BoundReference(0, dataTypeFor(tpe), nullable = true) - val toRowExpression = CreateNamedStruct( - Literal("value") :: ProductEncoder.extractorFor(input, tpe) :: Nil) - val fromRowExpression = ProductEncoder.constructorFor(tpe) - - new ExpressionEncoder[T]( - toRowExpression.dataType, - flat = true, - toRowExpression.flatten, - fromRowExpression, - ClassTag[T](cls)) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala deleted file mode 100644 index 2914c6ee790ce..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ /dev/null @@ -1,452 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import org.apache.spark.util.Utils -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ArrayBasedMapData, GenericArrayData} - -import scala.reflect.ClassTag - -object ProductEncoder { - import ScalaReflection.universe._ - import ScalaReflection.mirror - import ScalaReflection.localTypeOf - import ScalaReflection.dataTypeFor - import ScalaReflection.Schema - import ScalaReflection.schemaFor - import ScalaReflection.arrayClassFor - - def apply[T <: Product : TypeTag]: ExpressionEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. - val tpe = typeTag[T].tpe - val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(tpe) - - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val toRowExpression = extractorFor(inputObject, tpe).asInstanceOf[CreateNamedStruct] - val fromRowExpression = constructorFor(tpe) - - new ExpressionEncoder[T]( - toRowExpression.dataType, - flat = false, - toRowExpression.flatten, - fromRowExpression, - ClassTag[T](cls)) - } - - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - import scala.collection.Map - - def extractorFor( - inputObject: Expression, - tpe: `Type`): Expression = ScalaReflectionLock.synchronized { - if (!inputObject.dataType.isInstanceOf[ObjectType]) { - inputObject - } else { - tpe match { - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - optType match { - // For primitive types we must manually unbox the value of the object. - case t if t <:< definitions.IntTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), - "intValue", - IntegerType) - case t if t <:< definitions.LongTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), - "longValue", - LongType) - case t if t <:< definitions.DoubleTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), - "doubleValue", - DoubleType) - case t if t <:< definitions.FloatTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), - "floatValue", - FloatType) - case t if t <:< definitions.ShortTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), - "shortValue", - ShortType) - case t if t <:< definitions.ByteTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), - "byteValue", - ByteType) - case t if t <:< definitions.BooleanTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), - "booleanValue", - BooleanType) - - // For non-primitives, we can just extract the object from the Option and then recurse. - case other => - val className: String = optType.erasure.typeSymbol.asClass.fullName - val classObj = Utils.classForName(className) - val optionObjectType = ObjectType(classObj) - - val unwrapped = UnwrapOption(optionObjectType, inputObject) - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, schemaFor(optType).dataType), - extractorFor(unwrapped, optType)) - } - - case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } - - CreateNamedStruct(params.head.flatMap { p => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil - }) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keys = - Invoke( - Invoke(inputObject, "keysIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = toCatalystArray(keys, keyType) - - val values = - Invoke( - Invoke(inputObject, "valuesIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = toCatalystArray(values, valueType) - - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - NewInstance( - classOf[ArrayBasedMapData], - convertedKeys :: convertedValues :: Nil, - dataType = MapType(keyDataType, valueDataType, valueNullable)) - - case t if t <:< localTypeOf[String] => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils, - DateType, - "fromJavaDate", - inputObject :: Nil) - - case t if t <:< localTypeOf[BigDecimal] => - StaticInvoke( - Decimal, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - StaticInvoke( - Decimal, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.lang.Integer] => - Invoke(inputObject, "intValue", IntegerType) - case t if t <:< localTypeOf[java.lang.Long] => - Invoke(inputObject, "longValue", LongType) - case t if t <:< localTypeOf[java.lang.Double] => - Invoke(inputObject, "doubleValue", DoubleType) - case t if t <:< localTypeOf[java.lang.Float] => - Invoke(inputObject, "floatValue", FloatType) - case t if t <:< localTypeOf[java.lang.Short] => - Invoke(inputObject, "shortValue", ShortType) - case t if t <:< localTypeOf[java.lang.Byte] => - Invoke(inputObject, "byteValue", ByteType) - case t if t <:< localTypeOf[java.lang.Boolean] => - Invoke(inputObject, "booleanValue", BooleanType) - - case other => - throw new UnsupportedOperationException(s"Encoder for type $other is not supported") - } - } - } - - private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { - val externalDataType = dataTypeFor(elementType) - val Schema(catalystType, nullable) = schemaFor(elementType) - if (RowEncoder.isNativeType(catalystType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(catalystType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), input, externalDataType) - } - } - - def constructorFor( - tpe: `Type`, - path: Option[Expression] = None): Expression = ScalaReflectionLock.synchronized { - - /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) - - /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetInternalRowField(p, ordinal, dataType)) - .getOrElse(BoundReference(ordinal, dataType, false)) - - /** Returns the current path or `BoundReference`. */ - def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) - - tpe match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath - - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - WrapOption(null, constructorFor(optType, path)) - - case t if t <:< localTypeOf[java.lang.Integer] => - val boxedType = classOf[java.lang.Integer] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Long] => - val boxedType = classOf[java.lang.Long] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Double] => - val boxedType = classOf[java.lang.Double] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Float] => - val boxedType = classOf[java.lang.Float] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Short] => - val boxedType = classOf[java.lang.Short] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Byte] => - val boxedType = classOf[java.lang.Byte] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Boolean] => - val boxedType = classOf[java.lang.Boolean] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils, - ObjectType(classOf[java.sql.Date]), - "toJavaDate", - getPath :: Nil, - propagateNull = true) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils, - ObjectType(classOf[java.sql.Timestamp]), - "toJavaTimestamp", - getPath :: Nil, - propagateNull = true) - - case t if t <:< localTypeOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String])) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) - - case t if t <:< localTypeOf[BigDecimal] => - Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val primitiveMethod = elementType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - - primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - Invoke( - MapObjects( - p => constructorFor(elementType, Some(p)), - getPath, - schemaFor(elementType).dataType), - "array", - arrayClassFor(elementType)) - } - - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val arrayData = - Invoke( - MapObjects( - p => constructorFor(elementType, Some(p)), - getPath, - schemaFor(elementType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - scala.collection.mutable.WrappedArray, - ObjectType(classOf[Seq[_]]), - "make", - arrayData :: Nil) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keyData = - Invoke( - MapObjects( - p => constructorFor(keyType, Some(p)), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), - schemaFor(keyType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - val valueData = - Invoke( - MapObjects( - p => constructorFor(valueType, Some(p)), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), - schemaFor(valueType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - ArrayBasedMapData, - ObjectType(classOf[Map[_, _]]), - "toScalaMap", - keyData :: valueData :: Nil) - - case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } - - val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) - - val arguments = params.head.zipWithIndex.map { case (p, i) => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val dataType = schemaFor(fieldType).dataType - - // For tuples, we based grab the inner fields by ordinal instead of name. - if (cls.getName startsWith "scala.Tuple") { - constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) - } else { - constructorFor(fieldType, Some(addToPath(fieldName))) - } - } - - val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) - - if (path.nonEmpty) { - expressions.If( - IsNull(getPath), - expressions.Literal.create(null, ObjectType(cls)), - newInstance - ) - } else { - newInstance - } - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 9bb1602494b68..4cda4824acdc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -132,17 +133,8 @@ object RowEncoder { CreateStruct(convertedFields) } - /** - * Returns true if the value of this data type is same between internal and external. - */ - def isNativeType(dt: DataType): Boolean = dt match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => true - case _ => false - } - private def externalDataTypeFor(dt: DataType): DataType = dt match { - case _ if isNativeType(dt) => dt + case _ if ScalaReflection.isNativeType(dt) => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index f865a9408ef4e..ef7399e0196ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -24,7 +24,6 @@ import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.encoders.ProductEncoder import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow @@ -300,10 +299,9 @@ case class UnwrapOption( /** * Converts the result of evaluating `child` into an option, checking both the isNull bit and * (in the case of reference types) equality with null. - * @param optionType The datatype to be held inside of the Option. * @param child The expression to evaluate and wrap. */ -case class WrapOption(optionType: DataType, child: Expression) +case class WrapOption(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = ObjectType(classOf[Option[_]]) @@ -316,14 +314,13 @@ case class WrapOption(optionType: DataType, child: Expression) throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val javaType = ctx.javaType(optionType) val inputObject = child.gen(ctx) s""" ${inputObject.code} boolean ${ev.isNull} = false; - scala.Option<$javaType> ${ev.value} = + scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 4ea410d492b01..c2aace1ef238e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -186,74 +186,6 @@ class ScalaReflectionSuite extends SparkFunSuite { nullable = true)) } - test("get data type of a value") { - // BooleanType - assert(BooleanType === typeOfObject(true)) - assert(BooleanType === typeOfObject(false)) - - // BinaryType - assert(BinaryType === typeOfObject("string".getBytes)) - - // StringType - assert(StringType === typeOfObject("string")) - - // ByteType - assert(ByteType === typeOfObject(127.toByte)) - - // ShortType - assert(ShortType === typeOfObject(32767.toShort)) - - // IntegerType - assert(IntegerType === typeOfObject(2147483647)) - - // LongType - assert(LongType === typeOfObject(9223372036854775807L)) - - // FloatType - assert(FloatType === typeOfObject(3.4028235E38.toFloat)) - - // DoubleType - assert(DoubleType === typeOfObject(1.7976931348623157E308)) - - // DecimalType - assert(DecimalType.SYSTEM_DEFAULT === - typeOfObject(new java.math.BigDecimal("1.7976931348623157E318"))) - - // DateType - assert(DateType === typeOfObject(Date.valueOf("2014-07-25"))) - - // TimestampType - assert(TimestampType === typeOfObject(Timestamp.valueOf("2014-07-25 10:26:00"))) - - // NullType - assert(NullType === typeOfObject(null)) - - def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case _ => StringType - } - - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new BigInteger("92233720368547758070"))) - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new java.math.BigDecimal("1.7976931348623157E318"))) - assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) - - def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - } - - intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) - - def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse { - case c: Seq[_] => ArrayType(typeOfObject3(c.head)) - } - - assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) - assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1, 2, 3)))) - } - test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = InternalRow(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index cde0364f3dd9d..76459b34a484f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,24 +17,234 @@ package org.apache.spark.sql.catalyst.encoders +import java.sql.{Timestamp, Date} import java.util.Arrays import java.util.concurrent.ConcurrentMap +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe.TypeTag import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.types.ArrayType -abstract class ExpressionEncoderSuite extends SparkFunSuite { - val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() +case class RepeatedStruct(s: Seq[PrimitiveData]) - protected def encodeDecodeTest[T]( +case class NestedArray(a: Array[Array[Int]]) { + override def equals(other: Any): Boolean = other match { + case NestedArray(otherArray) => + java.util.Arrays.deepEquals( + a.asInstanceOf[Array[AnyRef]], + otherArray.asInstanceOf[Array[AnyRef]]) + case _ => false + } +} + +case class BoxedData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean) + +case class RepeatedData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[java.lang.Integer], + mapField: scala.collection.Map[Int, Long], + mapFieldNull: scala.collection.Map[Int, java.lang.Long], + structField: PrimitiveData) + +case class SpecificCollection(l: List[Int]) + +/** For testing Kryo serialization based encoder. */ +class KryoSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[KryoSerializable].value + } +} + +/** For testing Java serialization based encoder. */ +class JavaSerializable(val value: Int) extends Serializable { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[JavaSerializable].value + } +} + +class ExpressionEncoderSuite extends SparkFunSuite { + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() + + // test flat encoders + encodeDecodeTest(false, "primitive boolean") + encodeDecodeTest(-3.toByte, "primitive byte") + encodeDecodeTest(-3.toShort, "primitive short") + encodeDecodeTest(-3, "primitive int") + encodeDecodeTest(-3L, "primitive long") + encodeDecodeTest(-3.7f, "primitive float") + encodeDecodeTest(-3.7, "primitive double") + + encodeDecodeTest(new java.lang.Boolean(false), "boxed boolean") + encodeDecodeTest(new java.lang.Byte(-3.toByte), "boxed byte") + encodeDecodeTest(new java.lang.Short(-3.toShort), "boxed short") + encodeDecodeTest(new java.lang.Integer(-3), "boxed int") + encodeDecodeTest(new java.lang.Long(-3L), "boxed long") + encodeDecodeTest(new java.lang.Float(-3.7f), "boxed float") + encodeDecodeTest(new java.lang.Double(-3.7), "boxed double") + + encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") + // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") + + encodeDecodeTest("hello", "string") + encodeDecodeTest(Date.valueOf("2012-12-23"), "date") + encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), "timestamp") + encodeDecodeTest(Array[Byte](13, 21, -23), "binary") + + encodeDecodeTest(Seq(31, -123, 4), "seq of int") + encodeDecodeTest(Seq("abc", "xyz"), "seq of string") + encodeDecodeTest(Seq("abc", null, "xyz"), "seq of string with null") + encodeDecodeTest(Seq.empty[Int], "empty seq of int") + encodeDecodeTest(Seq.empty[String], "empty seq of string") + + encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), "seq of seq of int") + encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), + "seq of seq of string") + + encodeDecodeTest(Array(31, -123, 4), "array of int") + encodeDecodeTest(Array("abc", "xyz"), "array of string") + encodeDecodeTest(Array("a", null, "x"), "array of string with null") + encodeDecodeTest(Array.empty[Int], "empty array of int") + encodeDecodeTest(Array.empty[String], "empty array of string") + + encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), "array of array of int") + encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), + "array of array of string") + + encodeDecodeTest(Map(1 -> "a", 2 -> "b"), "map") + encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null") + encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map") + + // Kryo encoders + encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) + encodeDecodeTest(new KryoSerializable(15), "kryo object")( + encoderFor(Encoders.kryo[KryoSerializable])) + + // Java encoders + encodeDecodeTest("hello", "java string")(encoderFor(Encoders.javaSerialization[String])) + encodeDecodeTest(new JavaSerializable(15), "java object")( + encoderFor(Encoders.javaSerialization[JavaSerializable])) + + // test product encoders + private def productTest[T <: Product : ExpressionEncoder](input: T): Unit = { + encodeDecodeTest(input, input.getClass.getSimpleName) + } + + case class InnerClass(i: Int) + productTest(InnerClass(1)) + + productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) + + productTest( + OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) + + productTest(OptionalData(None, None, None, None, None, None, None, None)) + + productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) + + productTest(BoxedData(null, null, null, null, null, null, null)) + + productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) + + productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest( + RepeatedData( + Seq(1, 2), + Seq(new Integer(1), null, new Integer(2)), + Map(1 -> 2L), + Map(1 -> null), + PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) + + productTest(("Seq[(String, String)]", + Seq(("a", "b")))) + productTest(("Seq[(Int, Int)]", + Seq((1, 2)))) + productTest(("Seq[(Long, Long)]", + Seq((1L, 2L)))) + productTest(("Seq[(Float, Float)]", + Seq((1.toFloat, 2.toFloat)))) + productTest(("Seq[(Double, Double)]", + Seq((1.toDouble, 2.toDouble)))) + productTest(("Seq[(Short, Short)]", + Seq((1.toShort, 2.toShort)))) + productTest(("Seq[(Byte, Byte)]", + Seq((1.toByte, 2.toByte)))) + productTest(("Seq[(Boolean, Boolean)]", + Seq((true, false)))) + + productTest(("ArrayBuffer[(String, String)]", + ArrayBuffer(("a", "b")))) + productTest(("ArrayBuffer[(Int, Int)]", + ArrayBuffer((1, 2)))) + productTest(("ArrayBuffer[(Long, Long)]", + ArrayBuffer((1L, 2L)))) + productTest(("ArrayBuffer[(Float, Float)]", + ArrayBuffer((1.toFloat, 2.toFloat)))) + productTest(("ArrayBuffer[(Double, Double)]", + ArrayBuffer((1.toDouble, 2.toDouble)))) + productTest(("ArrayBuffer[(Short, Short)]", + ArrayBuffer((1.toShort, 2.toShort)))) + productTest(("ArrayBuffer[(Byte, Byte)]", + ArrayBuffer((1.toByte, 2.toByte)))) + productTest(("ArrayBuffer[(Boolean, Boolean)]", + ArrayBuffer((true, false)))) + + productTest(("Seq[Seq[(Int, Int)]]", + Seq(Seq((1, 2))))) + + // test for ExpressionEncoder.tuple + encodeDecodeTest( + 1 -> 10L, + "tuple with 2 flat encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[Long])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), + "tuple with 2 product encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[(Int, Long)])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), + "tuple with flat encoder and product encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[Int])) + + encodeDecodeTest( + (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), + "tuple with product encoder and flat encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[PrimitiveData])) + + encodeDecodeTest( + (1, (10, 100L)), + "nested tuple encoder") { + val intEnc = ExpressionEncoder[Int] + val longEnc = ExpressionEncoder[Long] + ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) + } + + private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + outers.put(getClass.getName, this) + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, - encoder: ExpressionEncoder[T], testName: String): Unit = { test(s"encode/decode for $testName: $input") { + val encoder = implicitly[ExpressionEncoder[T]] val row = encoder.toRow(input) val schema = encoder.schema.toAttributes val boundEncoder = encoder.resolve(schema, outers).bind(schema) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala deleted file mode 100644 index 07523d49f4266..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import java.sql.{Date, Timestamp} -import org.apache.spark.sql.Encoders - -class FlatEncoderSuite extends ExpressionEncoderSuite { - encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean") - encodeDecodeTest(-3.toByte, FlatEncoder[Byte], "primitive byte") - encodeDecodeTest(-3.toShort, FlatEncoder[Short], "primitive short") - encodeDecodeTest(-3, FlatEncoder[Int], "primitive int") - encodeDecodeTest(-3L, FlatEncoder[Long], "primitive long") - encodeDecodeTest(-3.7f, FlatEncoder[Float], "primitive float") - encodeDecodeTest(-3.7, FlatEncoder[Double], "primitive double") - - encodeDecodeTest(new java.lang.Boolean(false), FlatEncoder[java.lang.Boolean], "boxed boolean") - encodeDecodeTest(new java.lang.Byte(-3.toByte), FlatEncoder[java.lang.Byte], "boxed byte") - encodeDecodeTest(new java.lang.Short(-3.toShort), FlatEncoder[java.lang.Short], "boxed short") - encodeDecodeTest(new java.lang.Integer(-3), FlatEncoder[java.lang.Integer], "boxed int") - encodeDecodeTest(new java.lang.Long(-3L), FlatEncoder[java.lang.Long], "boxed long") - encodeDecodeTest(new java.lang.Float(-3.7f), FlatEncoder[java.lang.Float], "boxed float") - encodeDecodeTest(new java.lang.Double(-3.7), FlatEncoder[java.lang.Double], "boxed double") - - encodeDecodeTest(BigDecimal("32131413.211321313"), FlatEncoder[BigDecimal], "scala decimal") - type JDecimal = java.math.BigDecimal - // encodeDecodeTest(new JDecimal("231341.23123"), FlatEncoder[JDecimal], "java decimal") - - encodeDecodeTest("hello", FlatEncoder[String], "string") - encodeDecodeTest(Date.valueOf("2012-12-23"), FlatEncoder[Date], "date") - encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), FlatEncoder[Timestamp], "timestamp") - encodeDecodeTest(Array[Byte](13, 21, -23), FlatEncoder[Array[Byte]], "binary") - - encodeDecodeTest(Seq(31, -123, 4), FlatEncoder[Seq[Int]], "seq of int") - encodeDecodeTest(Seq("abc", "xyz"), FlatEncoder[Seq[String]], "seq of string") - encodeDecodeTest(Seq("abc", null, "xyz"), FlatEncoder[Seq[String]], "seq of string with null") - encodeDecodeTest(Seq.empty[Int], FlatEncoder[Seq[Int]], "empty seq of int") - encodeDecodeTest(Seq.empty[String], FlatEncoder[Seq[String]], "empty seq of string") - - encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), - FlatEncoder[Seq[Seq[Int]]], "seq of seq of int") - encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), - FlatEncoder[Seq[Seq[String]]], "seq of seq of string") - - encodeDecodeTest(Array(31, -123, 4), FlatEncoder[Array[Int]], "array of int") - encodeDecodeTest(Array("abc", "xyz"), FlatEncoder[Array[String]], "array of string") - encodeDecodeTest(Array("a", null, "x"), FlatEncoder[Array[String]], "array of string with null") - encodeDecodeTest(Array.empty[Int], FlatEncoder[Array[Int]], "empty array of int") - encodeDecodeTest(Array.empty[String], FlatEncoder[Array[String]], "empty array of string") - - encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), - FlatEncoder[Array[Array[Int]]], "array of array of int") - encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), - FlatEncoder[Array[Array[String]]], "array of array of string") - - encodeDecodeTest(Map(1 -> "a", 2 -> "b"), FlatEncoder[Map[Int, String]], "map") - encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null") - encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), - FlatEncoder[Map[Int, Map[String, Int]]], "map of map") - - // Kryo encoders - encodeDecodeTest("hello", encoderFor(Encoders.kryo[String]), "kryo string") - encodeDecodeTest(new KryoSerializable(15), - encoderFor(Encoders.kryo[KryoSerializable]), "kryo object") - - // Java encoders - encodeDecodeTest("hello", encoderFor(Encoders.javaSerialization[String]), "java string") - encodeDecodeTest(new JavaSerializable(15), - encoderFor(Encoders.javaSerialization[JavaSerializable]), "java object") -} - -/** For testing Kryo serialization based encoder. */ -class KryoSerializable(val value: Int) { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[KryoSerializable].value - } -} - -/** For testing Java serialization based encoder. */ -class JavaSerializable(val value: Int) extends Serializable { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[JavaSerializable].value - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala deleted file mode 100644 index 1798514c5c38b..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.collection.mutable.ArrayBuffer -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} - -case class RepeatedStruct(s: Seq[PrimitiveData]) - -case class NestedArray(a: Array[Array[Int]]) { - override def equals(other: Any): Boolean = other match { - case NestedArray(otherArray) => - java.util.Arrays.deepEquals( - a.asInstanceOf[Array[AnyRef]], - otherArray.asInstanceOf[Array[AnyRef]]) - case _ => false - } -} - -case class BoxedData( - intField: java.lang.Integer, - longField: java.lang.Long, - doubleField: java.lang.Double, - floatField: java.lang.Float, - shortField: java.lang.Short, - byteField: java.lang.Byte, - booleanField: java.lang.Boolean) - -case class RepeatedData( - arrayField: Seq[Int], - arrayFieldContainsNull: Seq[java.lang.Integer], - mapField: scala.collection.Map[Int, Long], - mapFieldNull: scala.collection.Map[Int, java.lang.Long], - structField: PrimitiveData) - -case class SpecificCollection(l: List[Int]) - -class ProductEncoderSuite extends ExpressionEncoderSuite { - outers.put(getClass.getName, this) - - case class InnerClass(i: Int) - productTest(InnerClass(1)) - - productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) - - productTest( - OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), - Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) - - productTest(OptionalData(None, None, None, None, None, None, None, None)) - - productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) - - productTest(BoxedData(null, null, null, null, null, null, null)) - - productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) - - productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) - - productTest( - RepeatedData( - Seq(1, 2), - Seq(new Integer(1), null, new Integer(2)), - Map(1 -> 2L), - Map(1 -> null), - PrimitiveData(1, 1, 1, 1, 1, 1, true))) - - productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) - - productTest(("Seq[(String, String)]", - Seq(("a", "b")))) - productTest(("Seq[(Int, Int)]", - Seq((1, 2)))) - productTest(("Seq[(Long, Long)]", - Seq((1L, 2L)))) - productTest(("Seq[(Float, Float)]", - Seq((1.toFloat, 2.toFloat)))) - productTest(("Seq[(Double, Double)]", - Seq((1.toDouble, 2.toDouble)))) - productTest(("Seq[(Short, Short)]", - Seq((1.toShort, 2.toShort)))) - productTest(("Seq[(Byte, Byte)]", - Seq((1.toByte, 2.toByte)))) - productTest(("Seq[(Boolean, Boolean)]", - Seq((true, false)))) - - productTest(("ArrayBuffer[(String, String)]", - ArrayBuffer(("a", "b")))) - productTest(("ArrayBuffer[(Int, Int)]", - ArrayBuffer((1, 2)))) - productTest(("ArrayBuffer[(Long, Long)]", - ArrayBuffer((1L, 2L)))) - productTest(("ArrayBuffer[(Float, Float)]", - ArrayBuffer((1.toFloat, 2.toFloat)))) - productTest(("ArrayBuffer[(Double, Double)]", - ArrayBuffer((1.toDouble, 2.toDouble)))) - productTest(("ArrayBuffer[(Short, Short)]", - ArrayBuffer((1.toShort, 2.toShort)))) - productTest(("ArrayBuffer[(Byte, Byte)]", - ArrayBuffer((1.toByte, 2.toByte)))) - productTest(("ArrayBuffer[(Boolean, Boolean)]", - ArrayBuffer((true, false)))) - - productTest(("Seq[Seq[(Int, Int)]]", - Seq(Seq((1, 2))))) - - encodeDecodeTest( - 1 -> 10L, - ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]), - "tuple with 2 flat encoders") - - encodeDecodeTest( - (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), - ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], ProductEncoder[(Int, Long)]), - "tuple with 2 product encoders") - - encodeDecodeTest( - (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), - ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]), - "tuple with flat encoder and product encoder") - - encodeDecodeTest( - (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), - ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]), - "tuple with product encoder and flat encoder") - - encodeDecodeTest( - (1, (10, 100L)), - { - val intEnc = FlatEncoder[Int] - val longEnc = FlatEncoder[Long] - ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) - }, - "nested tuple encoder") - - private def productTest[T <: Product : TypeTag](input: T): Unit = { - encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 7e5acbe8517d1..6de3dd626576a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, OuterScopes} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -242,7 +242,7 @@ class GroupedDataset[K, T] private[sql]( * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. */ - def count(): Dataset[(K, Long)] = agg(functions.count("*").as(FlatEncoder[Long])) + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long])) /** * Applies the given function to each cogrouped data. For each unique group, the function will diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 8471eea1b7d9c..25ffdcde17717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation - import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -28,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.types.StructField import org.apache.spark.unsafe.types.UTF8String @@ -37,16 +34,16 @@ import org.apache.spark.unsafe.types.UTF8String abstract class SQLImplicits { protected def _sqlContext: SQLContext - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T] + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() - implicit def newIntEncoder: Encoder[Int] = FlatEncoder[Int] - implicit def newLongEncoder: Encoder[Long] = FlatEncoder[Long] - implicit def newDoubleEncoder: Encoder[Double] = FlatEncoder[Double] - implicit def newFloatEncoder: Encoder[Float] = FlatEncoder[Float] - implicit def newByteEncoder: Encoder[Byte] = FlatEncoder[Byte] - implicit def newShortEncoder: Encoder[Short] = FlatEncoder[Short] - implicit def newBooleanEncoder: Encoder[Boolean] = FlatEncoder[Boolean] - implicit def newStringEncoder: Encoder[String] = FlatEncoder[String] + implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() + implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder() + implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder() + implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder() + implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder() + implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() + implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() + implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() /** * Creates a [[Dataset]] from an RDD. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 95158de710acf..b27b1340cce46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -26,7 +26,7 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} -import org.apache.spark.sql.catalyst.encoders.FlatEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint @@ -267,7 +267,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def count(columnName: String): TypedColumn[Any, Long] = - count(Column(columnName)).as(FlatEncoder[Long]) + count(Column(columnName)).as(ExpressionEncoder[Long]) /** * Aggregate function: returns the number of distinct items in a group. From ea67018b487b9c188e7cd5ab0344626adc8e0253 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 19 Nov 2015 13:08:01 -0800 Subject: [PATCH 0830/1168] [SPARK-11778][SQL] parse table name before it is passed to lookupRelation Fix a bug in DataFrameReader.table (table with schema name such as "db_name.table" doesn't work) Use SqlParser.parseTableIdentifier to parse the table name before lookupRelation. Author: Huaxin Gao Closes #9773 from huaxingao/spark-11778. --- .../scala/org/apache/spark/sql/DataFrameReader.scala | 3 ++- .../spark/sql/hive/HiveDataFrameAnalyticsSuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 5872fbded3833..dcb3737b70fbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -313,7 +313,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def table(tableName: String): DataFrame = { - DataFrame(sqlContext, sqlContext.catalog.lookupRelation(TableIdentifier(tableName))) + DataFrame(sqlContext, + sqlContext.catalog.lookupRelation(SqlParser.parseTableIdentifier(tableName))) } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 9864acf765265..f19a74d4b3724 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -34,10 +34,14 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with override def beforeAll() { testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") + hiveContext.sql("create schema usrdb") + hiveContext.sql("create table usrdb.test(c1 int)") } override def afterAll(): Unit = { hiveContext.dropTempTable("mytable") + hiveContext.sql("drop table usrdb.test") + hiveContext.sql("drop schema usrdb") } test("rollup") { @@ -74,4 +78,10 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with sql("select a, b, sum(b) from mytable group by a, b with cube").collect() ) } + + // There was a bug in DataFrameFrameReader.table and it has problem for table with schema name, + // Before fix, it throw Exceptionorg.apache.spark.sql.catalyst.analysis.NoSuchTableException + test("table name with schema") { + hiveContext.read.table("usrdb.test") + } } From cd79fb0d9c4502c630156a194db22d8676a6e7b2 Mon Sep 17 00:00:00 2001 From: David Tolpin Date: Thu, 19 Nov 2015 13:57:23 -0800 Subject: [PATCH 0831/1168] [SPARK-11812][PYSPARK] invFunc=None works properly with python's reduceByKeyAndWindow invFunc is optional and can be None. Instead of invFunc (the parameter) invReduceFunc (a local function) was checked for trueness (that is, not None, in this context). A local function is never None, thus the case of invFunc=None (a common one when inverse reduction is not defined) was treated incorrectly, resulting in loss of data. In addition, the docstring used wrong parameter names, also fixed. Author: David Tolpin Closes #9775 from dtolpin/master. --- python/pyspark/streaming/dstream.py | 6 +++--- python/pyspark/streaming/tests.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 698336cfce18d..acec850f02c2d 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -524,8 +524,8 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None `invFunc` can be None, then it will reduce all the RDDs in window, could be slower than having `invFunc`. - @param reduceFunc: associative reduce function - @param invReduceFunc: inverse function of `reduceFunc` + @param func: associative reduce function + @param invFunc: inverse function of `reduceFunc` @param windowDuration: width of the window; must be a multiple of this DStream's batching interval @param slideDuration: sliding interval of the window (i.e., the interval after which @@ -556,7 +556,7 @@ def invReduceFunc(t, a, b): if kv[1] is not None else kv[0]) jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) - if invReduceFunc: + if invFunc: jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) else: jinvReduceFunc = None diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 0bcd1f15532b5..3403f6d20d789 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -582,6 +582,17 @@ def test_reduce_by_invalid_window(self): self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + def test_reduce_by_key_and_window_with_none_invFunc(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.map(lambda x: (x, 1))\ + .reduceByKeyAndWindow(operator.add, None, 5, 1)\ + .filter(lambda kv: kv[1] > 0).count() + + expected = [[2], [4], [6], [6], [6], [6]] + self._test_func(input, func, expected) + class StreamingContextTests(PySparkStreamingTestCase): From 3e9d9ef69228c73ac0f34d88a5933acf6b8e537f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Nov 2015 14:48:18 -0800 Subject: [PATCH 0832/1168] [SPARK-11858][SQL] Move sql.columnar into sql.execution. In addition, tightened visibility of a lot of classes in the columnar package from private[sql] to private[columnar]. Author: Reynold Xin Closes #9842 from rxin/SPARK-11858. --- .../spark/sql/execution/CacheManager.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../columnar/ColumnAccessor.scala | 42 +++++++-------- .../columnar/ColumnBuilder.scala | 51 ++++++++++--------- .../columnar/ColumnStats.scala | 34 ++++++------- .../{ => execution}/columnar/ColumnType.scala | 48 ++++++++--------- .../columnar/GenerateColumnAccessor.scala | 4 +- .../columnar/InMemoryColumnarTableScan.scala | 5 +- .../columnar/NullableColumnAccessor.scala | 4 +- .../columnar/NullableColumnBuilder.scala | 4 +- .../CompressibleColumnAccessor.scala | 6 +-- .../CompressibleColumnBuilder.scala | 6 +-- .../compression/CompressionScheme.scala | 16 +++--- .../compression/compressionSchemes.scala | 16 +++--- .../apache/spark/sql/execution/package.scala | 2 + .../apache/spark/sql/CachedTableSuite.scala | 4 +- .../org/apache/spark/sql/QueryTest.scala | 2 +- .../columnar/ColumnStatsSuite.scala | 6 +-- .../columnar/ColumnTypeSuite.scala | 4 +- .../columnar/ColumnarTestUtils.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../NullableColumnAccessorSuite.scala | 4 +- .../columnar/NullableColumnBuilderSuite.scala | 4 +- .../columnar/PartitionBatchPruningSuite.scala | 2 +- .../compression/BooleanBitSetSuite.scala | 6 +-- .../compression/DictionaryEncodingSuite.scala | 6 +-- .../compression/IntegralDeltaSuite.scala | 6 +-- .../compression/RunLengthEncodingSuite.scala | 6 +-- .../TestCompressibleColumnBuilder.scala | 4 +- .../spark/sql/hive/CachedTableSuite.scala | 2 +- 30 files changed, 155 insertions(+), 147 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnAccessor.scala (75%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnBuilder.scala (74%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnStats.scala (88%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnType.scala (93%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/GenerateColumnAccessor.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/InMemoryColumnarTableScan.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnAccessor.scala (94%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnBuilder.scala (95%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/CompressibleColumnAccessor.scala (84%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/CompressibleColumnBuilder.scala (94%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/CompressionScheme.scala (83%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/compressionSchemes.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnStatsSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnTypeSuite.scala (97%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnarTestUtils.scala (98%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/InMemoryColumnarQuerySuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnAccessorSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnBuilderSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/PartitionBatchPruningSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/BooleanBitSetSuite.scala (94%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/DictionaryEncodingSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/IntegralDeltaSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/RunLengthEncodingSuite.scala (95%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/TestCompressibleColumnBuilder.scala (93%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index f85aeb1b02694..293fcfe96e677 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -22,7 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3d4ce633c07c9..f67c951bc0663 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.{Strategy, execution} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala similarity index 75% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 42ec4d3433f16..fee36f6023895 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow} -import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor +import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor import org.apache.spark.sql.types._ /** @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ * a [[MutableRow]]. In this way, boxing cost can be avoided by leveraging the setter methods * for primitive values provided by [[MutableRow]]. */ -private[sql] trait ColumnAccessor { +private[columnar] trait ColumnAccessor { initialize() protected def initialize() @@ -41,7 +41,7 @@ private[sql] trait ColumnAccessor { protected def underlyingBuffer: ByteBuffer } -private[sql] abstract class BasicColumnAccessor[JvmType]( +private[columnar] abstract class BasicColumnAccessor[JvmType]( protected val buffer: ByteBuffer, protected val columnType: ColumnType[JvmType]) extends ColumnAccessor { @@ -61,65 +61,65 @@ private[sql] abstract class BasicColumnAccessor[JvmType]( protected def underlyingBuffer = buffer } -private[sql] class NullColumnAccessor(buffer: ByteBuffer) +private[columnar] class NullColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[Any](buffer, NULL) with NullableColumnAccessor -private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( +private[columnar] abstract class NativeColumnAccessor[T <: AtomicType]( override protected val buffer: ByteBuffer, override protected val columnType: NativeColumnType[T]) extends BasicColumnAccessor(buffer, columnType) with NullableColumnAccessor with CompressibleColumnAccessor[T] -private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) +private[columnar] class BooleanColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BOOLEAN) -private[sql] class ByteColumnAccessor(buffer: ByteBuffer) +private[columnar] class ByteColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BYTE) -private[sql] class ShortColumnAccessor(buffer: ByteBuffer) +private[columnar] class ShortColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, SHORT) -private[sql] class IntColumnAccessor(buffer: ByteBuffer) +private[columnar] class IntColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, INT) -private[sql] class LongColumnAccessor(buffer: ByteBuffer) +private[columnar] class LongColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, LONG) -private[sql] class FloatColumnAccessor(buffer: ByteBuffer) +private[columnar] class FloatColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, FLOAT) -private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) +private[columnar] class DoubleColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, DOUBLE) -private[sql] class StringColumnAccessor(buffer: ByteBuffer) +private[columnar] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) -private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) +private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) with NullableColumnAccessor -private[sql] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) +private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType)) -private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) +private[columnar] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) extends BasicColumnAccessor[Decimal](buffer, LARGE_DECIMAL(dataType)) with NullableColumnAccessor -private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) +private[columnar] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) extends BasicColumnAccessor[UnsafeRow](buffer, STRUCT(dataType)) with NullableColumnAccessor -private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) +private[columnar] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) extends BasicColumnAccessor[UnsafeArrayData](buffer, ARRAY(dataType)) with NullableColumnAccessor -private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) +private[columnar] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) with NullableColumnAccessor -private[sql] object ColumnAccessor { +private[columnar] object ColumnAccessor { def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val buf = buffer.order(ByteOrder.nativeOrder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala similarity index 74% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index 599f30f2d73b4..7e26f19bb7449 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.columnar.ColumnBuilder._ -import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} +import org.apache.spark.sql.execution.columnar.ColumnBuilder._ +import org.apache.spark.sql.execution.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} import org.apache.spark.sql.types._ -private[sql] trait ColumnBuilder { +private[columnar] trait ColumnBuilder { /** * Initializes with an approximate lower bound on the expected number of elements in this column. */ @@ -46,7 +46,7 @@ private[sql] trait ColumnBuilder { def build(): ByteBuffer } -private[sql] class BasicColumnBuilder[JvmType]( +private[columnar] class BasicColumnBuilder[JvmType]( val columnStats: ColumnStats, val columnType: ColumnType[JvmType]) extends ColumnBuilder { @@ -84,17 +84,17 @@ private[sql] class BasicColumnBuilder[JvmType]( } } -private[sql] class NullColumnBuilder +private[columnar] class NullColumnBuilder extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL) with NullableColumnBuilder -private[sql] abstract class ComplexColumnBuilder[JvmType]( +private[columnar] abstract class ComplexColumnBuilder[JvmType]( columnStats: ColumnStats, columnType: ColumnType[JvmType]) extends BasicColumnBuilder[JvmType](columnStats, columnType) with NullableColumnBuilder -private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( +private[columnar] abstract class NativeColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) extends BasicColumnBuilder[T#InternalType](columnStats, columnType) @@ -102,40 +102,45 @@ private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( with AllCompressionSchemes with CompressibleColumnBuilder[T] -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) +private[columnar] +class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) -private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) +private[columnar] +class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) -private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) +private[columnar] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) -private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) +private[columnar] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) -private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) +private[columnar] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) -private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) +private[columnar] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) -private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) +private[columnar] +class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) -private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +private[columnar] +class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) +private[columnar] +class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) -private[sql] class CompactDecimalColumnBuilder(dataType: DecimalType) +private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType) extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType)) -private[sql] class DecimalColumnBuilder(dataType: DecimalType) +private[columnar] class DecimalColumnBuilder(dataType: DecimalType) extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType)) -private[sql] class StructColumnBuilder(dataType: StructType) +private[columnar] class StructColumnBuilder(dataType: StructType) extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType)) -private[sql] class ArrayColumnBuilder(dataType: ArrayType) +private[columnar] class ArrayColumnBuilder(dataType: ArrayType) extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType)) -private[sql] class MapColumnBuilder(dataType: MapType) +private[columnar] class MapColumnBuilder(dataType: MapType) extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType)) -private[sql] object ColumnBuilder { +private[columnar] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 128 * 1024 val MAX_BATCH_SIZE_IN_BYTE = 4 * 1024 * 1024L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 91a05650585cf..c52ee9ffd6d2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { +private[columnar] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = true)() val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() @@ -32,7 +32,7 @@ private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val schema = Seq(lowerBound, upperBound, nullCount, count, sizeInBytes) } -private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { +private[columnar] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { val (forAttribute, schema) = { val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a)) (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _)) @@ -45,10 +45,10 @@ private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Seri * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]` * brings significant performance penalty. */ -private[sql] sealed trait ColumnStats extends Serializable { +private[columnar] sealed trait ColumnStats extends Serializable { protected var count = 0 protected var nullCount = 0 - private[sql] var sizeInBytes = 0L + private[columnar] var sizeInBytes = 0L /** * Gathers statistics information from `row(ordinal)`. @@ -72,14 +72,14 @@ private[sql] sealed trait ColumnStats extends Serializable { /** * A no-op ColumnStats only used for testing purposes. */ -private[sql] class NoopColumnStats extends ColumnStats { +private[columnar] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) } -private[sql] class BooleanColumnStats extends ColumnStats { +private[columnar] class BooleanColumnStats extends ColumnStats { protected var upper = false protected var lower = true @@ -97,7 +97,7 @@ private[sql] class BooleanColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ByteColumnStats extends ColumnStats { +private[columnar] class ByteColumnStats extends ColumnStats { protected var upper = Byte.MinValue protected var lower = Byte.MaxValue @@ -115,7 +115,7 @@ private[sql] class ByteColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ShortColumnStats extends ColumnStats { +private[columnar] class ShortColumnStats extends ColumnStats { protected var upper = Short.MinValue protected var lower = Short.MaxValue @@ -133,7 +133,7 @@ private[sql] class ShortColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class IntColumnStats extends ColumnStats { +private[columnar] class IntColumnStats extends ColumnStats { protected var upper = Int.MinValue protected var lower = Int.MaxValue @@ -151,7 +151,7 @@ private[sql] class IntColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class LongColumnStats extends ColumnStats { +private[columnar] class LongColumnStats extends ColumnStats { protected var upper = Long.MinValue protected var lower = Long.MaxValue @@ -169,7 +169,7 @@ private[sql] class LongColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class FloatColumnStats extends ColumnStats { +private[columnar] class FloatColumnStats extends ColumnStats { protected var upper = Float.MinValue protected var lower = Float.MaxValue @@ -187,7 +187,7 @@ private[sql] class FloatColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class DoubleColumnStats extends ColumnStats { +private[columnar] class DoubleColumnStats extends ColumnStats { protected var upper = Double.MinValue protected var lower = Double.MaxValue @@ -205,7 +205,7 @@ private[sql] class DoubleColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class StringColumnStats extends ColumnStats { +private[columnar] class StringColumnStats extends ColumnStats { protected var upper: UTF8String = null protected var lower: UTF8String = null @@ -223,7 +223,7 @@ private[sql] class StringColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class BinaryColumnStats extends ColumnStats { +private[columnar] class BinaryColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { @@ -235,7 +235,7 @@ private[sql] class BinaryColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } -private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { +private[columnar] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { def this(dt: DecimalType) = this(dt.precision, dt.scale) protected var upper: Decimal = null @@ -256,7 +256,7 @@ private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends Column new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ObjectColumnStats(dataType: DataType) extends ColumnStats { +private[columnar] class ObjectColumnStats(dataType: DataType) extends ColumnStats { val columnType = ColumnType(dataType) override def gatherStats(row: InternalRow, ordinal: Int): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 68e509eb5047d..c9f2329db4b6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.math.{BigDecimal, BigInteger} import java.nio.ByteBuffer @@ -41,7 +41,7 @@ import org.apache.spark.unsafe.types.UTF8String * * WARNNING: This only works with HeapByteBuffer */ -object ByteBufferHelper { +private[columnar] object ByteBufferHelper { def getInt(buffer: ByteBuffer): Int = { val pos = buffer.position() buffer.position(pos + 4) @@ -73,7 +73,7 @@ object ByteBufferHelper { * * @tparam JvmType Underlying Java type to represent the elements. */ -private[sql] sealed abstract class ColumnType[JvmType] { +private[columnar] sealed abstract class ColumnType[JvmType] { // The catalyst data type of this column. def dataType: DataType @@ -142,7 +142,7 @@ private[sql] sealed abstract class ColumnType[JvmType] { override def toString: String = getClass.getSimpleName.stripSuffix("$") } -private[sql] object NULL extends ColumnType[Any] { +private[columnar] object NULL extends ColumnType[Any] { override def dataType: DataType = NullType override def defaultSize: Int = 0 @@ -152,7 +152,7 @@ private[sql] object NULL extends ColumnType[Any] { override def getField(row: InternalRow, ordinal: Int): Any = null } -private[sql] abstract class NativeColumnType[T <: AtomicType]( +private[columnar] abstract class NativeColumnType[T <: AtomicType]( val dataType: T, val defaultSize: Int) extends ColumnType[T#InternalType] { @@ -163,7 +163,7 @@ private[sql] abstract class NativeColumnType[T <: AtomicType]( def scalaTag: TypeTag[dataType.InternalType] = dataType.tag } -private[sql] object INT extends NativeColumnType(IntegerType, 4) { +private[columnar] object INT extends NativeColumnType(IntegerType, 4) { override def append(v: Int, buffer: ByteBuffer): Unit = { buffer.putInt(v) } @@ -192,7 +192,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) { } } -private[sql] object LONG extends NativeColumnType(LongType, 8) { +private[columnar] object LONG extends NativeColumnType(LongType, 8) { override def append(v: Long, buffer: ByteBuffer): Unit = { buffer.putLong(v) } @@ -220,7 +220,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 8) { } } -private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { +private[columnar] object FLOAT extends NativeColumnType(FloatType, 4) { override def append(v: Float, buffer: ByteBuffer): Unit = { buffer.putFloat(v) } @@ -248,7 +248,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { } } -private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { +private[columnar] object DOUBLE extends NativeColumnType(DoubleType, 8) { override def append(v: Double, buffer: ByteBuffer): Unit = { buffer.putDouble(v) } @@ -276,7 +276,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { } } -private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { +private[columnar] object BOOLEAN extends NativeColumnType(BooleanType, 1) { override def append(v: Boolean, buffer: ByteBuffer): Unit = { buffer.put(if (v) 1: Byte else 0: Byte) } @@ -302,7 +302,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { } } -private[sql] object BYTE extends NativeColumnType(ByteType, 1) { +private[columnar] object BYTE extends NativeColumnType(ByteType, 1) { override def append(v: Byte, buffer: ByteBuffer): Unit = { buffer.put(v) } @@ -330,7 +330,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 1) { } } -private[sql] object SHORT extends NativeColumnType(ShortType, 2) { +private[columnar] object SHORT extends NativeColumnType(ShortType, 2) { override def append(v: Short, buffer: ByteBuffer): Unit = { buffer.putShort(v) } @@ -362,7 +362,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 2) { * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper * objects. */ -private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { +private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { // copy the bytes from ByteBuffer to UnsafeRow override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { @@ -387,7 +387,7 @@ private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { } } -private[sql] object STRING +private[columnar] object STRING extends NativeColumnType(StringType, 8) with DirectCopyColumnType[UTF8String] { override def actualSize(row: InternalRow, ordinal: Int): Int = { @@ -425,7 +425,7 @@ private[sql] object STRING override def clone(v: UTF8String): UTF8String = v.clone() } -private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) +private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int) extends NativeColumnType(DecimalType(precision, scale), 8) { override def extract(buffer: ByteBuffer): Decimal = { @@ -467,13 +467,13 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) } } -private[sql] object COMPACT_DECIMAL { +private[columnar] object COMPACT_DECIMAL { def apply(dt: DecimalType): COMPACT_DECIMAL = { COMPACT_DECIMAL(dt.precision, dt.scale) } } -private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) +private[columnar] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) extends ColumnType[JvmType] with DirectCopyColumnType[JvmType] { def serialize(value: JvmType): Array[Byte] @@ -492,7 +492,7 @@ private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: } } -private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { +private[columnar] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { def dataType: DataType = BinaryType @@ -512,7 +512,7 @@ private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { def deserialize(bytes: Array[Byte]): Array[Byte] = bytes } -private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) +private[columnar] case class LARGE_DECIMAL(precision: Int, scale: Int) extends ByteArrayColumnType[Decimal](12) { override val dataType: DataType = DecimalType(precision, scale) @@ -539,13 +539,13 @@ private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) } } -private[sql] object LARGE_DECIMAL { +private[columnar] object LARGE_DECIMAL { def apply(dt: DecimalType): LARGE_DECIMAL = { LARGE_DECIMAL(dt.precision, dt.scale) } } -private[sql] case class STRUCT(dataType: StructType) +private[columnar] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] with DirectCopyColumnType[UnsafeRow] { private val numOfFields: Int = dataType.fields.size @@ -586,7 +586,7 @@ private[sql] case class STRUCT(dataType: StructType) override def clone(v: UnsafeRow): UnsafeRow = v.copy() } -private[sql] case class ARRAY(dataType: ArrayType) +private[columnar] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] { override def defaultSize: Int = 16 @@ -625,7 +625,7 @@ private[sql] case class ARRAY(dataType: ArrayType) override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() } -private[sql] case class MAP(dataType: MapType) +private[columnar] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] { override def defaultSize: Int = 32 @@ -663,7 +663,7 @@ private[sql] case class MAP(dataType: MapType) override def clone(v: UnsafeMapData): UnsafeMapData = v.copy() } -private[sql] object ColumnType { +private[columnar] object ColumnType { def apply(dataType: DataType): ColumnType[_] = { dataType match { case NullType => NULL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index ff9393b465b7a..eaafc96e4d2e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow @@ -121,7 +121,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; - import org.apache.spark.sql.columnar.MutableUnsafeRow; + import org.apache.spark.sql.execution.columnar.MutableUnsafeRow; public SpecificColumnarIterator generate($exprType[] expr) { return new SpecificColumnarIterator(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index ae77298e6da2f..ce701fb3a7f28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import scala.collection.mutable.ArrayBuffer @@ -50,7 +50,8 @@ private[sql] object InMemoryRelation { * @param buffers The buffers for serialized columns * @param stats The stat of columns */ -private[sql] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) +private[columnar] +case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) private[sql] case class InMemoryRelation( output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala index 7eaecfe047c3f..8d99546924de1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.catalyst.expressions.MutableRow -private[sql] trait NullableColumnAccessor extends ColumnAccessor { +private[columnar] trait NullableColumnAccessor extends ColumnAccessor { private var nullsBuffer: ByteBuffer = _ private var nullCount: Int = _ private var seenNulls: Int = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala index 76cfddf1cd01a..3a1931bfb5c84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow * +---+-----+---------+ * }}} */ -private[sql] trait NullableColumnBuilder extends ColumnBuilder { +private[columnar] trait NullableColumnBuilder extends ColumnBuilder { protected var nulls: ByteBuffer = _ protected var nullCount: Int = _ private var pos: Int = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala index cb205defbb1ad..6579b5068e65a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} +import org.apache.spark.sql.execution.columnar.{ColumnAccessor, NativeColumnAccessor} import org.apache.spark.sql.types.AtomicType -private[sql] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { +private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { this: NativeColumnAccessor[T] => private var decoder: Decoder[T] = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala index 161021ff96154..b0e216feb5595 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} +import org.apache.spark.sql.execution.columnar.{ColumnBuilder, NativeColumnBuilder} import org.apache.spark.sql.types.AtomicType /** @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.AtomicType * header body * }}} */ -private[sql] trait CompressibleColumnBuilder[T <: AtomicType] +private[columnar] trait CompressibleColumnBuilder[T <: AtomicType] extends ColumnBuilder with Logging { this: NativeColumnBuilder[T] with WithCompressionSchemes => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala similarity index 83% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index 9322b772fd898..920381f9c63d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} +import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} import org.apache.spark.sql.types.AtomicType -private[sql] trait Encoder[T <: AtomicType] { +private[columnar] trait Encoder[T <: AtomicType] { def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {} def compressedSize: Int @@ -37,13 +37,13 @@ private[sql] trait Encoder[T <: AtomicType] { def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer } -private[sql] trait Decoder[T <: AtomicType] { +private[columnar] trait Decoder[T <: AtomicType] { def next(row: MutableRow, ordinal: Int): Unit def hasNext: Boolean } -private[sql] trait CompressionScheme { +private[columnar] trait CompressionScheme { def typeId: Int def supports(columnType: ColumnType[_]): Boolean @@ -53,15 +53,15 @@ private[sql] trait CompressionScheme { def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] } -private[sql] trait WithCompressionSchemes { +private[columnar] trait WithCompressionSchemes { def schemes: Seq[CompressionScheme] } -private[sql] trait AllCompressionSchemes extends WithCompressionSchemes { +private[columnar] trait AllCompressionSchemes extends WithCompressionSchemes { override val schemes: Seq[CompressionScheme] = CompressionScheme.all } -private[sql] object CompressionScheme { +private[columnar] object CompressionScheme { val all: Seq[CompressionScheme] = Seq(PassThrough, RunLengthEncoding, DictionaryEncoding, BooleanBitSet, IntDelta, LongDelta) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 41c9a284e3e4a..941f03b745a07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer @@ -23,11 +23,11 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types._ -private[sql] case object PassThrough extends CompressionScheme { +private[columnar] case object PassThrough extends CompressionScheme { override val typeId = 0 override def supports(columnType: ColumnType[_]): Boolean = true @@ -64,7 +64,7 @@ private[sql] case object PassThrough extends CompressionScheme { } } -private[sql] case object RunLengthEncoding extends CompressionScheme { +private[columnar] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { @@ -172,7 +172,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { } } -private[sql] case object DictionaryEncoding extends CompressionScheme { +private[columnar] case object DictionaryEncoding extends CompressionScheme { override val typeId = 2 // 32K unique values allowed @@ -281,7 +281,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } -private[sql] case object BooleanBitSet extends CompressionScheme { +private[columnar] case object BooleanBitSet extends CompressionScheme { override val typeId = 3 val BITS_PER_LONG = 64 @@ -371,7 +371,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { } } -private[sql] case object IntDelta extends CompressionScheme { +private[columnar] case object IntDelta extends CompressionScheme { override def typeId: Int = 4 override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) @@ -451,7 +451,7 @@ private[sql] case object IntDelta extends CompressionScheme { } } -private[sql] case object LongDelta extends CompressionScheme { +private[columnar] case object LongDelta extends CompressionScheme { override def typeId: Int = 5 override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala index 28fa231e722d0..c912734bba9e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala @@ -19,5 +19,7 @@ package org.apache.spark.sql /** * The physical execution component of Spark SQL. Note that this is a private package. + * All classes in catalyst are considered an internal API to Spark SQL and are subject + * to change between minor releases. */ package object execution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index bce94dafad755..d86df4cfb9b4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -27,7 +27,7 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext} import org.apache.spark.storage.{StorageLevel, RDDBlockId} @@ -280,7 +280,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext sql("CACHE TABLE testData") sqlContext.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => - val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum + val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index b5417b195f396..6ea1fe4ccfd89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.InMemoryRelation abstract class QueryTest extends PlanTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index 89a664001bdd2..b2d04f7c5a6e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow @@ -50,7 +50,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) @@ -86,7 +86,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ val columnStats = new DecimalColumnStats(15, 10) val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 63bc39bfa0307..34dd96929e6c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types._ import org.apache.spark.{Logging, SparkFunSuite} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index a5882f7870e37..9cae65ef6f5dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import scala.collection.immutable.HashSet import scala.util.Random diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 6265e40a0a07b..25afed25c897b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.sql.{Date, Timestamp} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala index aa1605fee8c73..35dc9a276cef7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.ByteBuffer @@ -38,7 +38,7 @@ object TestNullableColumnAccessor { } class NullableColumnAccessorSuite extends SparkFunSuite { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala index 91404577832a0..93be3e16a5ed9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters @@ -36,7 +36,7 @@ object TestNullableColumnBuilder { } class NullableColumnBuilderSuite extends SparkFunSuite { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 6b7401464f46f..d762f7bfe914c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala similarity index 94% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index 9a2948c59ba42..ccbddef0fad3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.columnar.{BOOLEAN, NoopColumnStats} +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index acfab6586c0d1..830ca0294e1b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index 2111e9fbe62cb..988a577a7b4d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType class IntegralDeltaSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala similarity index 95% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index 67ec08f594a43..ce3affba55c71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala similarity index 93% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala index 5268dfe0aa03e..5e078f251375a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types.AtomicType class TestCompressibleColumnBuilder[T <: AtomicType]( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 5c2fc7d82ffbd..99478e82d419f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.execution.columnar.InMemoryColumnarTableScan import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} From cbc159754665eae89f1157dd09ebcffd5bfd024a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 19 Nov 2015 14:49:25 -0800 Subject: [PATCH 0833/1168] [SPARK-11831][CORE][TESTS] Use port 0 to avoid port conflicts in tests Use port 0 to fix port-contention-related flakiness Author: Shixiong Zhu Closes #9841 from zsxwing/SPARK-11831. --- .../org/apache/spark/rpc/RpcEnvSuite.scala | 24 +++++++++---------- .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 4 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 834e4743df866..2f55006420ce1 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -39,7 +39,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { val conf = new SparkConf() - env = createRpcEnv(conf, "local", 12345) + env = createRpcEnv(conf, "local", 0) } override def afterAll(): Unit = { @@ -76,7 +76,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") try { @@ -130,7 +130,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") try { @@ -158,7 +158,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val shortProp = "spark.rpc.short.timeout" conf.set("spark.rpc.retry.wait", "0") conf.set("spark.rpc.numRetries", "1") - val anotherEnv = createRpcEnv(conf, "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { @@ -417,7 +417,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") try { @@ -457,7 +457,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-remotely-error") @@ -497,7 +497,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "network-events") @@ -543,7 +543,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-unserializable-error") @@ -571,8 +571,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) try { @volatile var message: String = null @@ -602,8 +602,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) try { localEnv.setupEndpoint("ask-authentication", new RpcEndpoint { diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index 6478ab51c4da2..7aac02775e1bf 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -40,7 +40,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { }) val conf = new SparkConf() val newRpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf), false)) + RpcEnvConfig(conf, "test", "localhost", 0, new SecurityManager(conf), false)) try { val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") assert(s"akka.tcp://local@${env.address}/user/test_endpoint" === @@ -59,7 +59,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { val conf = SSLSampleConfigs.sparkSSLConfig() val securityManager = new SecurityManager(conf) val rpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, securityManager, false)) + RpcEnvConfig(conf, "test", "localhost", 0, securityManager, false)) try { val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) From 037933af63995431ecd920faec52fabb344bc36f Mon Sep 17 00:00:00 2001 From: Srinivasa Reddy Vundela Date: Thu, 19 Nov 2015 14:51:40 -0800 Subject: [PATCH 0834/1168] =?UTF-8?q?[SPARK-11799][CORE]=20Make=20it=20exp?= =?UTF-8?q?licit=20in=20executor=20logs=20that=20uncaught=20e=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …xceptions are thrown during executor shutdown This commit will make sure that when uncaught exceptions are prepended with [Container in shutdown] when JVM is shutting down. Author: Srinivasa Reddy Vundela Closes #9809 from vundela/master_11799. --- .../apache/spark/util/SparkUncaughtExceptionHandler.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index 7248187247330..5e322557e9649 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -29,7 +29,11 @@ private[spark] object SparkUncaughtExceptionHandler override def uncaughtException(thread: Thread, exception: Throwable) { try { - logError("Uncaught exception in thread " + thread, exception) + // Make it explicit that uncaught exceptions are thrown when container is shutting down. + // It will help users when they analyze the executor logs + val inShutdownMsg = if (ShutdownHookManager.inShutdown()) "[Container in shutdown] " else "" + val errMsg = "Uncaught exception in thread " + logError(inShutdownMsg + errMsg + thread, exception) // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) From 298f1ec6099b641f1bd5f13236f75ba5dd093135 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 19 Nov 2015 14:53:58 -0800 Subject: [PATCH 0835/1168] [SPARK-11828][CORE] Register DAGScheduler metrics source after app id is known. Author: Marcelo Vanzin Closes #9820 from vanzin/SPARK-11828. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 + .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ab374cb71286a..af4456c05b0a1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -581,6 +581,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Post init _taskScheduler.postStartHook() + _env.metricsSystem.registerSource(_dagScheduler.metricsSource) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) _executorAllocationManager.foreach { e => _env.metricsSystem.registerSource(e.executorAllocationManagerSource) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4a9518fff4e7b..ae725b467d8c4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -130,7 +130,7 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) - private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + private[spark] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() @@ -1580,8 +1580,6 @@ class DAGScheduler( taskScheduler.stop() } - // Start the event thread and register the metrics source at the end of the constructor - env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() } From 9ada1d4aee823544cc74fab52d0526d3dab96400 Mon Sep 17 00:00:00 2001 From: hushan Date: Thu, 19 Nov 2015 14:56:00 -0800 Subject: [PATCH 0836/1168] [SPARK-11746][CORE] Use cache-aware method dependencies a small change Author: hushan Closes #9691 from suyanNone/unify-getDependency. --- .../main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index d6a37e8cc5dac..0c6ddda52cee9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -65,7 +65,7 @@ class PartitionPruningRDD[T: ClassTag]( } override protected def getPartitions: Array[Partition] = - getDependencies.head.asInstanceOf[PruneDependency[T]].partitions + dependencies.head.asInstanceOf[PruneDependency[T]].partitions } From da98810cb3b4e688882ea10baf0a97e1b2c0d92b Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 19 Nov 2015 15:11:30 -0800 Subject: [PATCH 0837/1168] [SPARK-11275][SQL] Incorrect results when using rollup/cube Fixes bug with grouping sets (including cube/rollup) where aggregates that included grouping expressions would return the wrong (null) result. Also simplifies the analyzer rule a bit and leaves column pruning to the optimizer. Added multiple unit tests to DataFrameAggregateSuite and verified it passes hive compatibility suite: ``` build/sbt -Phive -Dspark.hive.whitelist='groupby.*_grouping.*' 'test-only org.apache.spark.sql.hive.execution.HiveCompatibilitySuite' ``` This is an alternative to pr https://github.com/apache/spark/pull/9419 but I think its better as it simplifies the analyzer rule instead of adding another special case to it. Author: Andrew Ray Closes #9815 from aray/groupingset-agg-fix. --- .../sql/catalyst/analysis/Analyzer.scala | 58 +++++++---------- .../plans/logical/basicOperators.scala | 4 ++ .../spark/sql/DataFrameAggregateSuite.scala | 62 +++++++++++++++++++ 3 files changed, 90 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 84781cd57f3dc..47962ebe6ef82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -213,45 +213,35 @@ class Analyzer( GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) case x: GroupingSets => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() - // We will insert another Projection if the GROUP BY keys contains the - // non-attribute expressions. And the top operators can references those - // expressions by its alias. - // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==> - // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a - - // find all of the non-attribute expressions in the GROUP BY keys - val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]() - - // The pair of (the original GROUP BY key, associated attribute) - val groupByExprPairs = x.groupByExprs.map(_ match { - case e: NamedExpression => (e, e.toAttribute) - case other => { - val alias = Alias(other, other.toString)() - nonAttributeGroupByExpressions += alias // add the non-attributes expression alias - (other, alias.toAttribute) - } - }) - - // substitute the non-attribute expressions for aggregations. - val aggregation = x.aggregations.map(expr => expr.transformDown { - case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e) - }.asInstanceOf[NamedExpression]) - // substitute the group by expressions. - val newGroupByExprs = groupByExprPairs.map(_._2) + // Expand works by setting grouping expressions to null as determined by the bitmasks. To + // prevent these null values from being used in an aggregate instead of the original value + // we need to create new aliases for all group by expressions that will only be used for + // the intended purpose. + val groupByAliases: Seq[Alias] = x.groupByExprs.map { + case e: NamedExpression => Alias(e, e.name)() + case other => Alias(other, other.toString)() + } - val child = if (nonAttributeGroupByExpressions.length > 0) { - // insert additional projection if contains the - // non-attribute expressions in the GROUP BY keys - Project(x.child.output ++ nonAttributeGroupByExpressions, x.child) - } else { - x.child + val aggregations: Seq[NamedExpression] = x.aggregations.map { + // If an expression is an aggregate (contains a AggregateExpression) then we dont change + // it so that the aggregation is computed on the unmodified value of its argument + // expressions. + case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr + // If not then its a grouping expression and we need to use the modified (with nulls from + // Expand) value of the expression. + case expr => expr.transformDown { + case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e) + }.asInstanceOf[NamedExpression] } + val child = Project(x.child.output ++ groupByAliases, x.child) + val groupByAttributes = groupByAliases.map(_.toAttribute) + Aggregate( - newGroupByExprs :+ VirtualColumn.groupingIdAttribute, - aggregation, - Expand(x.bitmasks, newGroupByExprs, gid, child)) + groupByAttributes :+ VirtualColumn.groupingIdAttribute, + aggregations, + Expand(x.bitmasks, groupByAttributes, gid, child)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 45630a591d349..0c444482c5e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -323,6 +323,10 @@ trait GroupingAnalytics extends UnaryNode { override def output: Seq[Attribute] = aggregations.map(_.toAttribute) + // Needs to be unresolved before its translated to Aggregate + Expand because output attributes + // will change in analysis. + override lazy val resolved: Boolean = false + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 71adf2148a403..9c42f65bb6f52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -60,6 +60,68 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("rollup") { + checkAnswer( + courseSales.rollup("course", "year").sum("earnings"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + + test("cube") { + checkAnswer( + courseSales.cube("course", "year").sum("earnings"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + + test("rollup overlapping columns") { + checkAnswer( + testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.rollup("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, null, 9) :: Nil + ) + } + + test("cube overlapping columns") { + checkAnswer( + testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, 1, 3) :: Row(null, 2, 0) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.cube("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, 1, 3) :: Row(null, 2, 6) + :: Row(null, null, 9) :: Nil + ) + } + test("spark.sql.retainGroupColumns config") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), From 08cbbea441061f67e16f3ccd7eb70ec25a5bcd8e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 19 Nov 2015 16:49:18 -0800 Subject: [PATCH 0838/1168] [SPARK-4134][CORE] Lower severity of some executor loss logs. Don't log ERROR messages when executors are explicitly killed or when the exit reason is not yet known. Author: Marcelo Vanzin Closes #9780 from vanzin/SPARK-11789. --- .../spark/scheduler/ExecutorLossReason.scala | 2 + .../spark/scheduler/TaskSchedulerImpl.scala | 44 ++++++++++++------- .../spark/scheduler/TaskSetManager.scala | 1 + .../CoarseGrainedSchedulerBackend.scala | 18 +++++--- .../spark/deploy/yarn/YarnAllocator.scala | 4 +- 5 files changed, 45 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 47a5cbff4930b..7e1197d742802 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -40,6 +40,8 @@ private[spark] object ExecutorExited { } } +private[spark] object ExecutorKilled extends ExecutorLossReason("Executor killed by driver.") + /** * A loss reason that means we don't yet know why the executor exited. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bf0419db1f75e..bdf19f9f277d9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -470,25 +470,25 @@ private[spark] class TaskSchedulerImpl( synchronized { if (executorIdToTaskCount.contains(executorId)) { val hostPort = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) + logExecutorLoss(executorId, hostPort, reason) removeExecutor(executorId, reason) failedExecutor = Some(executorId) } else { - executorIdToHost.get(executorId) match { - case Some(_) => - // If the host mapping still exists, it means we don't know the loss reason for the - // executor. So call removeExecutor() to update tasks running on that executor when - // the real loss reason is finally known. - logError(s"Actual reason for lost executor $executorId: ${reason.message}") - removeExecutor(executorId, reason) - - case None => - // We may get multiple executorLost() calls with different loss reasons. For example, - // one may be triggered by a dropped connection from the slave while another may be a - // report of executor termination from Mesos. We produce log messages for both so we - // eventually report the termination reason. - logError("Lost an executor " + executorId + " (already removed): " + reason) - } + executorIdToHost.get(executorId) match { + case Some(hostPort) => + // If the host mapping still exists, it means we don't know the loss reason for the + // executor. So call removeExecutor() to update tasks running on that executor when + // the real loss reason is finally known. + logExecutorLoss(executorId, hostPort, reason) + removeExecutor(executorId, reason) + + case None => + // We may get multiple executorLost() calls with different loss reasons. For example, + // one may be triggered by a dropped connection from the slave while another may be a + // report of executor termination from Mesos. We produce log messages for both so we + // eventually report the termination reason. + logError(s"Lost an executor $executorId (already removed): $reason") + } } } // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock @@ -498,6 +498,18 @@ private[spark] class TaskSchedulerImpl( } } + private def logExecutorLoss( + executorId: String, + hostPort: String, + reason: ExecutorLossReason): Unit = reason match { + case LossReasonPending => + logDebug(s"Executor $executorId on $hostPort lost, but reason not yet known.") + case ExecutorKilled => + logInfo(s"Executor $executorId on $hostPort killed by driver.") + case _ => + logError(s"Lost executor $executorId on $hostPort: $reason") + } + /** * Remove an executor from all our data structures and mark it as lost. If the executor's loss * reason is not yet known, do not yet remove its association with its host nor update the status diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 114468c48c44c..a02f3017cb6e9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -800,6 +800,7 @@ private[spark] class TaskSetManager( for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { val exitCausedByApp: Boolean = reason match { case exited: ExecutorExited => exited.exitCausedByApp + case ExecutorKilled => false case _ => true } handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 6f0c910c009a5..505c161141c88 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -64,8 +64,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private val listenerBus = scheduler.sc.listenerBus - // Executors we have requested the cluster manager to kill that have not died yet - private val executorsPendingToRemove = new HashSet[String] + // Executors we have requested the cluster manager to kill that have not died yet; maps + // the executor ID to whether it was explicitly killed by the driver (and thus shouldn't + // be considered an app-related failure). + private val executorsPendingToRemove = new HashMap[String, Boolean] // A map to store hostname with its possible task number running on it protected var hostToLocalTaskCount: Map[String, Int] = Map.empty @@ -250,15 +252,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case Some(executorInfo) => // This must be synchronized because variables mutated // in this block are read when requesting executors - CoarseGrainedSchedulerBackend.this.synchronized { + val killed = CoarseGrainedSchedulerBackend.this.synchronized { addressToExecutorId -= executorInfo.executorAddress executorDataMap -= executorId - executorsPendingToRemove -= executorId executorsPendingLossReason -= executorId + executorsPendingToRemove.remove(executorId).getOrElse(false) } totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) - scheduler.executorLost(executorId, reason) + scheduler.executorLost(executorId, if (killed) ExecutorKilled else reason) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString)) case None => logInfo(s"Asked to remove non-existent executor $executorId") @@ -459,6 +461,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. * + * When asking the executor to be replaced, the executor loss is considered a failure, and + * killed tasks that are running on the executor will count towards the failure limits. If no + * replacement is being requested, then the tasks will not count towards the limit. + * * @param executorIds identifiers of executors to kill * @param replace whether to replace the killed executors with new ones * @param force whether to force kill busy executors @@ -479,7 +485,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val executorsToKill = knownExecutors .filter { id => !executorsPendingToRemove.contains(id) } .filter { id => force || !scheduler.isExecutorBusy(id) } - executorsPendingToRemove ++= executorsToKill + executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } // If we do not wish to replace the executors we kill, sync the target number of executors // with the cluster manager to avoid allocating new ones. When computing the new target, diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 7e39c3ea56af3..73cd9031f0250 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -481,7 +481,7 @@ private[yarn] class YarnAllocator( (true, memLimitExceededLogMessage( completedContainer.getDiagnostics, PMEM_EXCEEDED_PATTERN)) - case unknown => + case _ => numExecutorsFailed += 1 (true, "Container marked as failed: " + containerId + onHostStr + ". Exit status: " + completedContainer.getExitStatus + @@ -493,7 +493,7 @@ private[yarn] class YarnAllocator( } else { logInfo(containerExitReason) } - ExecutorExited(0, exitCausedByApp, containerExitReason) + ExecutorExited(exitStatus, exitCausedByApp, containerExitReason) } else { // If we have already released this container, then it must mean // that the driver has explicitly requested it to be killed From 87a5ec3441462e62e9eb194bc535758752a627bb Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 19 Nov 2015 16:50:08 -0800 Subject: [PATCH 0839/1168] [SPARK-11845][STREAMING][TEST] Added unit test to verify TrackStateRDD is correctly checkpointed To make sure that all lineage is correctly truncated for TrackStateRDD when checkpointed. Author: Tathagata Das Closes #9831 from tdas/SPARK-11845. --- .../org/apache/spark/CheckpointSuite.scala | 411 +++++++++--------- .../streaming/rdd/TrackStateRDDSuite.scala | 60 ++- 2 files changed, 267 insertions(+), 204 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 119e5fc28e412..ab23326c6c25d 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,17 +21,223 @@ import java.io.File import scala.reflect.ClassTag +import org.apache.spark.CheckpointSuite._ import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils +trait RDDCheckpointTester { self: SparkFunSuite => + + protected val partitioner = new HashPartitioner(2) + + private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() + + /** Implementations of this trait must implement this method */ + protected def sparkContext: SparkContext + + /** + * Test checkpointing of the RDD generated by the given operation. It tests whether the + * serialized size of the RDD is reduce after checkpointing or not. This function should be called + * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). + * + * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something + * that supports == + */ + protected def testRDD[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { + // Generate the final RDD using given RDD operation + val baseRDD = generateFatRDD() + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.headOption.orNull + val rddType = operatedRDD.getClass.getSimpleName + val numPartitions = operatedRDD.partitions.length + + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) + + val partitionsBeforeCheckpoint = operatedRDD.partitions + + // Find serialized sizes before and after the checkpoint + logInfo("RDD before checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + checkpoint(operatedRDD, reliableCheckpoint) + val result = collectFunc(operatedRDD) + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + + // Test whether the checkpoint file has been created + if (reliableCheckpoint) { + assert( + collectFunc(sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) + } + + // Test whether dependencies have been changed from its earlier parent RDD + assert(operatedRDD.dependencies.head.rdd != parentRDD) + + // Test whether the partitions have been changed from its earlier partitions + assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) + + // Test whether the partitions have been changed to the new Hadoop partitions + assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) + + // Test whether the number of partitions is same as before + assert(operatedRDD.partitions.length === numPartitions) + + // Test whether the data in the checkpointed RDD is same as original + assert(collectFunc(operatedRDD) === result) + + // Test whether serialized size of the RDD has reduced. + logInfo("Size of " + rddType + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing " + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + /** + * Test whether checkpointing of the parent of the generated RDD also + * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent + * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, + * the generated RDD will remember the partitions and therefore potentially the whole lineage. + * This function should be called only those RDD whose partitions refer to parent RDD's + * partitions (i.e., do not call it on simple RDD like MappedRDD). + * + * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something + * that supports == + */ + protected def testRDDPartitions[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { + // Generate the final RDD using given RDD operation + val baseRDD = generateFatRDD() + val operatedRDD = op(baseRDD) + val parentRDDs = operatedRDD.dependencies.map(_.rdd) + val rddType = operatedRDD.getClass.getSimpleName + + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) + + // Find serialized sizes before and after the checkpoint + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + // checkpoint the parent RDD, not the generated one + parentRDDs.foreach { rdd => + checkpoint(rdd, reliableCheckpoint) + } + val result = collectFunc(operatedRDD) // force checkpointing + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + + // Test whether the data in the checkpointed RDD is same as original + assert(collectFunc(operatedRDD) === result) + + // Test whether serialized size of the partitions has reduced + logInfo("Size of partitions of " + rddType + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") + assert( + partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, + "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" + ) + } + + /** + * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks + * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. + */ + private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { + val rddSize = Utils.serialize(rdd).size + val rddCpDataSize = Utils.serialize(rdd.checkpointData).size + val rddPartitionSize = Utils.serialize(rdd.partitions).size + val rddDependenciesSize = Utils.serialize(rdd.dependencies).size + + // Print detailed size, helps in debugging + logInfo("Serialized sizes of " + rdd + + ": RDD = " + rddSize + + ", RDD checkpoint data = " + rddCpDataSize + + ", RDD partitions = " + rddPartitionSize + + ", RDD dependencies = " + rddDependenciesSize + ) + // this makes sure that serializing the RDD's checkpoint data does not + // serialize the whole RDD as well + assert( + rddSize > rddCpDataSize, + "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + + "whole RDD with checkpoint data (" + rddSize + ")" + ) + (rddSize - rddCpDataSize, rddPartitionSize) + } + + /** + * Serialize and deserialize an object. This is useful to verify the objects + * contents after deserialization (e.g., the contents of an RDD split after + * it is sent to a slave along with a task) + */ + protected def serializeDeserialize[T](obj: T): T = { + val bytes = Utils.serialize(obj) + Utils.deserialize[T](bytes) + } + + /** + * Recursively force the initialization of the all members of an RDD and it parents. + */ + private def initializeRdd(rdd: RDD[_]): Unit = { + rdd.partitions // forces the initialization of the partitions + rdd.dependencies.map(_.rdd).foreach(initializeRdd) + } + + /** Checkpoint the RDD either locally or reliably. */ + protected def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { + if (reliableCheckpoint) { + rdd.checkpoint() + } else { + rdd.localCheckpoint() + } + } + + /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ + protected def runTest(name: String)(body: Boolean => Unit): Unit = { + test(name + " [reliable checkpoint]")(body(true)) + test(name + " [local checkpoint]")(body(false)) + } + + /** + * Generate an RDD such that both the RDD and its partitions have large size. + */ + protected def generateFatRDD(): RDD[Int] = { + new FatRDD(sparkContext.makeRDD(1 to 100, 4)).map(x => x) + } + + /** + * Generate an pair RDD (with partitioner) such that both the RDD and its partitions + * have large size. + */ + protected def generateFatPairRDD(): RDD[(Int, Int)] = { + new FatPairRDD(sparkContext.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) + } +} + /** * Test suite for end-to-end checkpointing functionality. * This tests both reliable checkpoints and local checkpoints. */ -class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging { +class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalSparkContext { private var checkpointDir: File = _ - private val partitioner = new HashPartitioner(2) override def beforeEach(): Unit = { super.beforeEach() @@ -46,6 +252,8 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging Utils.deleteRecursively(checkpointDir) } + override def sparkContext: SparkContext = sc + runTest("basic checkpointing") { reliableCheckpoint: Boolean => val parCollection = sc.makeRDD(1 to 4) val flatMappedRDD = parCollection.flatMap(x => 1 to x) @@ -250,204 +458,6 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging assert(rdd.isCheckpointedAndMaterialized === true) assert(rdd.partitions.size === 0) } - - // Utility test methods - - /** Checkpoint the RDD either locally or reliably. */ - private def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { - if (reliableCheckpoint) { - rdd.checkpoint() - } else { - rdd.localCheckpoint() - } - } - - /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ - private def runTest(name: String)(body: Boolean => Unit): Unit = { - test(name + " [reliable checkpoint]")(body(true)) - test(name + " [local checkpoint]")(body(false)) - } - - private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() - - /** - * Test checkpointing of the RDD generated by the given operation. It tests whether the - * serialized size of the RDD is reduce after checkpointing or not. This function should be called - * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). - * - * @param op an operation to run on the RDD - * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints - * @param collectFunc a function for collecting the values in the RDD, in case there are - * non-comparable types like arrays that we want to convert to something that supports == - */ - private def testRDD[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - reliableCheckpoint: Boolean, - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { - // Generate the final RDD using given RDD operation - val baseRDD = generateFatRDD() - val operatedRDD = op(baseRDD) - val parentRDD = operatedRDD.dependencies.headOption.orNull - val rddType = operatedRDD.getClass.getSimpleName - val numPartitions = operatedRDD.partitions.length - - // Force initialization of all the data structures in RDDs - // Without this, serializing the RDD will give a wrong estimate of the size of the RDD - initializeRdd(operatedRDD) - - val partitionsBeforeCheckpoint = operatedRDD.partitions - - // Find serialized sizes before and after the checkpoint - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - checkpoint(operatedRDD, reliableCheckpoint) - val result = collectFunc(operatedRDD) - operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables - val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - - // Test whether the checkpoint file has been created - if (reliableCheckpoint) { - assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) - } - - // Test whether dependencies have been changed from its earlier parent RDD - assert(operatedRDD.dependencies.head.rdd != parentRDD) - - // Test whether the partitions have been changed from its earlier partitions - assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) - - // Test whether the partitions have been changed to the new Hadoop partitions - assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) - - // Test whether the number of partitions is same as before - assert(operatedRDD.partitions.length === numPartitions) - - // Test whether the data in the checkpointed RDD is same as original - assert(collectFunc(operatedRDD) === result) - - // Test whether serialized size of the RDD has reduced. - logInfo("Size of " + rddType + - " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") - assert( - rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after checkpointing " + - " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" - ) - } - - /** - * Test whether checkpointing of the parent of the generated RDD also - * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent - * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, - * the generated RDD will remember the partitions and therefore potentially the whole lineage. - * This function should be called only those RDD whose partitions refer to parent RDD's - * partitions (i.e., do not call it on simple RDD like MappedRDD). - * - * @param op an operation to run on the RDD - * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints - * @param collectFunc a function for collecting the values in the RDD, in case there are - * non-comparable types like arrays that we want to convert to something that supports == - */ - private def testRDDPartitions[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - reliableCheckpoint: Boolean, - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { - // Generate the final RDD using given RDD operation - val baseRDD = generateFatRDD() - val operatedRDD = op(baseRDD) - val parentRDDs = operatedRDD.dependencies.map(_.rdd) - val rddType = operatedRDD.getClass.getSimpleName - - // Force initialization of all the data structures in RDDs - // Without this, serializing the RDD will give a wrong estimate of the size of the RDD - initializeRdd(operatedRDD) - - // Find serialized sizes before and after the checkpoint - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - // checkpoint the parent RDD, not the generated one - parentRDDs.foreach { rdd => - checkpoint(rdd, reliableCheckpoint) - } - val result = collectFunc(operatedRDD) // force checkpointing - operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables - val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - - // Test whether the data in the checkpointed RDD is same as original - assert(collectFunc(operatedRDD) === result) - - // Test whether serialized size of the partitions has reduced - logInfo("Size of partitions of " + rddType + - " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") - assert( - partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, - "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + - " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" - ) - } - - /** - * Generate an RDD such that both the RDD and its partitions have large size. - */ - private def generateFatRDD(): RDD[Int] = { - new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x) - } - - /** - * Generate an pair RDD (with partitioner) such that both the RDD and its partitions - * have large size. - */ - private def generateFatPairRDD(): RDD[(Int, Int)] = { - new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) - } - - /** - * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks - * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. - */ - private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { - val rddSize = Utils.serialize(rdd).size - val rddCpDataSize = Utils.serialize(rdd.checkpointData).size - val rddPartitionSize = Utils.serialize(rdd.partitions).size - val rddDependenciesSize = Utils.serialize(rdd.dependencies).size - - // Print detailed size, helps in debugging - logInfo("Serialized sizes of " + rdd + - ": RDD = " + rddSize + - ", RDD checkpoint data = " + rddCpDataSize + - ", RDD partitions = " + rddPartitionSize + - ", RDD dependencies = " + rddDependenciesSize - ) - // this makes sure that serializing the RDD's checkpoint data does not - // serialize the whole RDD as well - assert( - rddSize > rddCpDataSize, - "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + - "whole RDD with checkpoint data (" + rddSize + ")" - ) - (rddSize - rddCpDataSize, rddPartitionSize) - } - - /** - * Serialize and deserialize an object. This is useful to verify the objects - * contents after deserialization (e.g., the contents of an RDD split after - * it is sent to a slave along with a task) - */ - private def serializeDeserialize[T](obj: T): T = { - val bytes = Utils.serialize(obj) - Utils.deserialize[T](bytes) - } - - /** - * Recursively force the initialization of the all members of an RDD and it parents. - */ - private def initializeRdd(rdd: RDD[_]): Unit = { - rdd.partitions // forces the - rdd.dependencies.map(_.rdd).foreach(initializeRdd) - } - } /** RDD partition that has large serialized size. */ @@ -494,5 +504,4 @@ object CheckpointSuite { part ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } - } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index 19ef5a14f8ab4..0feb3af1abb0f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -17,31 +17,40 @@ package org.apache.spark.streaming.rdd +import java.io.File + import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll +import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.streaming.util.OpenHashMapBasedStateMap -import org.apache.spark.streaming.{Time, State} -import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.streaming.{State, Time} +import org.apache.spark.util.Utils -class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { +class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll { private var sc: SparkContext = null + private var checkpointDir: File = _ override def beforeAll(): Unit = { sc = new SparkContext( new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite")) + checkpointDir = Utils.createTempDir() + sc.setCheckpointDir(checkpointDir.toString) } override def afterAll(): Unit = { if (sc != null) { sc.stop() } + Utils.deleteRecursively(checkpointDir) } + override def sparkContext: SparkContext = sc + test("creation from pair RDD") { val data = Seq((1, "1"), (2, "2"), (3, "3")) val partitioner = new HashPartitioner(10) @@ -278,6 +287,51 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { rdd7, Seq(("k3", 2)), Set()) } + test("checkpointing") { + /** + * This tests whether the TrackStateRDD correctly truncates any references to its parent RDDs - + * the data RDD and the parent TrackStateRDD. + */ + def rddCollectFunc(rdd: RDD[TrackStateRDDRecord[Int, Int, Int]]) + : Set[(List[(Int, Int, Long)], List[Int])] = { + rdd.map { record => (record.stateMap.getAll().toList, record.emittedRecords.toList) } + .collect.toSet + } + + /** Generate TrackStateRDD with data RDD having a long lineage */ + def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int]) + : TrackStateRDD[Int, Int, Int, Int] = { + TrackStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0)) + } + + testRDD( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + + /** Generate TrackStateRDD with parent state RDD having a long lineage */ + def makeStateRDDWithLongLineageParenttateRDD( + longLineageRDD: RDD[Int]): TrackStateRDD[Int, Int, Int, Int] = { + + // Create a TrackStateRDD that has a long lineage using the data RDD with a long lineage + val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD) + + // Create a new TrackStateRDD, with the lineage lineage TrackStateRDD as the parent + new TrackStateRDD[Int, Int, Int, Int]( + stateRDDWithLongLineage, + stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner), + (time: Time, key: Int, value: Option[Int], state: State[Int]) => None, + Time(10), + None + ) + } + + testRDD( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + } + /** Assert whether the `trackStateByKey` operation generates expected results */ private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( testStateRDD: TrackStateRDD[K, V, S, T], From d83d6e240e3b8fdb907eb3dd95cd9c2c4883add3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 19 Nov 2015 17:14:10 -0800 Subject: [PATCH 0840/1168] [SPARK-11864][SQL] Improve performance of max/min This PR has the following optimization: 1) The greatest/least already does the null-check, so the `If` and `IsNull` are not necessary. 2) In greatest/least, it should initialize the result using the first child (removing one block). 3) For primitive types, the generated greater expression is too complicated (`a > b ? 1 : (a < b) ? -1 : 0) > 0`), should be as simple as `a > b` Combine these optimization, this could improve the performance of `ss_max` query by 30%. Author: Davies Liu Closes #9846 from davies/improve_max. --- .../catalyst/expressions/aggregate/Max.scala | 5 +-- .../catalyst/expressions/aggregate/Min.scala | 5 +-- .../expressions/codegen/CodeGenerator.scala | 12 ++++++ .../expressions/conditionalExpressions.scala | 38 +++++++++++-------- .../expressions/nullExpressions.scala | 10 +++-- 5 files changed, 45 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 61cae44cd0f5b..906003188d4ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -46,13 +46,12 @@ case class Max(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) + /* max = */ Greatest(Seq(max, child)) ) override lazy val mergeExpressions: Seq[Expression] = { - val greatest = Greatest(Seq(max.left, max.right)) Seq( - /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) + /* max = */ Greatest(Seq(max.left, max.right)) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 242456d9e2e18..39f7afbd081cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -47,13 +47,12 @@ case class Min(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) + /* min = */ Least(Seq(min, child)) ) override lazy val mergeExpressions: Seq[Expression] = { - val least = Least(Seq(min.left, min.right)) Seq( - /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) + /* min = */ Least(Seq(min.left, min.right)) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1718cfbd35332..1b7260cdfe515 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -329,6 +329,18 @@ class CodeGenContext { throw new IllegalArgumentException("cannot generate compare code for un-comparable type") } + /** + * Generates code for greater of two expressions. + * + * @param dataType data type of the expressions + * @param c1 name of the variable of expression 1's output + * @param c2 name of the variable of expression 2's output + */ + def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match { + case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2" + case _ => s"(${genComp(dataType, c1, c2)}) > 0" + } + /** * List of java data types that have special accessors and setters in [[InternalRow]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 0d4af43978ea1..694a2a7c54a90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -348,19 +348,22 @@ case class Least(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = + val first = evalChildren(0) + val rest = evalChildren.drop(1) + def updateEval(eval: GeneratedExpressionCode): String = s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).value, ev.value)} < 0)) { + ${eval.code} + if (!${eval.isNull} && (${ev.isNull} || + ${ctx.genGreater(dataType, ev.value, eval.value)})) { ${ev.isNull} = false; - ${ev.value} = ${evalChildren(i).value}; + ${ev.value} = ${eval.value}; } """ s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} + ${first.code} + boolean ${ev.isNull} = ${first.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; + ${rest.map(updateEval).mkString("\n")} """ } } @@ -403,19 +406,22 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = + val first = evalChildren(0) + val rest = evalChildren.drop(1) + def updateEval(eval: GeneratedExpressionCode): String = s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).value, ev.value)} > 0)) { + ${eval.code} + if (!${eval.isNull} && (${ev.isNull} || + ${ctx.genGreater(dataType, eval.value, ev.value)})) { ${ev.isNull} = false; - ${ev.value} = ${evalChildren(i).value}; + ${ev.value} = ${eval.value}; } """ s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} + ${first.code} + boolean ${ev.isNull} = ${first.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; + ${rest.map(updateEval).mkString("\n")} """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 94deafb75b69c..df4747d4e6f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -62,11 +62,15 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val first = children(0) + val rest = children.drop(1) + val firstEval = first.gen(ctx) s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${firstEval.code} + boolean ${ev.isNull} = ${firstEval.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value}; """ + - children.map { e => + rest.map { e => val eval = e.gen(ctx) s""" if (${ev.isNull}) { From 5f78f632d702ad5ed0841c03a8bcb8423be490e7 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 19 Nov 2015 19:46:10 -0800 Subject: [PATCH 0841/1168] [SPARK-11544][SQL][TEST-HADOOP1.0] sqlContext doesn't use PathFilter Apply the user supplied pathfilter while retrieving the files from fs. Author: Dilip Biswal Closes #9830 from dilipbiswal/spark-11544. --- .../apache/spark/sql/sources/interfaces.scala | 25 ++++++++--- .../datasources/json/JsonSuite.scala | 41 ++++++++++++++++++- 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b3d3bdf50df63..f9465157c936d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,7 +21,8 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} +import org.apache.hadoop.mapred.{JobConf, FileInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -447,9 +448,15 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - logInfo(s"Listing $qualified on driver") - Try(fs.listStatus(qualified)).getOrElse(Array.empty) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) + } else { + Try(fs.listStatus(qualified)).getOrElse(Array.empty) + } }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -847,8 +854,16 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(fs.getConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6042b1178affe..ba7718c864637 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,19 +19,27 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.spark.rdd.RDD +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} import org.scalactic.Tolerance._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class TestFileFilter extends PathFilter { + override def accept(path: Path): Boolean = path.getParent.getName != "p=2" +} + class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1390,4 +1398,33 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + test("SPARK-11544 test pathfilter") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(2) + df.write.json(path + "/p=1") + df.write.json(path + "/p=2") + assert(sqlContext.read.json(path).count() === 4) + + val clonedConf = new Configuration(hadoopConfiguration) + try { + // Setting it twice as the name of the propery has changed between hadoop versions. + hadoopConfiguration.setClass( + "mapred.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + hadoopConfiguration.setClass( + "mapreduce.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + assert(sqlContext.read.json(path).count() === 2) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } } From 4176e345e674244ea39366d1b12fa3177a442999 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 19 Nov 2015 22:01:02 -0800 Subject: [PATCH 0842/1168] [SPARK-11846] Add save/load for AFTSurvivalRegression and IsotonicRegression https://issues.apache.org/jira/browse/SPARK-11846 mengxr Author: Xusen Yin Closes #9836 from yinxusen/SPARK-11846. --- .../ml/regression/AFTSurvivalRegression.scala | 78 +++++++++++++++-- .../ml/regression/IsotonicRegression.scala | 83 +++++++++++++++++-- .../AFTSurvivalRegressionSuite.scala | 37 ++++++++- .../regression/IsotonicRegressionSuite.scala | 34 +++++++- 4 files changed, 210 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index b7d095872ffa5..aedfb48058dc5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -21,20 +21,20 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} +import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkException, Logging} -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.ml.util._ +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel +import org.apache.spark.{Logging, SparkException} /** * Params for accelerated failure time (AFT) regression. @@ -120,7 +120,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params @Experimental @Since("1.6.0") class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String) - extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging { + extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams + with DefaultParamsWritable with Logging { @Since("1.6.0") def this() = this(Identifiable.randomUID("aftSurvReg")) @@ -243,6 +244,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra) } +@Since("1.6.0") +object AFTSurvivalRegression extends DefaultParamsReadable[AFTSurvivalRegression] { + + @Since("1.6.0") + override def load(path: String): AFTSurvivalRegression = super.load(path) +} + /** * :: Experimental :: * Model produced by [[AFTSurvivalRegression]]. @@ -254,7 +262,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") val coefficients: Vector, @Since("1.6.0") val intercept: Double, @Since("1.6.0") val scale: Double) - extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams { + extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable { /** @group setParam */ @Since("1.6.0") @@ -312,6 +320,58 @@ class AFTSurvivalRegressionModel private[ml] ( copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) .setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = + new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this) +} + +@Since("1.6.0") +object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[AFTSurvivalRegressionModel] = new AFTSurvivalRegressionModelReader + + @Since("1.6.0") + override def load(path: String): AFTSurvivalRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[AFTSurvivalRegressionModel]] */ + private[AFTSurvivalRegressionModel] class AFTSurvivalRegressionModelWriter ( + instance: AFTSurvivalRegressionModel + ) extends MLWriter with Logging { + + private case class Data(coefficients: Vector, intercept: Double, scale: Double) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: coefficients, intercept, scale + val data = Data(instance.coefficients, instance.intercept, instance.scale) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class AFTSurvivalRegressionModelReader extends MLReader[AFTSurvivalRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[AFTSurvivalRegressionModel].getName + + override def load(path: String): AFTSurvivalRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("coefficients", "intercept", "scale").head() + val coefficients = data.getAs[Vector](0) + val intercept = data.getDouble(1) + val scale = data.getDouble(2) + val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index a1fe01b047108..bbb1c7ac0a51e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -17,18 +17,22 @@ package org.apache.spark.ml.regression +import org.apache.hadoop.fs.Path + import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter +import org.apache.spark.ml.util._ +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel} +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} +import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** @@ -127,7 +131,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures @Since("1.5.0") @Experimental class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Estimator[IsotonicRegressionModel] with IsotonicRegressionBase { + extends Estimator[IsotonicRegressionModel] + with IsotonicRegressionBase with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("isoReg")) @@ -179,6 +184,13 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri } } +@Since("1.6.0") +object IsotonicRegression extends DefaultParamsReadable[IsotonicRegression] { + + @Since("1.6.0") + override def load(path: String): IsotonicRegression = super.load(path) +} + /** * :: Experimental :: * Model fitted by IsotonicRegression. @@ -194,7 +206,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri class IsotonicRegressionModel private[ml] ( override val uid: String, private val oldModel: MLlibIsotonicRegressionModel) - extends Model[IsotonicRegressionModel] with IsotonicRegressionBase { + extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with MLWritable { /** @group setParam */ @Since("1.5.0") @@ -240,4 +252,61 @@ class IsotonicRegressionModel private[ml] ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false) } + + @Since("1.6.0") + override def write: MLWriter = + new IsotonicRegressionModelWriter(this) +} + +@Since("1.6.0") +object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[IsotonicRegressionModel] = new IsotonicRegressionModelReader + + @Since("1.6.0") + override def load(path: String): IsotonicRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[IsotonicRegressionModel]] */ + private[IsotonicRegressionModel] class IsotonicRegressionModelWriter ( + instance: IsotonicRegressionModel + ) extends MLWriter with Logging { + + private case class Data( + boundaries: Array[Double], + predictions: Array[Double], + isotonic: Boolean) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: boundaries, predictions, isotonic + val data = Data( + instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class IsotonicRegressionModelReader extends MLReader[IsotonicRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[IsotonicRegressionModel].getName + + override def load(path: String): IsotonicRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("boundaries", "predictions", "isotonic").head() + val boundaries = data.getAs[Seq[Double]](0).toArray + val predictions = data.getAs[Seq[Double]](1).toArray + val isotonic = data.getBoolean(2) + val model = new IsotonicRegressionModel( + metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 359f31027172b..d718ef63b531a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -21,14 +21,15 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} -class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class AFTSurvivalRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var datasetUnivariate: DataFrame = _ @transient var datasetMultivariate: DataFrame = _ @@ -332,4 +333,32 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex assert(prediction ~== model.predict(features) relTol 1E-5) } } + + test("read/write") { + def checkModelData( + model: AFTSurvivalRegressionModel, + model2: AFTSurvivalRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + assert(model.scale === model2.scale) + } + val aft = new AFTSurvivalRegression() + testEstimatorAndModelReadWrite(aft, datasetMultivariate, + AFTSurvivalRegressionSuite.allParamSettings, checkModelData) + } +} + +object AFTSurvivalRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "fitIntercept" -> true, + "maxIter" -> 2, + "tol" -> 0.01 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 59f4193abc8f0..f067c29d27a7d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class IsotonicRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { sqlContext.createDataFrame( labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } @@ -164,4 +166,32 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predictions === Array(3.5, 5.0, 5.0, 5.0)) } + + test("read/write") { + val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) + + def checkModelData(model: IsotonicRegressionModel, model2: IsotonicRegressionModel): Unit = { + assert(model.boundaries === model2.boundaries) + assert(model.predictions === model2.predictions) + assert(model.isotonic === model2.isotonic) + } + + val ir = new IsotonicRegression() + testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings, + checkModelData) + } +} + +object IsotonicRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "isotonic" -> true, + "featureIndex" -> 0 + ) } From 2c04d36a5b9db96aac5fb10bc702023f7f5893cf Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 19 Nov 2015 22:02:17 -0800 Subject: [PATCH 0843/1168] [SPARK-11829][ML] Add read/write to estimators under ml.feature (II) Add read/write support to the following estimators under spark.ml: * ChiSqSelector * PCA * VectorIndexer * Word2Vec Author: Yanbo Liang Closes #9838 from yanboliang/spark-11829. --- .../spark/ml/feature/ChiSqSelector.scala | 65 ++++++++++++++++-- .../org/apache/spark/ml/feature/PCA.scala | 67 +++++++++++++++++-- .../spark/ml/feature/VectorIndexer.scala | 66 ++++++++++++++++-- .../apache/spark/ml/feature/Word2Vec.scala | 67 +++++++++++++++++-- .../apache/spark/mllib/feature/Word2Vec.scala | 6 +- .../spark/ml/feature/ChiSqSelectorSuite.scala | 22 +++++- .../apache/spark/ml/feature/PCASuite.scala | 26 ++++++- .../spark/ml/feature/VectorIndexerSuite.scala | 22 +++++- .../spark/ml/feature/Word2VecSuite.scala | 30 ++++++++- 9 files changed, 338 insertions(+), 33 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 5e4061fba5494..dfec03828f4b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -17,13 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute.{AttributeGroup, _} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint @@ -60,7 +61,7 @@ private[feature] trait ChiSqSelectorParams extends Params */ @Experimental final class ChiSqSelector(override val uid: String) - extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams { + extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("chiSqSelector")) @@ -95,6 +96,13 @@ final class ChiSqSelector(override val uid: String) override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra) } +@Since("1.6.0") +object ChiSqSelector extends DefaultParamsReadable[ChiSqSelector] { + + @Since("1.6.0") + override def load(path: String): ChiSqSelector = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[ChiSqSelector]]. @@ -103,7 +111,12 @@ final class ChiSqSelector(override val uid: String) final class ChiSqSelectorModel private[ml] ( override val uid: String, private val chiSqSelector: feature.ChiSqSelectorModel) - extends Model[ChiSqSelectorModel] with ChiSqSelectorParams { + extends Model[ChiSqSelectorModel] with ChiSqSelectorParams with MLWritable { + + import ChiSqSelectorModel._ + + /** list of indices to select (filter). Must be ordered asc */ + val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures /** @group setParam */ def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -147,4 +160,46 @@ final class ChiSqSelectorModel private[ml] ( val copied = new ChiSqSelectorModel(uid, chiSqSelector) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new ChiSqSelectorModelWriter(this) +} + +@Since("1.6.0") +object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { + + private[ChiSqSelectorModel] + class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends MLWriter { + + private case class Data(selectedFeatures: Seq[Int]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.selectedFeatures.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class ChiSqSelectorModelReader extends MLReader[ChiSqSelectorModel] { + + private val className = classOf[ChiSqSelectorModel].getName + + override def load(path: String): ChiSqSelectorModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head() + val selectedFeatures = data.getAs[Seq[Int]](0).toArray + val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) + val model = new ChiSqSelectorModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[ChiSqSelectorModel] = new ChiSqSelectorModelReader + + @Since("1.6.0") + override def load(path: String): ChiSqSelectorModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 539084704b653..32d7afee6e73b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -17,13 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -49,7 +51,8 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC * PCA trains a model to project vectors to a low-dimensional space using PCA. */ @Experimental -class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams { +class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("pca")) @@ -86,6 +89,13 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams override def copy(extra: ParamMap): PCA = defaultCopy(extra) } +@Since("1.6.0") +object PCA extends DefaultParamsReadable[PCA] { + + @Since("1.6.0") + override def load(path: String): PCA = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[PCA]]. @@ -94,7 +104,12 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams class PCAModel private[ml] ( override val uid: String, pcaModel: feature.PCAModel) - extends Model[PCAModel] with PCAParams { + extends Model[PCAModel] with PCAParams with MLWritable { + + import PCAModel._ + + /** a principal components Matrix. Each column is one principal component. */ + val pc: DenseMatrix = pcaModel.pc /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -127,4 +142,46 @@ class PCAModel private[ml] ( val copied = new PCAModel(uid, pcaModel) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new PCAModelWriter(this) +} + +@Since("1.6.0") +object PCAModel extends MLReadable[PCAModel] { + + private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter { + + private case class Data(k: Int, pc: DenseMatrix) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.getK, instance.pc) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class PCAModelReader extends MLReader[PCAModel] { + + private val className = classOf[PCAModel].getName + + override def load(path: String): PCAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath) + .select("k", "pc") + .head() + val oldModel = new feature.PCAModel(k, pc) + val model = new PCAModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[PCAModel] = new PCAModelReader + + @Since("1.6.0") + override def load(path: String): PCAModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 52e0599e38d83..a637a6f2881de 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -22,12 +22,14 @@ import java.util.{Map => JMap} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.udf @@ -93,7 +95,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu */ @Experimental class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel] - with VectorIndexerParams { + with VectorIndexerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecIdx")) @@ -136,7 +138,11 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra) } -private object VectorIndexer { +@Since("1.6.0") +object VectorIndexer extends DefaultParamsReadable[VectorIndexer] { + + @Since("1.6.0") + override def load(path: String): VectorIndexer = super.load(path) /** * Helper class for tracking unique values for each feature. @@ -146,7 +152,7 @@ private object VectorIndexer { * @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures. * @param maxCategories This class caps the number of unique values collected at maxCategories. */ - class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) + private class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) extends Serializable { /** featureValueSets[feature index] = set of unique values */ @@ -252,7 +258,9 @@ class VectorIndexerModel private[ml] ( override val uid: String, val numFeatures: Int, val categoryMaps: Map[Int, Map[Double, Int]]) - extends Model[VectorIndexerModel] with VectorIndexerParams { + extends Model[VectorIndexerModel] with VectorIndexerParams with MLWritable { + + import VectorIndexerModel._ /** Java-friendly version of [[categoryMaps]] */ def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = { @@ -408,4 +416,48 @@ class VectorIndexerModel private[ml] ( val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new VectorIndexerModelWriter(this) +} + +@Since("1.6.0") +object VectorIndexerModel extends MLReadable[VectorIndexerModel] { + + private[VectorIndexerModel] + class VectorIndexerModelWriter(instance: VectorIndexerModel) extends MLWriter { + + private case class Data(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.numFeatures, instance.categoryMaps) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class VectorIndexerModelReader extends MLReader[VectorIndexerModel] { + + private val className = classOf[VectorIndexerModel].getName + + override def load(path: String): VectorIndexerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("numFeatures", "categoryMaps") + .head() + val numFeatures = data.getAs[Int](0) + val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) + val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[VectorIndexerModel] = new VectorIndexerModelReader + + @Since("1.6.0") + override def load(path: String): VectorIndexerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 708dbeef84db4..a8d61b6dea00b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -17,15 +17,17 @@ package org.apache.spark.ml.feature +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -90,7 +92,8 @@ private[feature] trait Word2VecBase extends Params * natural language processing or machine learning process. */ @Experimental -final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase { +final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("w2v")) @@ -139,6 +142,13 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) } +@Since("1.6.0") +object Word2Vec extends DefaultParamsReadable[Word2Vec] { + + @Since("1.6.0") + override def load(path: String): Word2Vec = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[Word2Vec]]. @@ -147,7 +157,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] class Word2VecModel private[ml] ( override val uid: String, @transient private val wordVectors: feature.Word2VecModel) - extends Model[Word2VecModel] with Word2VecBase { + extends Model[Word2VecModel] with Word2VecBase with MLWritable { + + import Word2VecModel._ /** * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and @@ -224,4 +236,49 @@ class Word2VecModel private[ml] ( val copied = new Word2VecModel(uid, wordVectors) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new Word2VecModelWriter(this) +} + +@Since("1.6.0") +object Word2VecModel extends MLReadable[Word2VecModel] { + + private[Word2VecModel] + class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter { + + private case class Data(wordIndex: Map[String, Int], wordVectors: Seq[Float]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.wordVectors.wordIndex, instance.wordVectors.wordVectors.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class Word2VecModelReader extends MLReader[Word2VecModel] { + + private val className = classOf[Word2VecModel].getName + + override def load(path: String): Word2VecModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("wordIndex", "wordVectors") + .head() + val wordIndex = data.getAs[Map[String, Int]](0) + val wordVectors = data.getAs[Seq[Float]](1).toArray + val oldModel = new feature.Word2VecModel(wordIndex, wordVectors) + val model = new Word2VecModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[Word2VecModel] = new Word2VecModelReader + + @Since("1.6.0") + override def load(path: String): Word2VecModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 7ab0d89d23a3f..a47f27b0afb14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -432,9 +432,9 @@ class Word2Vec extends Serializable with Logging { * (i * vectorSize, i * vectorSize + vectorSize) */ @Since("1.1.0") -class Word2VecModel private[mllib] ( - private val wordIndex: Map[String, Int], - private val wordVectors: Array[Float]) extends Serializable with Saveable { +class Word2VecModel private[spark] ( + private[spark] val wordIndex: Map[String, Int], + private[spark] val wordVectors: Array[Float]) extends Serializable with Saveable { private val numWords = wordIndex.size // vectorSize: Dimension of each word's vector. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index e5a42967bd2c8..7827db2794cf3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} -class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + test("Test Chi-Square selector") { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ @@ -58,4 +62,20 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(vec1 ~== vec2 absTol 1e-1) } } + + test("ChiSqSelector read/write") { + val t = new ChiSqSelector() + .setFeaturesCol("myFeaturesCol") + .setLabelCol("myLabelCol") + .setOutputCol("myOutputCol") + .setNumTopFeatures(2) + testDefaultReadWrite(t) + } + + test("ChiSqSelectorModel read/write") { + val oldModel = new feature.ChiSqSelectorModel(Array(1, 3)) + val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.selectedFeatures === instance.selectedFeatures) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 30c500f87a769..5a21cd20ceede 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} +import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel} import org.apache.spark.sql.Row -class PCASuite extends SparkFunSuite with MLlibTestSparkContext { +class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new PCA) @@ -65,4 +65,24 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } } + + test("read/write") { + + def checkModelData(model1: PCAModel, model2: PCAModel): Unit = { + assert(model1.pc === model2.pc) + } + val allParams: Map[String, Any] = Map( + "k" -> 3, + "inputCol" -> "features", + "outputCol" -> "pca_features" + ) + val data = Seq( + (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))), + (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + ) + val df = sqlContext.createDataFrame(data).toDF("id", "features") + val pca = new PCA().setK(3) + testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 8cb0a2cf14d37..67817fa4baf56 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -22,13 +22,14 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest with Logging { import VectorIndexerSuite.FeatureData @@ -251,6 +252,23 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L } } } + + test("VectorIndexer read/write") { + val t = new VectorIndexer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMaxCategories(30) + testDefaultReadWrite(t) + } + + test("VectorIndexerModel read/write") { + val categoryMaps = Map(0 -> Map(0.0 -> 0, 1.0 -> 1), 1 -> Map(0.0 -> 0, 1.0 -> 1, + 2.0 -> 2, 3.0 -> 3), 2 -> Map(0.0 -> 0, -1.0 -> 1, 2.0 -> 2)) + val instance = new VectorIndexerModel("myVectorIndexerModel", 3, categoryMaps) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.numFeatures === instance.numFeatures) + assert(newInstance.categoryMaps === instance.categoryMaps) + } } private[feature] object VectorIndexerSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 23dfdaa9f8fc6..a773244cd735e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} -class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Word2Vec) @@ -143,5 +143,31 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } } + + test("Word2Vec read/write") { + val t = new Word2Vec() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMaxIter(2) + .setMinCount(8) + .setNumPartitions(1) + .setSeed(42L) + .setStepSize(0.01) + .setVectorSize(100) + testDefaultReadWrite(t) + } + + test("Word2VecModel read/write") { + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val oldModel = new OldWord2VecModel(word2VecMap) + val instance = new Word2VecModel("myWord2VecModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.getVectors.collect() === instance.getVectors.collect()) + } } From 2a1f897f8948a3c8cb5dacd4fbf59bc724dec54a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 19 Nov 2015 22:14:01 -0800 Subject: [PATCH 0844/1168] [SPARK-11875][ML][PYSPARK] Update doc for PySpark HasCheckpointInterval * Update doc for PySpark ```HasCheckpointInterval``` that users can understand how to disable checkpoint. * Update doc for PySpark ```cacheNodeIds``` of ```DecisionTreeParams``` to notify the relationship between ```cacheNodeIds``` and ```checkpointInterval```. Author: Yanbo Liang Closes #9856 from yanboliang/spark-11875. --- python/pyspark/ml/param/_shared_params_code_gen.py | 6 ++++-- python/pyspark/ml/param/shared.py | 14 +++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 070c5db01ae73..0528dc1e3a6b9 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -118,7 +118,8 @@ def get$Name(self): ("inputCols", "input column names.", None), ("outputCol", "output column name.", "self.uid + '__output'"), ("numFeatures", "number of features.", None), - ("checkpointInterval", "checkpoint interval (>= 1).", None), + ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + + "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None), ("seed", "random seed.", "hash(type(self).__name__)"), ("tol", "the convergence tolerance for iterative algorithms.", None), ("stepSize", "Step size to be used for each iteration of optimization.", None), @@ -157,7 +158,8 @@ def get$Name(self): ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation."), ("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " + "instances with nodes. If true, the algorithm will cache node IDs for each instance. " + - "Caching can speed up training of deeper trees.")] + "Caching can speed up training of deeper trees. Users can set how often should the " + + "cache be checkpointed or disable it by setting checkpointInterval.")] decisionTreeCode = '''class DecisionTreeParams(Params): """ diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 4bdf2a8cc563f..4d960801502c2 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -325,16 +325,16 @@ def getNumFeatures(self): class HasCheckpointInterval(Params): """ - Mixin for param checkpointInterval: checkpoint interval (>= 1). + Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. """ # a placeholder to make it appear in the generated doc - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "checkpoint interval (>= 1).") + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") def __init__(self): super(HasCheckpointInterval, self).__init__() - #: param for checkpoint interval (>= 1). - self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1).") + #: param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + self.checkpointInterval = Param(self, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") def setCheckpointInterval(self, value): """ @@ -636,7 +636,7 @@ class DecisionTreeParams(Params): minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") + cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.") def __init__(self): @@ -651,8 +651,8 @@ def __init__(self): self.minInfoGain = Param(self, "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") #: param for Maximum memory in MB allocated to histogram aggregation. self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. - self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") + #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. + self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.") def setMaxDepth(self, value): """ From 798c825030ea43ac8a445911c1f7484b74179b23 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 19 Nov 2015 23:42:24 -0800 Subject: [PATCH 0845/1168] [SPARK-11869][ML] Clean up TempDirectory properly in ML tests Need to remove parent directory (```className```) rather than just tempDir (```className/random_name```) I tested this with IDFSuite, which has 2 read/write tests, and it fixes the problem. CC: mengxr Can you confirm this is fine? I believe it is since the same ```random_name``` is used for all tests in a suite; we basically have an extra unneeded level of nesting. Author: Joseph K. Bradley Closes #9851 from jkbradley/tempdir-cleanup. --- .../src/test/scala/org/apache/spark/ml/util/TempDirectory.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala index 2742026a69c2e..c8a0bb16247b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -35,7 +35,7 @@ trait TempDirectory extends BeforeAndAfterAll { self: Suite => override def beforeAll(): Unit = { super.beforeAll() - _tempDir = Utils.createTempDir(this.getClass.getName) + _tempDir = Utils.createTempDir(namePrefix = this.getClass.getName) } override def afterAll(): Unit = { From fc3109d65e236ea41c0decfa15bbcafe401edbaf Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 19 Nov 2015 23:43:18 -0800 Subject: [PATCH 0846/1168] [SPARK-11867] Add save/load for kmeans and naive bayes https://issues.apache.org/jira/browse/SPARK-11867 Author: Xusen Yin Closes #9849 from yinxusen/SPARK-11867. --- .../spark/ml/classification/NaiveBayes.scala | 68 +++++++++++++++++-- .../apache/spark/ml/clustering/KMeans.scala | 67 ++++++++++++++++-- .../ml/classification/NaiveBayesSuite.scala | 47 +++++++++++-- .../spark/ml/clustering/KMeansSuite.scala | 41 ++++++++--- 4 files changed, 195 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index a14dcecbaf5b9..c512a2cb8bf3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -17,12 +17,15 @@ package org.apache.spark.ml.classification +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} +import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -72,7 +75,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { @Experimental class NaiveBayes(override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] - with NaiveBayesParams { + with NaiveBayesParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("nb")) @@ -102,6 +105,13 @@ class NaiveBayes(override val uid: String) override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) } +@Since("1.6.0") +object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { + + @Since("1.6.0") + override def load(path: String): NaiveBayes = super.load(path) +} + /** * :: Experimental :: * Model produced by [[NaiveBayes]] @@ -114,7 +124,8 @@ class NaiveBayesModel private[ml] ( override val uid: String, val pi: Vector, val theta: Matrix) - extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { + extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] + with NaiveBayesParams with MLWritable { import OldNaiveBayes.{Bernoulli, Multinomial} @@ -203,12 +214,15 @@ class NaiveBayesModel private[ml] ( s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" } + @Since("1.6.0") + override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this) } -private[ml] object NaiveBayesModel { +@Since("1.6.0") +object NaiveBayesModel extends MLReadable[NaiveBayesModel] { /** Convert a model from the old API */ - def fromOld( + private[ml] def fromOld( oldModel: OldNaiveBayesModel, parent: NaiveBayes): NaiveBayesModel = { val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") @@ -218,4 +232,44 @@ private[ml] object NaiveBayesModel { oldModel.theta.flatten, true) new NaiveBayesModel(uid, pi, theta) } + + @Since("1.6.0") + override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader + + @Since("1.6.0") + override def load(path: String): NaiveBayesModel = super.load(path) + + /** [[MLWriter]] instance for [[NaiveBayesModel]] */ + private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter { + + private case class Data(pi: Vector, theta: Matrix) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: pi, theta + val data = Data(instance.pi, instance.theta) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[NaiveBayesModel].getName + + override def load(path: String): NaiveBayesModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head() + val pi = data.getAs[Vector](0) + val theta = data.getAs[Matrix](1) + val model = new NaiveBayesModel(metadata.uid, pi, theta) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 509be63002396..71e968497500f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,10 +17,12 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.util._ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -28,7 +30,6 @@ import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.{DataFrame, Row} - /** * Common params for KMeans and KMeansModel */ @@ -94,7 +95,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Experimental class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { + private val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with MLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -129,6 +131,52 @@ class KMeansModel private[ml] ( val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } parentModel.computeCost(data) } + + @Since("1.6.0") + override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) +} + +@Since("1.6.0") +object KMeansModel extends MLReadable[KMeansModel] { + + @Since("1.6.0") + override def read: MLReader[KMeansModel] = new KMeansModelReader + + @Since("1.6.0") + override def load(path: String): KMeansModel = super.load(path) + + /** [[MLWriter]] instance for [[KMeansModel]] */ + private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { + + private case class Data(clusterCenters: Array[Vector]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data = Data(instance.clusterCenters) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class KMeansModelReader extends MLReader[KMeansModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[KMeansModel].getName + + override def load(path: String): KMeansModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head() + val clusterCenters = data.getAs[Seq[Vector]](0).toArray + val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** @@ -141,7 +189,7 @@ class KMeansModel private[ml] ( @Experimental class KMeans @Since("1.5.0") ( @Since("1.5.0") override val uid: String) - extends Estimator[KMeansModel] with KMeansParams { + extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { setDefault( k -> 2, @@ -210,3 +258,10 @@ class KMeans @Since("1.5.0") ( } } +@Since("1.6.0") +object KMeans extends DefaultParamsReadable[KMeans] { + + @Since("1.6.0") + override def load(path: String): KMeans = super.load(path) +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 98bc9511163e7..082a6bcd211ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -21,15 +21,30 @@ import breeze.linalg.{Vector => BV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial} +import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.classification.NaiveBayesSuite._ -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} + +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) -class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42)) + } def validatePrediction(predictionAndLabels: DataFrame): Unit = { val numOfErrorPredictions = predictionAndLabels.collect().count { @@ -161,4 +176,26 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { .select("features", "probability") validateProbabilities(featureAndProbabilities, model, "bernoulli") } + + test("read/write") { + def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = { + assert(model.pi === model2.pi) + assert(model.theta === model2.theta) + } + val nb = new NaiveBayes() + testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) + } +} + +object NaiveBayesSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "smoothing" -> 0.1 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c05f90550d161..2724e51f31aa4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -25,16 +26,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} private[clustering] case class TestRow(features: Vector) -object KMeansSuite { - def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { - val sc = sql.sparkContext - val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) - .map(v => new TestRow(v)) - sql.createDataFrame(rdd) - } -} - -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 @transient var dataset: DataFrame = _ @@ -106,4 +98,33 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) } + + test("read/write") { + def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { + assert(model.clusterCenters === model2.clusterCenters) + } + val kmeans = new KMeans() + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) + } +} + +object KMeansSuite { + def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { + val sc = sql.sparkContext + val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) + .map(v => new TestRow(v)) + sql.createDataFrame(rdd) + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "k" -> 3, + "maxIter" -> 2, + "tol" -> 0.01 + ) } From f033d48af20c1d9e0cc6b5069a5857a3f4191c96 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 20 Nov 2015 00:46:29 -0800 Subject: [PATCH 0847/1168] [SPARK-11877] Prevent agg. fallback conf. from leaking across test suites This patch fixes an issue where the `spark.sql.TungstenAggregate.testFallbackStartsAt` SQLConf setting was not properly reset / cleared at the end of `TungstenAggregationQueryWithControlledFallbackSuite`. This ended up causing test failures in HiveCompatibilitySuite in Maven builds by causing spilling to occur way too frequently. This configuration leak was inadvertently introduced during test cleanup in #9618. Author: Josh Rosen Closes #9857 from JoshRosen/clear-fallback-prop-in-test-teardown. --- .../execution/AggregationQuerySuite.scala | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 6dde79f74d3d8..39c0a2a0de045 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -868,29 +868,27 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { (0 to 2).foreach { fallbackStartsAt => - sqlContext.setConf( - "spark.sql.TungstenAggregate.testFallbackStartsAt", - fallbackStartsAt.toString) - - // Create a new df to make sure its physical operator picks up - // spark.sql.TungstenAggregate.testFallbackStartsAt. - // todo: remove it? - val newActual = DataFrame(sqlContext, actual.logicalPlan) - - QueryTest.checkAnswer(newActual, expectedAnswer) match { - case Some(errorMessage) => - val newErrorMessage = - s""" - |The following aggregation query failed when using TungstenAggregate with - |controlled fallback (it falls back to sort-based aggregation once it has processed - |$fallbackStartsAt input rows). The query is - |${actual.queryExecution} - | - |$errorMessage - """.stripMargin - - fail(newErrorMessage) - case None => + withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> fallbackStartsAt.toString) { + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + // todo: remove it? + val newActual = DataFrame(sqlContext, actual.logicalPlan) + + QueryTest.checkAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using TungstenAggregate with + |controlled fallback (it falls back to sort-based aggregation once it has processed + |$fallbackStartsAt input rows). The query is + |${actual.queryExecution} + | + |$errorMessage + """.stripMargin + + fail(newErrorMessage) + case None => + } } } } From f88ae67b1cbd38a2fc5d92838b5ad9a0222e6c6c Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 20 Nov 2015 09:55:53 -0800 Subject: [PATCH 0848/1168] [SPARK-11852][ML] StandardScaler minor refactor ```withStd``` and ```withMean``` should be params of ```StandardScaler``` and ```StandardScalerModel```. Author: Yanbo Liang Closes #9839 from yanboliang/standardScaler-refactor. --- .../spark/ml/feature/StandardScaler.scala | 60 +++++++++---------- .../ml/feature/StandardScalerSuite.scala | 11 ++-- 2 files changed, 32 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 6d545219ebf49..d76a9c6275e6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -36,20 +36,30 @@ import org.apache.spark.sql.types.{StructField, StructType} private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol { /** - * Centers the data with mean before scaling. + * Whether to center the data with mean before scaling. * It will build a dense output, so this does not work on sparse input * and will raise an exception. * Default: false * @group param */ - val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") + val withMean: BooleanParam = new BooleanParam(this, "withMean", + "Whether to center data with mean") + + /** @group getParam */ + def getWithMean: Boolean = $(withMean) /** - * Scales the data to unit standard deviation. + * Whether to scale the data to unit standard deviation. * Default: true * @group param */ - val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation") + val withStd: BooleanParam = new BooleanParam(this, "withStd", + "Whether to scale the data to unit standard deviation") + + /** @group getParam */ + def getWithStd: Boolean = $(withStd) + + setDefault(withMean -> false, withStd -> true) } /** @@ -63,8 +73,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM def this() = this(Identifiable.randomUID("stdScal")) - setDefault(withMean -> false, withStd -> true) - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -82,7 +90,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) val scalerModel = scaler.fit(input) - copyValues(new StandardScalerModel(uid, scalerModel).setParent(this)) + copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -108,29 +116,19 @@ object StandardScaler extends DefaultParamsReadable[StandardScaler] { /** * :: Experimental :: * Model fitted by [[StandardScaler]]. + * + * @param std Standard deviation of the StandardScalerModel + * @param mean Mean of the StandardScalerModel */ @Experimental class StandardScalerModel private[ml] ( override val uid: String, - scaler: feature.StandardScalerModel) + val std: Vector, + val mean: Vector) extends Model[StandardScalerModel] with StandardScalerParams with MLWritable { import StandardScalerModel._ - /** Standard deviation of the StandardScalerModel */ - val std: Vector = scaler.std - - /** Mean of the StandardScalerModel */ - val mean: Vector = scaler.mean - - /** Whether to scale to unit standard deviation. */ - @Since("1.6.0") - def getWithStd: Boolean = scaler.withStd - - /** Whether to center data with mean. */ - @Since("1.6.0") - def getWithMean: Boolean = scaler.withMean - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -139,6 +137,7 @@ class StandardScalerModel private[ml] ( override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) + val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) val scale = udf { scaler.transform _ } dataset.withColumn($(outputCol), scale(col($(inputCol)))) } @@ -154,7 +153,7 @@ class StandardScalerModel private[ml] ( } override def copy(extra: ParamMap): StandardScalerModel = { - val copied = new StandardScalerModel(uid, scaler) + val copied = new StandardScalerModel(uid, std, mean) copyValues(copied, extra).setParent(parent) } @@ -168,11 +167,11 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { private[StandardScalerModel] class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter { - private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) + private case class Data(std: Vector, mean: Vector) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean) + val data = Data(instance.std, instance.mean) val dataPath = new Path(path, "data").toString sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } @@ -185,13 +184,10 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { override def load(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) = - sqlContext.read.parquet(dataPath) - .select("std", "mean", "withStd", "withMean") - .head() - // This is very likely to change in the future because withStd and withMean should be params. - val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean) - val model = new StandardScalerModel(metadata.uid, oldModel) + val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath) + .select("std", "mean") + .head() + val model = new StandardScalerModel(metadata.uid, std, mean) DefaultParamsReader.getAndSetParams(model, metadata) model } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 49a4b2efe0c29..1eae125a524ef 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -70,8 +70,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext test("params") { ParamsSuite.checkParams(new StandardScaler) - val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0)) - ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel)) + ParamsSuite.checkParams(new StandardScalerModel("empty", + Vectors.dense(1.0), Vectors.dense(2.0))) } test("Standardization with default parameter") { @@ -126,13 +126,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext } test("StandardScalerModel read/write") { - val oldModel = new feature.StandardScalerModel( - Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true) - val instance = new StandardScalerModel("myStandardScalerModel", oldModel) + val instance = new StandardScalerModel("myStandardScalerModel", + Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0)) val newInstance = testDefaultReadWrite(instance) assert(newInstance.std === instance.std) assert(newInstance.mean === instance.mean) - assert(newInstance.getWithStd === instance.getWithStd) - assert(newInstance.getWithMean === instance.getWithMean) } } From 166a1b202403ebe0a28685fd003ab6027b4e01a7 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 20 Nov 2015 09:57:09 -0800 Subject: [PATCH 0849/1168] [SPARK-11689][ML] Add user guide and example code for LDA under spark.ml jira: https://issues.apache.org/jira/browse/SPARK-11689 Add simple user guide for LDA under spark.ml and example code under examples/. Use include_example to include example code in the user guide markdown. Check SPARK-11606 for instructions. Author: Yuhao Yang Closes #9722 from hhbyyh/ldaMLExample. --- docs/ml-clustering.md | 30 ++++++ docs/ml-guide.md | 3 +- docs/mllib-guide.md | 1 + .../spark/examples/ml/JavaLDAExample.java | 94 +++++++++++++++++++ .../apache/spark/examples/ml/LDAExample.scala | 77 +++++++++++++++ 5 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 docs/ml-clustering.md create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md new file mode 100644 index 0000000000000..1743ef43a6ddf --- /dev/null +++ b/docs/ml-clustering.md @@ -0,0 +1,30 @@ +--- +layout: global +title: Clustering - ML +displayTitle: ML - Clustering +--- + +In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). + +## Latent Dirichlet allocation (LDA) + +`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, +and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by +`EMLDAOptimizer` to a `DistributedLDAModel` if needed. + +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. + +
    +{% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} +
    + +
    \ No newline at end of file diff --git a/docs/ml-guide.md b/docs/ml-guide.md index be18a05361a17..6f35b30c3d4df 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -40,6 +40,7 @@ Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., provide class probabilities, and linear models provide model summaries. * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision Trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) @@ -950,4 +951,4 @@ model.transform(test) {% endhighlight %}
    -
    +
    \ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 91e50ccfecec4..54e35fcbb15af 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -69,6 +69,7 @@ We list major functionality from both below, with links to detailed guides. concepts. It also contains sections on using algorithms within the Pipelines API, for example: * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java new file mode 100644 index 0000000000000..b3a7d2eb29780 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.regex.Pattern; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.clustering.LDA; +import org.apache.spark.ml.clustering.LDAModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * An example demonstrating LDA + * Run with + *
    + * bin/run-example ml.JavaLDAExample
    + * 
    + */ +public class JavaLDAExample { + + private static class ParseVector implements Function { + private static final Pattern separator = Pattern.compile(" "); + + @Override + public Row call(String line) { + String[] tok = separator.split(line); + double[] point = new double[tok.length]; + for (int i = 0; i < tok.length; ++i) { + point[i] = Double.parseDouble(tok[i]); + } + Vector[] points = {Vectors.dense(point)}; + return new GenericRow(points); + } + } + + public static void main(String[] args) { + + String inputFile = "data/mllib/sample_lda_data.txt"; + + // Parses the arguments + SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Loads data + JavaRDD points = jsc.textFile(inputFile).map(new ParseVector()); + StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; + StructType schema = new StructType(fields); + DataFrame dataset = sqlContext.createDataFrame(points, schema); + + // Trains a LDA model + LDA lda = new LDA() + .setK(10) + .setMaxIter(10); + LDAModel model = lda.fit(dataset); + + System.out.println(model.logLikelihood(dataset)); + System.out.println(model.logPerplexity(dataset)); + + // Shows the result + DataFrame topics = model.describeTopics(3); + topics.show(false); + model.transform(dataset).show(false); + + jsc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala new file mode 100644 index 0000000000000..419ce3d87a6ac --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +// $example on$ +import org.apache.spark.ml.clustering.LDA +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.types.{StructField, StructType} +// $example off$ + +/** + * An example demonstrating a LDA of ML pipeline. + * Run with + * {{{ + * bin/run-example ml.LDAExample + * }}} + */ +object LDAExample { + + final val FEATURES_COL = "features" + + def main(args: Array[String]): Unit = { + + val input = "data/mllib/sample_lda_data.txt" + // Creates a Spark context and a SQL context + val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Loads data + val rowRDD = sc.textFile(input).filter(_.nonEmpty) + .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) + val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) + val dataset = sqlContext.createDataFrame(rowRDD, schema) + + // Trains a LDA model + val lda = new LDA() + .setK(10) + .setMaxIter(10) + .setFeaturesCol(FEATURES_COL) + val model = lda.fit(dataset) + val transformed = model.transform(dataset) + + val ll = model.logLikelihood(dataset) + val lp = model.logPerplexity(dataset) + + // describeTopics + val topics = model.describeTopics(3) + + // Shows the result + topics.show(false) + transformed.show(false) + + // $example off$ + sc.stop() + } +} +// scalastyle:on println From 9b35667e92076244151d7398db3b135e054e71e1 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 20 Nov 2015 11:20:47 -0800 Subject: [PATCH 0850/1168] [SPARK-11876][SQL] Support printSchema in DataSet API DataSet APIs look great! However, I am lost when doing multiple level joins. For example, ``` val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a") val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b") val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c") ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2").printSchema() ``` The printed schema is like ``` root |-- _1: struct (nullable = true) | |-- _1: struct (nullable = true) | | |-- _1: string (nullable = true) | | |-- _2: integer (nullable = true) | |-- _2: struct (nullable = true) | | |-- _1: string (nullable = true) | | |-- _2: integer (nullable = true) |-- _2: struct (nullable = true) | |-- _1: string (nullable = true) | |-- _2: integer (nullable = true) ``` Personally, I think we need the printSchema function. Sometimes, I do not know how to specify the column, especially when their data types are mixed. For example, if I want to write the following select for the above multi-level join, I have to know the schema: ``` newDS.select(expr("_1._2._2 + 1").as[Int]).collect() ``` marmbrus rxin cloud-fan Do you have the same feeling? Author: gatorsmile Closes #9855 from gatorsmile/printSchemaDataSet. --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 9 --------- .../scala/org/apache/spark/sql/execution/Queryable.scala | 9 +++++++++ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 98358127e2709..7abcecaa2880e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -299,15 +299,6 @@ class DataFrame private[sql]( */ def columns: Array[String] = schema.fields.map(_.name) - /** - * Prints the schema to the console in a nice tree format. - * @group basic - * @since 1.3.0 - */ - // scalastyle:off println - def printSchema(): Unit = println(schema.treeString) - // scalastyle:on println - /** * Returns true if the `collect` and `take` methods can be run locally * (without any Spark executors). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index e86a52c149a2f..321e2c783537f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -37,6 +37,15 @@ private[sql] trait Queryable { } } + /** + * Prints the schema to the console in a nice tree format. + * @group basic + * @since 1.3.0 + */ + // scalastyle:off println + def printSchema(): Unit = println(schema.treeString) + // scalastyle:on println + /** * Prints the plans (logical and physical) to the console for debugging purposes. * @since 1.3.0 From e81e26f0d8038ba7d860a1719a4d19e08b5c7d54 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 20 Nov 2015 11:43:45 -0800 Subject: [PATCH 0851/1168] [SPARK-11817][SQL] Truncating the fractional seconds to prevent inserting a NULL JIRA: https://issues.apache.org/jira/browse/SPARK-11817 Instead of return None, we should truncate the fractional seconds to prevent inserting NULL. Author: Liang-Chi Hsieh Closes #9834 from viirya/truncate-fractional-sec. --- .../apache/spark/sql/catalyst/util/DateTimeUtils.scala | 5 +++++ .../spark/sql/catalyst/util/DateTimeUtilsSuite.scala | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 17a5527f3fb29..2b93882919487 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -327,6 +327,11 @@ object DateTimeUtils { return None } + // Instead of return None, we truncate the fractional seconds to prevent inserting NULL + if (segments(6) > 999999) { + segments(6) = segments(6).toString.take(6).toInt + } + if (segments(3) < 0 || segments(3) > 23 || segments(4) < 0 || segments(4) > 59 || segments(5) < 0 || segments(5) > 59 || segments(6) < 0 || segments(6) > 999999 || segments(7) < 0 || segments(7) > 23 || segments(8) < 0 || segments(8) > 59) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index faca128badfd6..0ce5a2fb69505 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -343,6 +343,14 @@ class DateTimeUtilsSuite extends SparkFunSuite { UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty) assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) + + // Truncating the fractional seconds + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+00:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123456789+0:00")).get === + c.getTimeInMillis * 1000 + 123456) } test("hours") { From 7a45b10474259267319bb01124a70a8eb24622b0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Nov 2015 12:04:42 -0800 Subject: [PATCH 0852/1168] [SPARK-11819][SQL] nice error message for missing encoder before this PR, when users try to get an encoder for an un-supported class, they will only get a very simple error message like `Encoder for type xxx is not supported`. After this PR, the error message become more friendly, for example: ``` No Encoder found for abc.xyz.NonEncodable - array element class: "abc.xyz.NonEncodable" - field (class: "scala.Array", name: "arrayField") - root class: "abc.xyz.AnotherClass" ``` Author: Wenchen Fan Closes #9810 from cloud-fan/error-message. --- .../spark/sql/catalyst/ScalaReflection.scala | 90 ++++++++++++++----- .../encoders/EncoderErrorMessageSuite.scala | 62 +++++++++++++ 2 files changed, 129 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 33ae700706dae..918050b531c02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -63,7 +63,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< definitions.BooleanTpe => BooleanType case t if t <:< localTypeOf[Array[Byte]] => BinaryType case _ => - val className: String = tpe.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(tpe) className match { case "scala.Array" => val TypeRef(_, _, Seq(elementType)) = tpe @@ -320,9 +320,23 @@ object ScalaReflection extends ScalaReflection { } } - /** Returns expressions for extracting all the fields from the given type. */ + /** + * Returns expressions for extracting all the fields from the given type. + * + * If the given type is not supported, i.e. there is no encoder can be built for this type, + * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain + * the type path walked so far and which class we are not supporting. + * There are 4 kinds of type path: + * * the root type: `root class: "abc.xyz.MyClass"` + * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"` + * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` + * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` + */ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { - extractorFor(inputObject, localTypeOf[T]) match { + val tpe = localTypeOf[T] + val clsName = getClassNameFromType(tpe) + val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + extractorFor(inputObject, tpe, walkedTypePath) match { case s: CreateNamedStruct => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } @@ -331,7 +345,28 @@ object ScalaReflection extends ScalaReflection { /** Helper for extracting internal fields from a case class. */ private def extractorFor( inputObject: Expression, - tpe: `Type`): Expression = ScalaReflectionLock.synchronized { + tpe: `Type`, + walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { + + def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = silentSchemaFor(elementType) + if (isNativeType(catalystType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + val clsName = getClassNameFromType(elementType) + val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath + // `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here + // to trigger the type check. + extractorFor(inputObject, elementType, newPath) + + MapObjects(extractorFor(_, elementType, newPath), input, externalDataType) + } + } + if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { @@ -378,15 +413,16 @@ object ScalaReflection extends ScalaReflection { // For non-primitives, we can just extract the object from the Option and then recurse. case other => - val className: String = optType.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(optType) val classObj = Utils.classForName(className) val optionObjectType = ObjectType(classObj) + val newPath = s"""- option value class: "$className"""" +: walkedTypePath val unwrapped = UnwrapOption(optionObjectType, inputObject) expressions.If( IsNull(unwrapped), - expressions.Literal.create(null, schemaFor(optType).dataType), - extractorFor(unwrapped, optType)) + expressions.Literal.create(null, silentSchemaFor(optType).dataType), + extractorFor(unwrapped, optType, newPath)) } case t if t <:< localTypeOf[Product] => @@ -412,7 +448,10 @@ object ScalaReflection extends ScalaReflection { val fieldName = p.name.toString val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + val clsName = getClassNameFromType(fieldType) + val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil }) case t if t <:< localTypeOf[Array[_]] => @@ -500,23 +539,11 @@ object ScalaReflection extends ScalaReflection { Invoke(inputObject, "booleanValue", BooleanType) case other => - throw new UnsupportedOperationException(s"Extractor for type $other is not supported") + throw new UnsupportedOperationException( + s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } } } - - private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { - val externalDataType = dataTypeFor(elementType) - val Schema(catalystType, nullable) = schemaFor(elementType) - if (isNativeType(catalystType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(catalystType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), input, externalDataType) - } - } } /** @@ -561,7 +588,7 @@ trait ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { - val className: String = tpe.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(tpe) tpe match { case t if Utils.classIsLoadable(className) && Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => @@ -637,6 +664,23 @@ trait ScalaReflection { } } + /** + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + * + * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return + * `NullType` silently instead. + */ + private def silentSchemaFor(tpe: `Type`): Schema = try { + schemaFor(tpe) + } catch { + case _: UnsupportedOperationException => Schema(NullType, nullable = true) + } + + /** Returns the full class name for a type. */ + private def getClassNameFromType(tpe: `Type`): String = { + tpe.erasure.typeSymbol.asClass.fullName + } + /** * Returns classes of input parameters of scala function object. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala index 0b2a10bb04c10..8c766ef829923 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -17,9 +17,22 @@ package org.apache.spark.sql.catalyst.encoders +import scala.reflect.ClassTag + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders +class NonEncodable(i: Int) + +case class ComplexNonEncodable1(name1: NonEncodable) + +case class ComplexNonEncodable2(name2: ComplexNonEncodable1) + +case class ComplexNonEncodable3(name3: Option[NonEncodable]) + +case class ComplexNonEncodable4(name4: Array[NonEncodable]) + +case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]]) class EncoderErrorMessageSuite extends SparkFunSuite { @@ -37,4 +50,53 @@ class EncoderErrorMessageSuite extends SparkFunSuite { intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] } intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] } } + + test("nice error message for missing encoder") { + val errorMsg1 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage + assert(errorMsg1.contains( + s"""root class: "${clsName[ComplexNonEncodable1]}"""")) + assert(errorMsg1.contains( + s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) + + val errorMsg2 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage + assert(errorMsg2.contains( + s"""root class: "${clsName[ComplexNonEncodable2]}"""")) + assert(errorMsg2.contains( + s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")""")) + assert(errorMsg1.contains( + s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) + + val errorMsg3 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage + assert(errorMsg3.contains( + s"""root class: "${clsName[ComplexNonEncodable3]}"""")) + assert(errorMsg3.contains( + s"""field (class: "scala.Option", name: "name3")""")) + assert(errorMsg3.contains( + s"""option value class: "${clsName[NonEncodable]}"""")) + + val errorMsg4 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage + assert(errorMsg4.contains( + s"""root class: "${clsName[ComplexNonEncodable4]}"""")) + assert(errorMsg4.contains( + s"""field (class: "scala.Array", name: "name4")""")) + assert(errorMsg4.contains( + s"""array element class: "${clsName[NonEncodable]}"""")) + + val errorMsg5 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage + assert(errorMsg5.contains( + s"""root class: "${clsName[ComplexNonEncodable5]}"""")) + assert(errorMsg5.contains( + s"""field (class: "scala.Option", name: "name5")""")) + assert(errorMsg5.contains( + s"""option value class: "scala.Array"""")) + assert(errorMsg5.contains( + s"""array element class: "${clsName[NonEncodable]}"""")) + } + + private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName } From 44adab79a5deb1163610a2c9889e40bbe50d5186 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 20 Nov 2015 13:17:35 -0800 Subject: [PATCH 0853/1168] [SPARK-11650] Reduce RPC timeouts to speed up slow AkkaUtilsSuite test This patch reduces some RPC timeouts in order to speed up the slow "AkkaUtilsSuite.remote fetch ssl on - untrusted server", which used to take two minutes to run. Author: Josh Rosen Closes #9869 from JoshRosen/SPARK-11650. --- core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 61601016e005e..0af4b6098bb0a 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -340,10 +340,11 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() + .set("spark.rpc.askTimeout", "5s") + .set("spark.rpc.lookupTimeout", "5s") val securityManagerBad = new SecurityManager(slaveConf) val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) try { slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) fail("should receive either ActorNotFound or TimeoutException") From cf2d6fede8d1e888717a4db317f150c7956b3562 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 20 Nov 2015 14:19:34 -0800 Subject: [PATCH 0854/1168] [SPARK-11724][SQL] Change casting between int and timestamp to consistently treat int in seconds. Hive has since changed this behavior as well. https://issues.apache.org/jira/browse/HIVE-3454 Author: Nong Li Author: Nong Li Author: Yin Huai Closes #9685 from nongli/spark-11724. --- .../spark/sql/catalyst/expressions/Cast.scala | 6 ++-- .../sql/catalyst/expressions/CastSuite.scala | 16 +++++---- .../apache/spark/sql/DateFunctionsSuite.scala | 3 ++ ...esting-0-237a6af90a857da1efcbe98f6bbbf9d6} | 2 +- ... cast #3-0-76ee270337f664b36cacfc6528ac109 | 1 - ...cast #5-0-dbd7bcd167d322d6617b884c02c7f247 | 1 - ...cast #7-0-1d70654217035f8ce5f64344f4c5a80f | 1 - .../sql/hive/execution/HiveQuerySuite.scala | 34 +++++++++++++------ 8 files changed, 39 insertions(+), 25 deletions(-) rename sql/hive/src/test/resources/golden/{constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 => constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6} (52%) delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5564e242b0472..533d17ea5c172 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -204,8 +204,8 @@ case class Cast(child: Expression, dataType: DataType) if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong } - // converting milliseconds to us - private[this] def longToTimestamp(t: Long): Long = t * 1000L + // converting seconds to us + private[this] def longToTimestamp(t: Long): Long = t * 1000000L // converting us to seconds private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 1000000L).toLong // converting us to seconds in double @@ -647,7 +647,7 @@ case class Cast(child: Expression, dataType: DataType) private[this] def decimalToTimestampCode(d: String): String = s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" - private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L" + private[this] def longToTimeStampCode(l: String): String = s"$l * 1000000L" private[this] def timestampToIntegerCode(ts: String): String = s"java.lang.Math.floor((double) $ts / 1000000L)" private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index f4db4da7646f8..ab77a764483e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -258,8 +258,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from int 2") { checkEvaluation(cast(1, LongType), 1.toLong) - checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) - checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) + checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong) + checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong) checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) @@ -348,14 +348,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), - 0.toShort) + 5.toShort) checkEvaluation( cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), null) checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), ByteType), TimestampType), LongType), StringType), ShortType), - 0.toShort) + 5.toShort) checkEvaluation(cast("23", DoubleType), 23d) checkEvaluation(cast("23", IntegerType), 23) @@ -479,10 +479,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(ts, LongType), 15.toLong) checkEvaluation(cast(ts, FloatType), 15.003f) checkEvaluation(cast(ts, DoubleType), 15.003) - checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts) * 1000) checkEvaluation(cast(cast(tss, IntegerType), TimestampType), - DateTimeUtils.fromJavaTimestamp(ts)) - checkEvaluation(cast(cast(tss, LongType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + DateTimeUtils.fromJavaTimestamp(ts) * 1000) + checkEvaluation(cast(cast(tss, LongType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts) * 1000) checkEvaluation( cast(cast(millis.toFloat / 1000, TimestampType), FloatType), millis.toFloat / 1000) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 241cbd0115070..a61c3aa48a73f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -448,6 +448,9 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + + val now = sql("select unix_timestamp()").collect().head.getLong(0) + checkAnswer(sql(s"select cast ($now as timestamp)"), Row(new java.util.Date(now * 1000))) } test("to_unix_timestamp") { diff --git a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 b/sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 similarity index 52% rename from sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 rename to sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 index 7c41615f8c184..a01c2622c68e2 100644 --- a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 +++ b/sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 @@ -1 +1 @@ -1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL 1969-12-31 16:00:00.001 NULL 1 NULL +1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL NULL 1 NULL diff --git a/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 b/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 deleted file mode 100644 index d00491fd7e5bb..0000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 +++ /dev/null @@ -1 +0,0 @@ -1 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 deleted file mode 100644 index 84a31a5a6970b..0000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 +++ /dev/null @@ -1 +0,0 @@ --0.001 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f b/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f deleted file mode 100644 index 3fbedf693b51d..0000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f +++ /dev/null @@ -1 +0,0 @@ --2 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index f0a7a6cc7a1e3..8a5acaf3e10bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io.File +import java.sql.Timestamp import java.util.{Locale, TimeZone} import scala.util.Try @@ -248,12 +249,17 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |IF(TRUE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL18, |IF(FALSE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL19, |IF(TRUE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL20, - |IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, - |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL22, - |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23, - |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL24 + |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, + |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL22, + |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23 |FROM src LIMIT 1""".stripMargin) + test("constant null testing timestamp") { + val r1 = sql("SELECT IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL20") + .collect().head + assert(new Timestamp(1000) == r1.getTimestamp(0)) + } + createQueryTest("constant array", """ |SELECT sort_array( @@ -603,26 +609,32 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Jdk version leads to different query output for double, so not use createQueryTest here test("timestamp cast #1") { val res = sql("SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head - assert(0.001 == res.getDouble(0)) + assert(1 == res.getDouble(0)) } createQueryTest("timestamp cast #2", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #3", - "SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + test("timestamp cast #3") { + val res = sql("SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head + assert(1200 == res.getInt(0)) + } createQueryTest("timestamp cast #4", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #5", - "SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + test("timestamp cast #5") { + val res = sql("SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head + assert(-1 == res.get(0)) + } createQueryTest("timestamp cast #6", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #7", - "SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + test("timestamp cast #7") { + val res = sql("SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head + assert(-1200 == res.getInt(0)) + } createQueryTest("timestamp cast #8", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") From b7633c880255e6cd4a1c3b74ef77d66770c445e8 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 20 Nov 2015 14:23:01 -0800 Subject: [PATCH 0855/1168] [SPARK-11870][STREAMING][PYSPARK] Rethrow the exceptions in TransformFunction and TransformFunctionSerializer TransformFunction and TransformFunctionSerializer don't rethrow the exception, so when any exception happens, it just return None. This will cause some weird NPE and confuse people. Author: Shixiong Zhu Closes #9847 from zsxwing/pyspark-streaming-exception. --- python/pyspark/streaming/tests.py | 16 ++++++++++++++++ python/pyspark/streaming/util.py | 3 +++ 2 files changed, 19 insertions(+) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 3403f6d20d789..a0e0267cafa58 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -403,6 +403,22 @@ def func(dstream): expected = [[('k', v)] for v in expected] self._test_func(input, func, expected) + def test_failed_func(self): + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream = self.ssc.queueStream(input) + + def failed_func(i): + raise ValueError("failed") + + input_stream.map(failed_func).pprint() + self.ssc.start() + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + return + + self.fail("a failed func should throw an error") + class StreamingListenerTests(PySparkStreamingTestCase): diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index b20613b1283bd..767c732eb90b4 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -64,6 +64,7 @@ def call(self, milliseconds, jrdds): return r._jrdd except Exception: traceback.print_exc() + raise def __repr__(self): return "TransformFunction(%s)" % self.func @@ -95,6 +96,7 @@ def dumps(self, id): return bytearray(self.serializer.dumps((func.func, func.deserializers))) except Exception: traceback.print_exc() + raise def loads(self, data): try: @@ -102,6 +104,7 @@ def loads(self, data): return TransformFunction(self.ctx, f, *deserializers) except Exception: traceback.print_exc() + raise def __repr__(self): return "TransformFunctionSerializer(%s)" % self.serializer From e7f4c6fc80212ee19fff861763bdb165480f1cd5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 20 Nov 2015 14:31:26 -0800 Subject: [PATCH 0856/1168] [SPARK-11887] Close PersistenceEngine at the end of PersistenceEngineSuite tests In PersistenceEngineSuite, we do not call `close()` on the PersistenceEngine at the end of the test. For the ZooKeeperPersistenceEngine, this causes us to leak a ZooKeeper client, causing the logs of unrelated tests to be periodically spammed with connection error messages from that client: ``` 15/11/20 05:13:35.789 pool-1-thread-1-ScalaTest-running-PersistenceEngineSuite-SendThread(localhost:15741) INFO ClientCnxn: Opening socket connection to server localhost/127.0.0.1:15741. Will not attempt to authenticate using SASL (unknown error) 15/11/20 05:13:35.790 pool-1-thread-1-ScalaTest-running-PersistenceEngineSuite-SendThread(localhost:15741) WARN ClientCnxn: Session 0x15124ff48dd0000 for server null, unexpected error, closing socket connection and attempting reconnect java.net.ConnectException: Connection refused at sun.nio.ch.SocketChannelImpl.checkConnect(Native Method) at sun.nio.ch.SocketChannelImpl.finishConnect(SocketChannelImpl.java:739) at org.apache.zookeeper.ClientCnxnSocketNIO.doTransport(ClientCnxnSocketNIO.java:350) at org.apache.zookeeper.ClientCnxn$SendThread.run(ClientCnxn.java:1068) ``` This patch fixes this by using a `finally` block. Author: Josh Rosen Closes #9864 from JoshRosen/close-zookeeper-client-in-tests. --- .../master/PersistenceEngineSuite.scala | 100 +++++++++--------- 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala index 34775577de8a3..7a44728675680 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -63,56 +63,60 @@ class PersistenceEngineSuite extends SparkFunSuite { conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = { val serializer = new JavaSerializer(conf) val persistenceEngine = persistenceEngineCreator(serializer) - persistenceEngine.persist("test_1", "test_1_value") - assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) - persistenceEngine.persist("test_2", "test_2_value") - assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) - persistenceEngine.unpersist("test_1") - assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) - persistenceEngine.unpersist("test_2") - assert(persistenceEngine.read[String]("test_").isEmpty) - - // Test deserializing objects that contain RpcEndpointRef - val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) try { - // Create a real endpoint so that we can test RpcEndpointRef deserialization - val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint { - override val rpcEnv: RpcEnv = testRpcEnv - }) - - val workerToPersist = new WorkerInfo( - id = "test_worker", - host = "127.0.0.1", - port = 10000, - cores = 0, - memory = 0, - endpoint = workerEndpoint, - webUiPort = 0, - publicAddress = "" - ) - - persistenceEngine.addWorker(workerToPersist) - - val (storedApps, storedDrivers, storedWorkers) = - persistenceEngine.readPersistedData(testRpcEnv) - - assert(storedApps.isEmpty) - assert(storedDrivers.isEmpty) - - // Check deserializing WorkerInfo - assert(storedWorkers.size == 1) - val recoveryWorkerInfo = storedWorkers.head - assert(workerToPersist.id === recoveryWorkerInfo.id) - assert(workerToPersist.host === recoveryWorkerInfo.host) - assert(workerToPersist.port === recoveryWorkerInfo.port) - assert(workerToPersist.cores === recoveryWorkerInfo.cores) - assert(workerToPersist.memory === recoveryWorkerInfo.memory) - assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) - assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) - assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + persistenceEngine.persist("test_1", "test_1_value") + assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.persist("test_2", "test_2_value") + assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) + persistenceEngine.unpersist("test_1") + assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.unpersist("test_2") + assert(persistenceEngine.read[String]("test_").isEmpty) + + // Test deserializing objects that contain RpcEndpointRef + val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + try { + // Create a real endpoint so that we can test RpcEndpointRef deserialization + val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = testRpcEnv + }) + + val workerToPersist = new WorkerInfo( + id = "test_worker", + host = "127.0.0.1", + port = 10000, + cores = 0, + memory = 0, + endpoint = workerEndpoint, + webUiPort = 0, + publicAddress = "" + ) + + persistenceEngine.addWorker(workerToPersist) + + val (storedApps, storedDrivers, storedWorkers) = + persistenceEngine.readPersistedData(testRpcEnv) + + assert(storedApps.isEmpty) + assert(storedDrivers.isEmpty) + + // Check deserializing WorkerInfo + assert(storedWorkers.size == 1) + val recoveryWorkerInfo = storedWorkers.head + assert(workerToPersist.id === recoveryWorkerInfo.id) + assert(workerToPersist.host === recoveryWorkerInfo.host) + assert(workerToPersist.port === recoveryWorkerInfo.port) + assert(workerToPersist.cores === recoveryWorkerInfo.cores) + assert(workerToPersist.memory === recoveryWorkerInfo.memory) + assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) + assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) + assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + } finally { + testRpcEnv.shutdown() + testRpcEnv.awaitTermination() + } } finally { - testRpcEnv.shutdown() - testRpcEnv.awaitTermination() + persistenceEngine.close() } } From 413a73be95c67dcfa3d0d5567d5aa3c4765daffb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Fri, 20 Nov 2015 14:45:40 -0800 Subject: [PATCH 0857/1168] [SPARK-11716][SQL] UDFRegistration just drops the input type when re-creating the UserDefinedFunction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://issues.apache.org/jira/browse/SPARK-11716 This is one is #9739 and a regression test. When commit it, please make sure the author is jbonofre. You can find the original PR at https://github.com/apache/spark/pull/9739 closes #9739 Author: Jean-Baptiste Onofré Author: Yin Huai Closes #9868 from yhuai/SPARK-11716. --- .../apache/spark/sql/UDFRegistration.scala | 48 +++++++++---------- .../scala/org/apache/spark/sql/UDFSuite.scala | 15 ++++++ 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index fc4d0938c533a..051694c0d43a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -88,7 +88,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try($inputTypes).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) }""") } @@ -120,7 +120,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -133,7 +133,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -146,7 +146,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -159,7 +159,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -172,7 +172,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -185,7 +185,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -198,7 +198,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -211,7 +211,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -224,7 +224,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -237,7 +237,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -250,7 +250,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -263,7 +263,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -276,7 +276,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -289,7 +289,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -302,7 +302,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -315,7 +315,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -328,7 +328,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -341,7 +341,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -354,7 +354,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -367,7 +367,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -380,7 +380,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -393,7 +393,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -406,7 +406,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 9837fa6bdb357..fd736718af12c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -232,4 +232,19 @@ class UDFSuite extends QueryTest with SharedSQLContext { | (SELECT complexDataFunc(m, a, b) AS t FROM complexData) tmp """.stripMargin).toDF(), complexData.select("m", "a", "b")) } + + test("SPARK-11716 UDFRegistration does not include the input data type in returned UDF") { + val myUDF = sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) + + // Without the fix, this will fail because we fail to cast data type of b to string + // because myUDF does not know its input data type. With the fix, this query should not + // fail. + checkAnswer( + testData2.select(myUDF($"a", $"b").as("t")), + testData2.selectExpr("struct(a, b)")) + + checkAnswer( + sql("SELECT tmp.t.* FROM (SELECT testDataFunc(a, b) AS t from testData2) tmp").toDF(), + testData2) + } } From d9ba3c7bd43e902b5e839bb7915081ca52accdb2 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Fri, 20 Nov 2015 15:10:55 -0800 Subject: [PATCH 0858/1168] [SPARK-11756][SPARKR] Fix use of aliases - SparkR can not output help information for SparkR:::summary correctly Fix use of aliases and changes uses of rdname and seealso `aliases` is the hint for `?` - it should not be linked to some other name - those should be seealso https://cran.r-project.org/web/packages/roxygen2/vignettes/rd.html Clean up usage on family, as multiple use of family with the same rdname is causing duplicated See Also html blocks (like http://spark.apache.org/docs/latest/api/R/count.html) Also changing some rdname for dplyr-like variant for better R user visibility in R doc, eg. rbind, summary, mutate, summarize shivaram yanboliang Author: felixcheung Closes #9750 from felixcheung/rdocaliases. --- R/pkg/R/DataFrame.R | 96 ++++++++++++--------------------------------- R/pkg/R/broadcast.R | 1 - R/pkg/R/generics.R | 12 +++--- R/pkg/R/group.R | 12 +++--- 4 files changed, 37 insertions(+), 84 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 06b0108b1389e..8a13e7a36766d 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -254,7 +254,6 @@ setMethod("dtypes", #' @family DataFrame functions #' @rdname columns #' @name columns -#' @aliases names #' @export #' @examples #'\dontrun{ @@ -272,7 +271,6 @@ setMethod("columns", }) }) -#' @family DataFrame functions #' @rdname columns #' @name names setMethod("names", @@ -281,7 +279,6 @@ setMethod("names", columns(x) }) -#' @family DataFrame functions #' @rdname columns #' @name names<- setMethod("names<-", @@ -533,14 +530,8 @@ setMethod("distinct", dataFrame(sdf) }) -#' @title Distinct rows in a DataFrame -# -#' @description Returns a new DataFrame containing distinct rows in this DataFrame -#' -#' @family DataFrame functions -#' @rdname unique +#' @rdname distinct #' @name unique -#' @aliases distinct setMethod("unique", signature(x = "DataFrame"), function(x) { @@ -557,7 +548,7 @@ setMethod("unique", #' #' @family DataFrame functions #' @rdname sample -#' @aliases sample_frac +#' @name sample #' @export #' @examples #'\dontrun{ @@ -579,7 +570,6 @@ setMethod("sample", dataFrame(sdf) }) -#' @family DataFrame functions #' @rdname sample #' @name sample_frac setMethod("sample_frac", @@ -589,16 +579,15 @@ setMethod("sample_frac", sample(x, withReplacement, fraction) }) -#' Count +#' nrow #' #' Returns the number of rows in a DataFrame #' #' @param x A SparkSQL DataFrame #' #' @family DataFrame functions -#' @rdname count +#' @rdname nrow #' @name count -#' @aliases nrow #' @export #' @examples #'\dontrun{ @@ -614,14 +603,8 @@ setMethod("count", callJMethod(x@sdf, "count") }) -#' @title Number of rows for a DataFrame -#' @description Returns number of rows in a DataFrames -#' #' @name nrow -#' -#' @family DataFrame functions #' @rdname nrow -#' @aliases count setMethod("nrow", signature(x = "DataFrame"), function(x) { @@ -870,7 +853,6 @@ setMethod("toRDD", #' @param x a DataFrame #' @return a GroupedData #' @seealso GroupedData -#' @aliases group_by #' @family DataFrame functions #' @rdname groupBy #' @name groupBy @@ -896,7 +878,6 @@ setMethod("groupBy", groupedData(sgd) }) -#' @family DataFrame functions #' @rdname groupBy #' @name group_by setMethod("group_by", @@ -913,7 +894,6 @@ setMethod("group_by", #' @family DataFrame functions #' @rdname agg #' @name agg -#' @aliases summarize #' @export setMethod("agg", signature(x = "DataFrame"), @@ -921,7 +901,6 @@ setMethod("agg", agg(groupBy(x), ...) }) -#' @family DataFrame functions #' @rdname agg #' @name summarize setMethod("summarize", @@ -1092,7 +1071,6 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), #' @family DataFrame functions #' @rdname subset #' @name subset -#' @aliases [ #' @family subsetting functions #' @examples #' \dontrun{ @@ -1216,7 +1194,7 @@ setMethod("selectExpr", #' @family DataFrame functions #' @rdname withColumn #' @name withColumn -#' @aliases mutate transform +#' @seealso \link{rename} \link{mutate} #' @export #' @examples #'\dontrun{ @@ -1231,7 +1209,6 @@ setMethod("withColumn", function(x, colName, col) { select(x, x$"*", alias(col, colName)) }) - #' Mutate #' #' Return a new DataFrame with the specified columns added. @@ -1240,9 +1217,9 @@ setMethod("withColumn", #' @param col a named argument of the form name = col #' @return A new DataFrame with the new columns added. #' @family DataFrame functions -#' @rdname withColumn +#' @rdname mutate #' @name mutate -#' @aliases withColumn transform +#' @seealso \link{rename} \link{withColumn} #' @export #' @examples #'\dontrun{ @@ -1273,17 +1250,15 @@ setMethod("mutate", }) #' @export -#' @family DataFrame functions -#' @rdname withColumn +#' @rdname mutate #' @name transform -#' @aliases withColumn mutate setMethod("transform", signature(`_data` = "DataFrame"), function(`_data`, ...) { mutate(`_data`, ...) }) -#' WithColumnRenamed +#' rename #' #' Rename an existing column in a DataFrame. #' @@ -1292,8 +1267,9 @@ setMethod("transform", #' @param newCol The new column name. #' @return A DataFrame with the column name changed. #' @family DataFrame functions -#' @rdname withColumnRenamed +#' @rdname rename #' @name withColumnRenamed +#' @seealso \link{mutate} #' @export #' @examples #'\dontrun{ @@ -1316,17 +1292,9 @@ setMethod("withColumnRenamed", select(x, cols) }) -#' Rename -#' -#' Rename an existing column in a DataFrame. -#' -#' @param x A DataFrame -#' @param newCol A named pair of the form new_column_name = existing_column -#' @return A DataFrame with the column name changed. -#' @family DataFrame functions -#' @rdname withColumnRenamed +#' @param newColPair A named pair of the form new_column_name = existing_column +#' @rdname rename #' @name rename -#' @aliases withColumnRenamed #' @export #' @examples #'\dontrun{ @@ -1371,7 +1339,6 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @family DataFrame functions #' @rdname arrange #' @name arrange -#' @aliases orderby #' @export #' @examples #'\dontrun{ @@ -1395,8 +1362,8 @@ setMethod("arrange", dataFrame(sdf) }) -#' @family DataFrame functions #' @rdname arrange +#' @name arrange #' @export setMethod("arrange", signature(x = "DataFrame", col = "character"), @@ -1427,9 +1394,9 @@ setMethod("arrange", do.call("arrange", c(x, jcols)) }) -#' @family DataFrame functions #' @rdname arrange -#' @name orderby +#' @name orderBy +#' @export setMethod("orderBy", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col) { @@ -1492,6 +1459,7 @@ setMethod("where", #' @family DataFrame functions #' @rdname join #' @name join +#' @seealso \link{merge} #' @export #' @examples #'\dontrun{ @@ -1528,9 +1496,7 @@ setMethod("join", dataFrame(sdf) }) -#' #' @name merge -#' @aliases join #' @title Merges two data frames #' @param x the first data frame to be joined #' @param y the second data frame to be joined @@ -1550,6 +1516,7 @@ setMethod("join", #' outer join will be returned. #' @family DataFrame functions #' @rdname merge +#' @seealso \link{join} #' @export #' @examples #'\dontrun{ @@ -1671,7 +1638,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { cols } -#' UnionAll +#' rbind #' #' Return a new DataFrame containing the union of rows in this DataFrame #' and another DataFrame. This is equivalent to `UNION ALL` in SQL. @@ -1681,7 +1648,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the union. #' @family DataFrame functions -#' @rdname unionAll +#' @rdname rbind #' @name unionAll #' @export #' @examples @@ -1700,13 +1667,11 @@ setMethod("unionAll", }) #' @title Union two or more DataFrames -#' #' @description Returns a new DataFrame containing rows of all parameters. #' -#' @family DataFrame functions #' @rdname rbind #' @name rbind -#' @aliases unionAll +#' @export setMethod("rbind", signature(... = "DataFrame"), function(x, ..., deparse.level = 1) { @@ -1795,7 +1760,6 @@ setMethod("except", #' @family DataFrame functions #' @rdname write.df #' @name write.df -#' @aliases saveDF #' @export #' @examples #'\dontrun{ @@ -1828,7 +1792,6 @@ setMethod("write.df", callJMethod(df@sdf, "save", source, jmode, options) }) -#' @family DataFrame functions #' @rdname write.df #' @name saveDF #' @export @@ -1891,7 +1854,7 @@ setMethod("saveAsTable", callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) }) -#' describe +#' summary #' #' Computes statistics for numeric columns. #' If no columns are given, this function computes statistics for all numerical columns. @@ -1901,9 +1864,8 @@ setMethod("saveAsTable", #' @param ... Additional expressions #' @return A DataFrame #' @family DataFrame functions -#' @rdname describe +#' @rdname summary #' @name describe -#' @aliases summary #' @export #' @examples #'\dontrun{ @@ -1923,8 +1885,7 @@ setMethod("describe", dataFrame(sdf) }) -#' @family DataFrame functions -#' @rdname describe +#' @rdname summary #' @name describe setMethod("describe", signature(x = "DataFrame"), @@ -1934,11 +1895,6 @@ setMethod("describe", dataFrame(sdf) }) -#' @title Summary -#' -#' @description Computes statistics for numeric columns of the DataFrame -#' -#' @family DataFrame functions #' @rdname summary #' @name summary setMethod("summary", @@ -1966,7 +1922,6 @@ setMethod("summary", #' @family DataFrame functions #' @rdname nafunctions #' @name dropna -#' @aliases na.omit #' @export #' @examples #'\dontrun{ @@ -1993,7 +1948,6 @@ setMethod("dropna", dataFrame(sdf) }) -#' @family DataFrame functions #' @rdname nafunctions #' @name na.omit #' @export @@ -2019,9 +1973,7 @@ setMethod("na.omit", #' type are ignored. For example, if value is a character, and #' subset contains a non-character column, then the non-character #' column is simply ignored. -#' @return A DataFrame #' -#' @family DataFrame functions #' @rdname nafunctions #' @name fillna #' @export diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 2403925b267c8..38f0eed95e065 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -51,7 +51,6 @@ Broadcast <- function(id, value, jBroadcastRef, objName) { # # @param bcast The broadcast variable to get # @rdname broadcast -# @aliases value,Broadcast-method setMethod("value", signature(bcast = "Broadcast"), function(bcast) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 71004a05ba611..1b3f10ea04643 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -397,7 +397,7 @@ setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) #' @export setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") }) -#' @rdname describe +#' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) @@ -459,11 +459,11 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) -#' rdname merge +#' @rdname merge #' @export setGeneric("merge") -#' @rdname withColumn +#' @rdname mutate #' @export setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) @@ -475,7 +475,7 @@ setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) #' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) -#' @rdname withColumnRenamed +#' @rdname rename #' @export setGeneric("rename", function(x, ...) { standardGeneric("rename") }) @@ -553,7 +553,7 @@ setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) -#' @rdname unionAll +#' @rdname rbind #' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) @@ -565,7 +565,7 @@ setGeneric("where", function(x, condition) { standardGeneric("where") }) #' @export setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) -#' @rdname withColumnRenamed +#' @rdname rename #' @export setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index e5f702faee65d..23b49aebda05f 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -68,7 +68,7 @@ setMethod("count", dataFrame(callJMethod(x@sgd, "count")) }) -#' Agg +#' summarize #' #' Aggregates on the entire DataFrame without groups. #' The resulting DataFrame will also contain the grouping columns. @@ -78,12 +78,14 @@ setMethod("count", #' #' @param x a GroupedData #' @return a DataFrame -#' @rdname agg +#' @rdname summarize +#' @name agg #' @family agg_funcs #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' -#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' df3 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' df4 <- summarize(df, ageSum = max(df$age)) #' } setMethod("agg", signature(x = "GroupedData"), @@ -110,8 +112,8 @@ setMethod("agg", dataFrame(sdf) }) -#' @rdname agg -#' @aliases agg +#' @rdname summarize +#' @name summarize setMethod("summarize", signature(x = "GroupedData"), function(x, ...) { From b7c6d1b960958edf9d46436426c575bca70b07da Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 15:17:17 -0800 Subject: [PATCH 0859/1168] [SPARK-11636][SQL] Support classes defined in the REPL with Encoders #theScaryParts (i.e. changes to the repl, executor classloaders and codegen)... Author: Michael Armbrust Author: Yin Huai Closes #9825 from marmbrus/dataset-replClasses2. --- .../org/apache/spark/repl/SparkIMain.scala | 14 +++++++---- .../org/apache/spark/repl/ReplSuite.scala | 24 +++++++++++++++++++ .../spark/repl/ExecutorClassLoader.scala | 8 ++++++- .../expressions/codegen/CodeGenerator.scala | 4 ++-- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 4ee605fd7f11e..829b12269fd2b 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1221,10 +1221,16 @@ import org.apache.spark.annotation.DeveloperApi ) } - val preamble = """ - |class %s extends Serializable { - | %s%s%s - """.stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, importsPreamble, indentCode(toCompute)) + val preamble = s""" + |class ${lineRep.readName} extends Serializable { + | ${envLines.map(" " + _ + ";\n").mkString} + | $importsPreamble + | + | // If we need to construct any objects defined in the REPL on an executor we will need + | // to pass the outer scope to the appropriate encoder. + | org.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this) + | ${indentCode(toCompute)} + """.stripMargin val postamble = importsTrailer + "\n}" + "\n" + "object " + lineRep.readName + " {\n" + " val INSTANCE = new " + lineRep.readName + "();\n" + diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 5674dcd669bee..081aa03002cc6 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -262,6 +262,9 @@ class ReplSuite extends SparkFunSuite { |import sqlContext.implicits._ |case class TestCaseClass(value: Int) |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() + | + |// Test Dataset Serialization in the REPL + |Seq(TestCaseClass(1)).toDS().collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -278,6 +281,27 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("java.lang.ClassNotFoundException", output) } + test("Datasets and encoders") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |val simpleSum = new Aggregator[Int, Int, Int] with Serializable { + | def zero: Int = 0 // The initial value. + | def reduce(b: Int, a: Int) = b + a // Add an element to the running total + | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. + | def finish(b: Int) = b // Return the final result. + |}.toColumn + | + |val ds = Seq(1, 2, 3, 4).toDS() + |ds.select(simpleSum).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("SPARK-2632 importing a method from non serializable class and not using it.") { val output = runInterpreter("local", """ diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 3d2d235a00c93..a976e96809cb8 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -65,7 +65,13 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader case e: ClassNotFoundException => { val classOption = findClassLocally(name) classOption match { - case None => throw new ClassNotFoundException(name, e) + case None => + // If this class has a cause, it will break the internal assumption of Janino + // (the compiler used for Spark SQL code-gen). + // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see + // its behavior will be changed if there is a cause and the compilation + // of generated class will fail. + throw new ClassNotFoundException(name) case Some(a) => a } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1b7260cdfe515..2f3d6aeb86c5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ - +import org.apache.spark.util.Utils /** * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. @@ -536,7 +536,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin */ private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() - evaluator.setParentClassLoader(getClass.getClassLoader) + evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( From 26479371f8b51c255cc4b850ab22fccfbd4cd373 Mon Sep 17 00:00:00 2001 From: Vikas Nelamangala Date: Fri, 20 Nov 2015 15:18:41 -0800 Subject: [PATCH 0860/1168] [SPARK-11549][DOCS] Replace example code in mllib-evaluation-metrics.md using include_example Author: Vikas Nelamangala Closes #9689 from vikasnp/master. --- docs/mllib-evaluation-metrics.md | 940 +----------------- ...avaBinaryClassificationMetricsExample.java | 113 +++ ...ultiLabelClassificationMetricsExample.java | 80 ++ ...ulticlassClassificationMetricsExample.java | 97 ++ .../mllib/JavaRankingMetricsExample.java | 176 ++++ .../mllib/JavaRegressionMetricsExample.java | 91 ++ .../binary_classification_metrics_example.py | 55 + .../mllib/multi_class_metrics_example.py | 69 ++ .../mllib/multi_label_metrics_example.py | 61 ++ .../python/mllib/ranking_metrics_example.py | 55 + .../mllib/regression_metrics_example.py | 59 ++ .../BinaryClassificationMetricsExample.scala | 103 ++ .../mllib/MultiLabelMetricsExample.scala | 69 ++ .../mllib/MulticlassMetricsExample.scala | 99 ++ .../mllib/RankingMetricsExample.scala | 110 ++ .../mllib/RegressionMetricsExample.scala | 67 ++ 16 files changed, 1319 insertions(+), 925 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java create mode 100644 examples/src/main/python/mllib/binary_classification_metrics_example.py create mode 100644 examples/src/main/python/mllib/multi_class_metrics_example.py create mode 100644 examples/src/main/python/mllib/multi_label_metrics_example.py create mode 100644 examples/src/main/python/mllib/ranking_metrics_example.py create mode 100644 examples/src/main/python/mllib/regression_metrics_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index f73eff637dc36..6924037b941f3 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -104,214 +104,21 @@ data, and evaluate the performance of the algorithm by several binary evaluation
    Refer to the [`LogisticRegressionWithLBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) and [`BinaryClassificationMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.BinaryClassificationMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") - -// Split data into training (60%) and test (40%) -val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) -training.cache() - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(2) - .run(training) - -// Clear the prediction threshold so the model will return probabilities -model.clearThreshold - -// Compute raw scores on the test set -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Instantiate metrics object -val metrics = new BinaryClassificationMetrics(predictionAndLabels) - -// Precision by threshold -val precision = metrics.precisionByThreshold -precision.foreach { case (t, p) => - println(s"Threshold: $t, Precision: $p") -} - -// Recall by threshold -val recall = metrics.recallByThreshold -recall.foreach { case (t, r) => - println(s"Threshold: $t, Recall: $r") -} - -// Precision-Recall Curve -val PRC = metrics.pr - -// F-measure -val f1Score = metrics.fMeasureByThreshold -f1Score.foreach { case (t, f) => - println(s"Threshold: $t, F-score: $f, Beta = 1") -} - -val beta = 0.5 -val fScore = metrics.fMeasureByThreshold(beta) -f1Score.foreach { case (t, f) => - println(s"Threshold: $t, F-score: $f, Beta = 0.5") -} - -// AUPRC -val auPRC = metrics.areaUnderPR -println("Area under precision-recall curve = " + auPRC) - -// Compute thresholds used in ROC and PR curves -val thresholds = precision.map(_._1) - -// ROC Curve -val roc = metrics.roc - -// AUROC -val auROC = metrics.areaUnderROC -println("Area under ROC = " + auROC) - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala %}
    Refer to the [`LogisticRegressionModel` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionModel.html) and [`LogisticRegressionWithLBFGS` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class BinaryClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_binary_classification_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(2) - .run(training.rdd()); - - // Clear the prediction threshold so the model will return probabilities - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); - - // Precision by threshold - JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); - System.out.println("Precision by threshold: " + precision.toArray()); - - // Recall by threshold - JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); - System.out.println("Recall by threshold: " + recall.toArray()); - - // F Score by threshold - JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); - System.out.println("F1 Score by threshold: " + f1Score.toArray()); - - JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); - System.out.println("F2 Score by threshold: " + f2Score.toArray()); - - // Precision-recall curve - JavaRDD> prc = metrics.pr().toJavaRDD(); - System.out.println("Precision-recall curve: " + prc.toArray()); - - // Thresholds - JavaRDD thresholds = precision.map( - new Function, Double>() { - public Double call (Tuple2 t) { - return new Double(t._1().toString()); - } - } - ); - - // ROC Curve - JavaRDD> roc = metrics.roc().toJavaRDD(); - System.out.println("ROC curve: " + roc.toArray()); - - // AUPRC - System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); - - // AUROC - System.out.println("Area under ROC = " + metrics.areaUnderROC()); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java %}
    Refer to the [`BinaryClassificationMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.BinaryClassificationMetrics) and [`LogisticRegressionWithLBFGS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionWithLBFGS) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS -from pyspark.mllib.evaluation import BinaryClassificationMetrics -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import MLUtils - -# Several of the methods available in scala are currently missing from pyspark - -# Load training data in LIBSVM format -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") - -# Split data into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 11L) -training.cache() - -# Run training algorithm to build the model -model = LogisticRegressionWithLBFGS.train(training) - -# Compute raw scores on the test set -predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) - -# Instantiate metrics object -metrics = BinaryClassificationMetrics(predictionAndLabels) - -# Area under precision-recall curve -print("Area under PR = %s" % metrics.areaUnderPR) - -# Area under ROC curve -print("Area under ROC = %s" % metrics.areaUnderROC) - -{% endhighlight %} - +{% include_example python/mllib/binary_classification_metrics_example.py %}
    @@ -433,204 +240,21 @@ the data, and evaluate the performance of the algorithm by several multiclass cl
    Refer to the [`MulticlassMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.MulticlassMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") - -// Split data into training (60%) and test (40%) -val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) -training.cache() - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(3) - .run(training) - -// Compute raw scores on the test set -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Instantiate metrics object -val metrics = new MulticlassMetrics(predictionAndLabels) - -// Confusion matrix -println("Confusion matrix:") -println(metrics.confusionMatrix) - -// Overall Statistics -val precision = metrics.precision -val recall = metrics.recall // same as true positive rate -val f1Score = metrics.fMeasure -println("Summary Statistics") -println(s"Precision = $precision") -println(s"Recall = $recall") -println(s"F1 Score = $f1Score") - -// Precision by label -val labels = metrics.labels -labels.foreach { l => - println(s"Precision($l) = " + metrics.precision(l)) -} - -// Recall by label -labels.foreach { l => - println(s"Recall($l) = " + metrics.recall(l)) -} - -// False positive rate by label -labels.foreach { l => - println(s"FPR($l) = " + metrics.falsePositiveRate(l)) -} - -// F-measure by label -labels.foreach { l => - println(s"F1-Score($l) = " + metrics.fMeasure(l)) -} - -// Weighted stats -println(s"Weighted precision: ${metrics.weightedPrecision}") -println(s"Weighted recall: ${metrics.weightedRecall}") -println(s"Weighted F1 score: ${metrics.weightedFMeasure}") -println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala %}
    Refer to the [`MulticlassMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/MulticlassMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class MulticlassClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_multiclass_classification_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(3) - .run(training.rdd()); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); - - // Confusion matrix - Matrix confusion = metrics.confusionMatrix(); - System.out.println("Confusion matrix: \n" + confusion); - - // Overall statistics - System.out.println("Precision = " + metrics.precision()); - System.out.println("Recall = " + metrics.recall()); - System.out.println("F1 Score = " + metrics.fMeasure()); - - // Stats by labels - for (int i = 0; i < metrics.labels().length; i++) { - System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); - System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); - System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); - } - - //Weighted stats - System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); - System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); - System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); - System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} - -{% endhighlight %} + {% include_example java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java %}
    Refer to the [`MulticlassMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MulticlassMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS -from pyspark.mllib.util import MLUtils -from pyspark.mllib.evaluation import MulticlassMetrics - -# Load training data in LIBSVM format -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") - -# Split data into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 11L) -training.cache() - -# Run training algorithm to build the model -model = LogisticRegressionWithLBFGS.train(training, numClasses=3) - -# Compute raw scores on the test set -predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) - -# Instantiate metrics object -metrics = MulticlassMetrics(predictionAndLabels) - -# Overall statistics -precision = metrics.precision() -recall = metrics.recall() -f1Score = metrics.fMeasure() -print("Summary Stats") -print("Precision = %s" % precision) -print("Recall = %s" % recall) -print("F1 Score = %s" % f1Score) - -# Statistics by class -labels = data.map(lambda lp: lp.label).distinct().collect() -for label in sorted(labels): - print("Class %s precision = %s" % (label, metrics.precision(label))) - print("Class %s recall = %s" % (label, metrics.recall(label))) - print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) - -# Weighted stats -print("Weighted recall = %s" % metrics.weightedRecall) -print("Weighted precision = %s" % metrics.weightedPrecision) -print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) -print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) -print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) -{% endhighlight %} +{% include_example python/mllib/multi_class_metrics_example.py %}
    @@ -766,154 +390,21 @@ True classes:
    Refer to the [`MultilabelMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.MultilabelMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.evaluation.MultilabelMetrics -import org.apache.spark.rdd.RDD; - -val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( - Seq((Array(0.0, 1.0), Array(0.0, 2.0)), - (Array(0.0, 2.0), Array(0.0, 1.0)), - (Array(), Array(0.0)), - (Array(2.0), Array(2.0)), - (Array(2.0, 0.0), Array(2.0, 0.0)), - (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), - (Array(1.0), Array(1.0, 2.0))), 2) - -// Instantiate metrics object -val metrics = new MultilabelMetrics(scoreAndLabels) - -// Summary stats -println(s"Recall = ${metrics.recall}") -println(s"Precision = ${metrics.precision}") -println(s"F1 measure = ${metrics.f1Measure}") -println(s"Accuracy = ${metrics.accuracy}") - -// Individual label stats -metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}")) -metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) -metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) - -// Micro stats -println(s"Micro recall = ${metrics.microRecall}") -println(s"Micro precision = ${metrics.microPrecision}") -println(s"Micro F1 measure = ${metrics.microF1Measure}") - -// Hamming loss -println(s"Hamming loss = ${metrics.hammingLoss}") - -// Subset accuracy -println(s"Subset accuracy = ${metrics.subsetAccuracy}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala %}
    Refer to the [`MultilabelMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/MultilabelMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.mllib.evaluation.MultilabelMetrics; -import org.apache.spark.SparkConf; -import java.util.Arrays; -import java.util.List; - -public class MultilabelClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); - JavaSparkContext sc = new JavaSparkContext(conf); - - List> data = Arrays.asList( - new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), - new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{}, new double[]{0.0}), - new Tuple2(new double[]{2.0}, new double[]{2.0}), - new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), - new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) - ); - JavaRDD> scoreAndLabels = sc.parallelize(data); - - // Instantiate metrics object - MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); - - // Summary stats - System.out.format("Recall = %f\n", metrics.recall()); - System.out.format("Precision = %f\n", metrics.precision()); - System.out.format("F1 measure = %f\n", metrics.f1Measure()); - System.out.format("Accuracy = %f\n", metrics.accuracy()); - - // Stats by labels - for (int i = 0; i < metrics.labels().length - 1; i++) { - System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); - System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); - System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); - } - - // Micro stats - System.out.format("Micro recall = %f\n", metrics.microRecall()); - System.out.format("Micro precision = %f\n", metrics.microPrecision()); - System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); - - // Hamming loss - System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); - - // Subset accuracy - System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); - - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java %}
    Refer to the [`MultilabelMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MultilabelMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.evaluation import MultilabelMetrics - -scoreAndLabels = sc.parallelize([ - ([0.0, 1.0], [0.0, 2.0]), - ([0.0, 2.0], [0.0, 1.0]), - ([], [0.0]), - ([2.0], [2.0]), - ([2.0, 0.0], [2.0, 0.0]), - ([0.0, 1.0, 2.0], [0.0, 1.0]), - ([1.0], [1.0, 2.0])]) - -# Instantiate metrics object -metrics = MultilabelMetrics(scoreAndLabels) - -# Summary stats -print("Recall = %s" % metrics.recall()) -print("Precision = %s" % metrics.precision()) -print("F1 measure = %s" % metrics.f1Measure()) -print("Accuracy = %s" % metrics.accuracy) - -# Individual label stats -labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() -for label in labels: - print("Class %s precision = %s" % (label, metrics.precision(label))) - print("Class %s recall = %s" % (label, metrics.recall(label))) - print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) - -# Micro stats -print("Micro precision = %s" % metrics.microPrecision) -print("Micro recall = %s" % metrics.microRecall) -print("Micro F1 measure = %s" % metrics.microF1Measure) - -# Hamming loss -print("Hamming loss = %s" % metrics.hammingLoss) - -# Subset accuracy -print("Subset accuracy = %s" % metrics.subsetAccuracy) - -{% endhighlight %} +{% include_example python/mllib/multi_label_metrics_example.py %}
    @@ -1027,280 +518,21 @@ expanded world of non-positive weights are "the same as never having interacted
    Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RankingMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} -import org.apache.spark.mllib.recommendation.{ALS, Rating} - -// Read in the ratings data -val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => - val fields = line.split("::") - Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) -}.cache() - -// Map ratings to 1 or 0, 1 indicating a movie that should be recommended -val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache() - -// Summarize ratings -val numRatings = ratings.count() -val numUsers = ratings.map(_.user).distinct().count() -val numMovies = ratings.map(_.product).distinct().count() -println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") - -// Build the model -val numIterations = 10 -val rank = 10 -val lambda = 0.01 -val model = ALS.train(ratings, rank, numIterations, lambda) - -// Define a function to scale ratings from 0 to 1 -def scaledRating(r: Rating): Rating = { - val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) - Rating(r.user, r.product, scaledRating) -} - -// Get sorted top ten predictions for each user and then scale from [0, 1] -val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) => - (user, recs.map(scaledRating)) -} - -// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document -// Compare with top ten most relevant documents -val userMovies = binarizedRatings.groupBy(_.user) -val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) => - (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) -} - -// Instantiate metrics object -val metrics = new RankingMetrics(relevantDocuments) - -// Precision at K -Array(1, 3, 5).foreach{ k => - println(s"Precision at $k = ${metrics.precisionAt(k)}") -} - -// Mean average precision -println(s"Mean average precision = ${metrics.meanAveragePrecision}") - -// Normalized discounted cumulative gain -Array(1, 3, 5).foreach{ k => - println(s"NDCG at $k = ${metrics.ndcgAt(k)}") -} - -// Get predictions for each data point -val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) -val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) -val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) => - (predicted, actual) -} - -// Get the RMSE using regression metrics -val regressionMetrics = new RegressionMetrics(predictionsAndLabels) -println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") - -// R-squared -println(s"R-squared = ${regressionMetrics.r2}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala %}
    Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) and [`RankingMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RankingMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function; -import java.util.*; -import org.apache.spark.mllib.evaluation.RegressionMetrics; -import org.apache.spark.mllib.evaluation.RankingMetrics; -import org.apache.spark.mllib.recommendation.ALS; -import org.apache.spark.mllib.recommendation.Rating; - -// Read in the ratings data -public class Ranking { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Ranking Metrics"); - JavaSparkContext sc = new JavaSparkContext(conf); - String path = "data/mllib/sample_movielens_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD ratings = data.map( - new Function() { - public Rating call(String line) { - String[] parts = line.split("::"); - return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5); - } - } - ); - ratings.cache(); - - // Train an ALS model - final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); - - // Get top 10 recommendations for every user and scale ratings from 0 to 1 - JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); - JavaRDD> userRecsScaled = userRecs.map( - new Function, Tuple2>() { - public Tuple2 call(Tuple2 t) { - Rating[] scaledRatings = new Rating[t._2().length]; - for (int i = 0; i < scaledRatings.length; i++) { - double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); - scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); - } - return new Tuple2(t._1(), scaledRatings); - } - } - ); - JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); - - // Map ratings to 1 or 0, 1 indicating a movie that should be recommended - JavaRDD binarizedRatings = ratings.map( - new Function() { - public Rating call(Rating r) { - double binaryRating; - if (r.rating() > 0.0) { - binaryRating = 1.0; - } - else { - binaryRating = 0.0; - } - return new Rating(r.user(), r.product(), binaryRating); - } - } - ); - - // Group ratings by common user - JavaPairRDD> userMovies = binarizedRatings.groupBy( - new Function() { - public Object call(Rating r) { - return r.user(); - } - } - ); - - // Get true relevant documents from all user ratings - JavaPairRDD> userMoviesList = userMovies.mapValues( - new Function, List>() { - public List call(Iterable docs) { - List products = new ArrayList(); - for (Rating r : docs) { - if (r.rating() > 0.0) { - products.add(r.product()); - } - } - return products; - } - } - ); - - // Extract the product id from each recommendation - JavaPairRDD> userRecommendedList = userRecommended.mapValues( - new Function>() { - public List call(Rating[] docs) { - List products = new ArrayList(); - for (Rating r : docs) { - products.add(r.product()); - } - return products; - } - } - ); - JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values(); - - // Instantiate the metrics object - RankingMetrics metrics = RankingMetrics.of(relevantDocs); - - // Precision and NDCG at k - Integer[] kVector = {1, 3, 5}; - for (Integer k : kVector) { - System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); - System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); - } - - // Mean average precision - System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); - - // Evaluate the model using numerical ratings and regression metrics - JavaRDD> userProducts = ratings.map( - new Function>() { - public Tuple2 call(Rating r) { - return new Tuple2(r.user(), r.product()); - } - } - ); - JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( - model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( - new Function, Object>>() { - public Tuple2, Object> call(Rating r){ - return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )); - JavaRDD> ratesAndPreds = - JavaPairRDD.fromJavaRDD(ratings.map( - new Function, Object>>() { - public Tuple2, Object> call(Rating r){ - return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )).join(predictions).values(); - - // Create regression metrics object - RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); - - // Root mean squared error - System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); - - // R-squared - System.out.format("R-squared = %f\n", regressionMetrics.r2()); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java %}
    Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.recommendation import ALS, Rating -from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics - -# Read in the ratings data -lines = sc.textFile("data/mllib/sample_movielens_data.txt") - -def parseLine(line): - fields = line.split("::") - return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) -ratings = lines.map(lambda r: parseLine(r)) - -# Train a model on to predict user-product ratings -model = ALS.train(ratings, 10, 10, 0.01) - -# Get predicted ratings on all existing user-product pairs -testData = ratings.map(lambda p: (p.user, p.product)) -predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) - -ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) -scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) - -# Instantiate regression metrics to compare predicted and actual ratings -metrics = RegressionMetrics(scoreAndLabels) - -# Root mean sqaured error -print("RMSE = %s" % metrics.rootMeanSquaredError) - -# R-squared -print("R-squared = %s" % metrics.r2) - -{% endhighlight %} +{% include_example python/mllib/ranking_metrics_example.py %}
    @@ -1350,163 +582,21 @@ and evaluate the performance of the algorithm by several regression metrics.
    Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.regression.LinearRegressionModel -import org.apache.spark.mllib.regression.LinearRegressionWithSGD -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.mllib.util.MLUtils - -// Load the data -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() - -// Build the model -val numIterations = 100 -val model = LinearRegressionWithSGD.train(data, numIterations) - -// Get predictions -val valuesAndPreds = data.map{ point => - val prediction = model.predict(point.features) - (prediction, point.label) -} - -// Instantiate metrics object -val metrics = new RegressionMetrics(valuesAndPreds) - -// Squared error -println(s"MSE = ${metrics.meanSquaredError}") -println(s"RMSE = ${metrics.rootMeanSquaredError}") - -// R-squared -println(s"R-squared = ${metrics.r2}") - -// Mean absolute error -println(s"MAE = ${metrics.meanAbsoluteError}") - -// Explained variance -println(s"Explained variance = ${metrics.explainedVariance}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala %}
    Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.regression.LinearRegressionModel; -import org.apache.spark.mllib.regression.LinearRegressionWithSGD; -import org.apache.spark.mllib.evaluation.RegressionMetrics; -import org.apache.spark.SparkConf; - -public class LinearRegression { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/sample_linear_regression_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public LabeledPoint call(String line) { - String[] parts = line.split(" "); - double[] v = new double[parts.length - 1]; - for (int i = 1; i < parts.length - 1; i++) - v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - } - } - ); - parsedData.cache(); - - // Building the model - int numIterations = 100; - final LinearRegressionModel model = - LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); - - // Evaluate model on training examples and compute training error - JavaRDD> valuesAndPreds = parsedData.map( - new Function>() { - public Tuple2 call(LabeledPoint point) { - double prediction = model.predict(point.features()); - return new Tuple2(prediction, point.label()); - } - } - ); - - // Instantiate metrics object - RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); - - // Squared error - System.out.format("MSE = %f\n", metrics.meanSquaredError()); - System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); - - // R-squared - System.out.format("R Squared = %f\n", metrics.r2()); - - // Mean absolute error - System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); - - // Explained variance - System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java %}
    Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD -from pyspark.mllib.evaluation import RegressionMetrics -from pyspark.mllib.linalg import DenseVector - -# Load and parse the data -def parsePoint(line): - values = line.split() - return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]])) - -data = sc.textFile("data/mllib/sample_linear_regression_data.txt") -parsedData = data.map(parsePoint) - -# Build the model -model = LinearRegressionWithSGD.train(parsedData) - -# Get predictions -valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) - -# Instantiate metrics object -metrics = RegressionMetrics(valuesAndPreds) - -# Squared Error -print("MSE = %s" % metrics.meanSquaredError) -print("RMSE = %s" % metrics.rootMeanSquaredError) - -# R-squared -print("R-squared = %s" % metrics.r2) - -# Mean absolute error -print("MAE = %s" % metrics.meanAbsoluteError) - -# Explained variance -print("Explained variance = %s" % metrics.explainedVariance) - -{% endhighlight %} +{% include_example python/mllib/regression_metrics_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java new file mode 100644 index 0000000000000..980a9108af53f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class JavaBinaryClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_binary_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = + data.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training.rdd()); + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); + + // Precision by threshold + JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); + System.out.println("Precision by threshold: " + precision.toArray()); + + // Recall by threshold + JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + System.out.println("Recall by threshold: " + recall.toArray()); + + // F Score by threshold + JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + System.out.println("F1 Score by threshold: " + f1Score.toArray()); + + JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + System.out.println("F2 Score by threshold: " + f2Score.toArray()); + + // Precision-recall curve + JavaRDD> prc = metrics.pr().toJavaRDD(); + System.out.println("Precision-recall curve: " + prc.toArray()); + + // Thresholds + JavaRDD thresholds = precision.map( + new Function, Double>() { + public Double call(Tuple2 t) { + return new Double(t._1().toString()); + } + } + ); + + // ROC Curve + JavaRDD> roc = metrics.roc().toJavaRDD(); + System.out.println("ROC curve: " + roc.toArray()); + + // AUPRC + System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); + + // AUROC + System.out.println("Area under ROC = " + metrics.areaUnderROC()); + + // Save and load model + model.save(sc, "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java new file mode 100644 index 0000000000000..b54e1ea3f2bcf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.mllib.evaluation.MultilabelMetrics; +import org.apache.spark.rdd.RDD; +import org.apache.spark.SparkConf; +// $example off$ +import org.apache.spark.SparkContext; + +public class JavaMultiLabelClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + List> data = Arrays.asList( + new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{}, new double[]{0.0}), + new Tuple2(new double[]{2.0}, new double[]{2.0}), + new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + ); + JavaRDD> scoreAndLabels = sc.parallelize(data); + + // Instantiate metrics object + MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); + + // Summary stats + System.out.format("Recall = %f\n", metrics.recall()); + System.out.format("Precision = %f\n", metrics.precision()); + System.out.format("F1 measure = %f\n", metrics.f1Measure()); + System.out.format("Accuracy = %f\n", metrics.accuracy()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length - 1; i++) { + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision + (metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics + .labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure + (metrics.labels()[i])); + } + + // Micro stats + System.out.format("Micro recall = %f\n", metrics.microRecall()); + System.out.format("Micro precision = %f\n", metrics.microPrecision()); + System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); + + // Hamming loss + System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); + + // Subset accuracy + System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java new file mode 100644 index 0000000000000..21f628fb51b6e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.mllib.linalg.Matrix; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class JavaMulticlassClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics Example"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_multiclass_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + + // Confusion matrix + Matrix confusion = metrics.confusionMatrix(); + System.out.println("Confusion matrix: \n" + confusion); + + // Overall statistics + System.out.println("Precision = " + metrics.precision()); + System.out.println("Recall = " + metrics.recall()); + System.out.println("F1 Score = " + metrics.fMeasure()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length; i++) { + System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision + (metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics + .labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure + (metrics.labels()[i])); + } + + //Weighted stats + System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); + System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); + System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); + System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); + + // Save and load model + model.save(sc, "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java new file mode 100644 index 0000000000000..7c4c97e74681f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.*; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.mllib.evaluation.RankingMetrics; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; +// $example off$ +import org.apache.spark.SparkConf; + +public class JavaRankingMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Ranking Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + String path = "data/mllib/sample_movielens_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String line) { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double + .parseDouble(parts[2]) - 2.5); + } + } + ); + ratings.cache(); + + // Train an ALS model + final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + + // Get top 10 recommendations for every user and scale ratings from 0 to 1 + JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); + JavaRDD> userRecsScaled = userRecs.map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 t) { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); + } + return new Tuple2(t._1(), scaledRatings); + } + } + ); + JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + JavaRDD binarizedRatings = ratings.map( + new Function() { + public Rating call(Rating r) { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } else { + binaryRating = 0.0; + } + return new Rating(r.user(), r.product(), binaryRating); + } + } + ); + + // Group ratings by common user + JavaPairRDD> userMovies = binarizedRatings.groupBy( + new Function() { + public Object call(Rating r) { + return r.user(); + } + } + ); + + // Get true relevant documents from all user ratings + JavaPairRDD> userMoviesList = userMovies.mapValues( + new Function, List>() { + public List call(Iterable docs) { + List products = new ArrayList(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); + } + } + return products; + } + } + ); + + // Extract the product id from each recommendation + JavaPairRDD> userRecommendedList = userRecommended.mapValues( + new Function>() { + public List call(Rating[] docs) { + List products = new ArrayList(); + for (Rating r : docs) { + products.add(r.product()); + } + return products; + } + } + ); + JavaRDD, List>> relevantDocs = userMoviesList.join + (userRecommendedList).values(); + + // Instantiate the metrics object + RankingMetrics metrics = RankingMetrics.of(relevantDocs); + + // Precision and NDCG at k + Integer[] kVector = {1, 3, 5}; + for (Integer k : kVector) { + System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); + System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + } + + // Mean average precision + System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); + + // Evaluate the model using numerical ratings and regression metrics + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r) { + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r) { + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + + // Create regression metrics object + RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); + + // Root mean squared error + System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R-squared = %f\n", regressionMetrics.r2()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java new file mode 100644 index 0000000000000..d2efc6bf97776 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRegressionMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Regression Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + // Load and parse the data + String path = "data/mllib/sample_linear_regression_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), + numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + + // Instantiate metrics object + RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); + + // Squared error + System.out.format("MSE = %f\n", metrics.meanSquaredError()); + System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R Squared = %f\n", metrics.r2()); + + // Mean absolute error + System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); + + // Explained variance + System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); + + // Save and load model + model.save(sc.sc(), "target/tmp/LogisticRegressionModel"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/python/mllib/binary_classification_metrics_example.py b/examples/src/main/python/mllib/binary_classification_metrics_example.py new file mode 100644 index 0000000000000..437acb998acc3 --- /dev/null +++ b/examples/src/main/python/mllib/binary_classification_metrics_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Binary Classification Metrics Example. +""" +from __future__ import print_function +import sys +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinaryClassificationMetricsExample") + sqlContext = SQLContext(sc) + # $example on$ + # Several of the methods available in scala are currently missing from pyspark + # Load training data in LIBSVM format + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + + # Split data into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=11L) + training.cache() + + # Run training algorithm to build the model + model = LogisticRegressionWithLBFGS.train(training) + + # Compute raw scores on the test set + predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + + # Instantiate metrics object + metrics = BinaryClassificationMetrics(predictionAndLabels) + + # Area under precision-recall curve + print("Area under PR = %s" % metrics.areaUnderPR) + + # Area under ROC curve + print("Area under ROC = %s" % metrics.areaUnderROC) + # $example off$ diff --git a/examples/src/main/python/mllib/multi_class_metrics_example.py b/examples/src/main/python/mllib/multi_class_metrics_example.py new file mode 100644 index 0000000000000..cd56b3c97c778 --- /dev/null +++ b/examples/src/main/python/mllib/multi_class_metrics_example.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.util import MLUtils +from pyspark.mllib.evaluation import MulticlassMetrics +# $example off$ + +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="MultiClassMetricsExample") + + # Several of the methods available in scala are currently missing from pyspark + # $example on$ + # Load training data in LIBSVM format + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + + # Split data into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=11L) + training.cache() + + # Run training algorithm to build the model + model = LogisticRegressionWithLBFGS.train(training, numClasses=3) + + # Compute raw scores on the test set + predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + + # Instantiate metrics object + metrics = MulticlassMetrics(predictionAndLabels) + + # Overall statistics + precision = metrics.precision() + recall = metrics.recall() + f1Score = metrics.fMeasure() + print("Summary Stats") + print("Precision = %s" % precision) + print("Recall = %s" % recall) + print("F1 Score = %s" % f1Score) + + # Statistics by class + labels = data.map(lambda lp: lp.label).distinct().collect() + for label in sorted(labels): + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) + + # Weighted stats + print("Weighted recall = %s" % metrics.weightedRecall) + print("Weighted precision = %s" % metrics.weightedPrecision) + print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) + print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) + print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) + # $example off$ diff --git a/examples/src/main/python/mllib/multi_label_metrics_example.py b/examples/src/main/python/mllib/multi_label_metrics_example.py new file mode 100644 index 0000000000000..960ade6597379 --- /dev/null +++ b/examples/src/main/python/mllib/multi_label_metrics_example.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.evaluation import MultilabelMetrics +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="MultiLabelMetricsExample") + # $example on$ + scoreAndLabels = sc.parallelize([ + ([0.0, 1.0], [0.0, 2.0]), + ([0.0, 2.0], [0.0, 1.0]), + ([], [0.0]), + ([2.0], [2.0]), + ([2.0, 0.0], [2.0, 0.0]), + ([0.0, 1.0, 2.0], [0.0, 1.0]), + ([1.0], [1.0, 2.0])]) + + # Instantiate metrics object + metrics = MultilabelMetrics(scoreAndLabels) + + # Summary stats + print("Recall = %s" % metrics.recall()) + print("Precision = %s" % metrics.precision()) + print("F1 measure = %s" % metrics.f1Measure()) + print("Accuracy = %s" % metrics.accuracy) + + # Individual label stats + labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() + for label in labels: + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) + + # Micro stats + print("Micro precision = %s" % metrics.microPrecision) + print("Micro recall = %s" % metrics.microRecall) + print("Micro F1 measure = %s" % metrics.microF1Measure) + + # Hamming loss + print("Hamming loss = %s" % metrics.hammingLoss) + + # Subset accuracy + print("Subset accuracy = %s" % metrics.subsetAccuracy) + # $example off$ diff --git a/examples/src/main/python/mllib/ranking_metrics_example.py b/examples/src/main/python/mllib/ranking_metrics_example.py new file mode 100644 index 0000000000000..327791966c901 --- /dev/null +++ b/examples/src/main/python/mllib/ranking_metrics_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="Ranking Metrics Example") + + # Several of the methods available in scala are currently missing from pyspark + # $example on$ + # Read in the ratings data + lines = sc.textFile("data/mllib/sample_movielens_data.txt") + + def parseLine(line): + fields = line.split("::") + return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) + ratings = lines.map(lambda r: parseLine(r)) + + # Train a model on to predict user-product ratings + model = ALS.train(ratings, 10, 10, 0.01) + + # Get predicted ratings on all existing user-product pairs + testData = ratings.map(lambda p: (p.user, p.product)) + predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) + + ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) + scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) + + # Instantiate regression metrics to compare predicted and actual ratings + metrics = RegressionMetrics(scoreAndLabels) + + # Root mean sqaured error + print("RMSE = %s" % metrics.rootMeanSquaredError) + + # R-squared + print("R-squared = %s" % metrics.r2) + # $example off$ diff --git a/examples/src/main/python/mllib/regression_metrics_example.py b/examples/src/main/python/mllib/regression_metrics_example.py new file mode 100644 index 0000000000000..a3a83aafd7a1f --- /dev/null +++ b/examples/src/main/python/mllib/regression_metrics_example.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# $example on$ +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.evaluation import RegressionMetrics +from pyspark.mllib.linalg import DenseVector +# $example off$ + +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="Regression Metrics Example") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = line.split() + return LabeledPoint(float(values[0]), + DenseVector([float(x.split(':')[1]) for x in values[1:]])) + + data = sc.textFile("data/mllib/sample_linear_regression_data.txt") + parsedData = data.map(parsePoint) + + # Build the model + model = LinearRegressionWithSGD.train(parsedData) + + # Get predictions + valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) + + # Instantiate metrics object + metrics = RegressionMetrics(valuesAndPreds) + + # Squared Error + print("MSE = %s" % metrics.meanSquaredError) + print("RMSE = %s" % metrics.rootMeanSquaredError) + + # R-squared + print("R-squared = %s" % metrics.r2) + + # Mean absolute error + print("MAE = %s" % metrics.meanAbsoluteError) + + # Explained variance + print("Explained variance = %s" % metrics.explainedVariance) + # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala new file mode 100644 index 0000000000000..13a37827ab935 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object BinaryClassificationMetricsExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("BinaryClassificationMetricsExample") + val sc = new SparkContext(conf) + // $example on$ + // Load training data in LIBSVM format + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + + // Split data into training (60%) and test (40%) + val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) + training.cache() + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training) + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold + + // Compute raw scores on the test set + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Instantiate metrics object + val metrics = new BinaryClassificationMetrics(predictionAndLabels) + + // Precision by threshold + val precision = metrics.precisionByThreshold + precision.foreach { case (t, p) => + println(s"Threshold: $t, Precision: $p") + } + + // Recall by threshold + val recall = metrics.recallByThreshold + recall.foreach { case (t, r) => + println(s"Threshold: $t, Recall: $r") + } + + // Precision-Recall Curve + val PRC = metrics.pr + + // F-measure + val f1Score = metrics.fMeasureByThreshold + f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 1") + } + + val beta = 0.5 + val fScore = metrics.fMeasureByThreshold(beta) + f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 0.5") + } + + // AUPRC + val auPRC = metrics.areaUnderPR + println("Area under precision-recall curve = " + auPRC) + + // Compute thresholds used in ROC and PR curves + val thresholds = precision.map(_._1) + + // ROC Curve + val roc = metrics.roc + + // AUROC + val auROC = metrics.areaUnderROC + println("Area under ROC = " + auROC) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala new file mode 100644 index 0000000000000..4503c15360adc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.evaluation.MultilabelMetrics +import org.apache.spark.rdd.RDD +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object MultiLabelMetricsExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MultiLabelMetricsExample") + val sc = new SparkContext(conf) + // $example on$ + val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array.empty[Double], Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + + // Instantiate metrics object + val metrics = new MultilabelMetrics(scoreAndLabels) + + // Summary stats + println(s"Recall = ${metrics.recall}") + println(s"Precision = ${metrics.precision}") + println(s"F1 measure = ${metrics.f1Measure}") + println(s"Accuracy = ${metrics.accuracy}") + + // Individual label stats + metrics.labels.foreach(label => + println(s"Class $label precision = ${metrics.precision(label)}")) + metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) + metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) + + // Micro stats + println(s"Micro recall = ${metrics.microRecall}") + println(s"Micro precision = ${metrics.microPrecision}") + println(s"Micro F1 measure = ${metrics.microF1Measure}") + + // Hamming loss + println(s"Hamming loss = ${metrics.hammingLoss}") + + // Subset accuracy + println(s"Subset accuracy = ${metrics.subsetAccuracy}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala new file mode 100644 index 0000000000000..0904449245989 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object MulticlassMetricsExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MulticlassMetricsExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load training data in LIBSVM format + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + + // Split data into training (60%) and test (40%) + val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) + training.cache() + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training) + + // Compute raw scores on the test set + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Instantiate metrics object + val metrics = new MulticlassMetrics(predictionAndLabels) + + // Confusion matrix + println("Confusion matrix:") + println(metrics.confusionMatrix) + + // Overall Statistics + val precision = metrics.precision + val recall = metrics.recall // same as true positive rate + val f1Score = metrics.fMeasure + println("Summary Statistics") + println(s"Precision = $precision") + println(s"Recall = $recall") + println(s"F1 Score = $f1Score") + + // Precision by label + val labels = metrics.labels + labels.foreach { l => + println(s"Precision($l) = " + metrics.precision(l)) + } + + // Recall by label + labels.foreach { l => + println(s"Recall($l) = " + metrics.recall(l)) + } + + // False positive rate by label + labels.foreach { l => + println(s"FPR($l) = " + metrics.falsePositiveRate(l)) + } + + // F-measure by label + labels.foreach { l => + println(s"F1-Score($l) = " + metrics.fMeasure(l)) + } + + // Weighted stats + println(s"Weighted precision: ${metrics.weightedPrecision}") + println(s"Weighted recall: ${metrics.weightedRecall}") + println(s"Weighted F1 score: ${metrics.weightedFMeasure}") + println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala new file mode 100644 index 0000000000000..cffa03d5cc9f4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.recommendation.{ALS, Rating} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} + +object RankingMetricsExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("RankingMetricsExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // $example on$ + // Read in the ratings data + val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) + }.cache() + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + val binarizedRatings = ratings.map(r => Rating(r.user, r.product, + if (r.rating > 0) 1.0 else 0.0)).cache() + + // Summarize ratings + val numRatings = ratings.count() + val numUsers = ratings.map(_.user).distinct().count() + val numMovies = ratings.map(_.product).distinct().count() + println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + + // Build the model + val numIterations = 10 + val rank = 10 + val lambda = 0.01 + val model = ALS.train(ratings, rank, numIterations, lambda) + + // Define a function to scale ratings from 0 to 1 + def scaledRating(r: Rating): Rating = { + val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) + Rating(r.user, r.product, scaledRating) + } + + // Get sorted top ten predictions for each user and then scale from [0, 1] + val userRecommended = model.recommendProductsForUsers(10).map { case (user, recs) => + (user, recs.map(scaledRating)) + } + + // Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document + // Compare with top ten most relevant documents + val userMovies = binarizedRatings.groupBy(_.user) + val relevantDocuments = userMovies.join(userRecommended).map { case (user, (actual, + predictions)) => + (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) + } + + // Instantiate metrics object + val metrics = new RankingMetrics(relevantDocuments) + + // Precision at K + Array(1, 3, 5).foreach { k => + println(s"Precision at $k = ${metrics.precisionAt(k)}") + } + + // Mean average precision + println(s"Mean average precision = ${metrics.meanAveragePrecision}") + + // Normalized discounted cumulative gain + Array(1, 3, 5).foreach { k => + println(s"NDCG at $k = ${metrics.ndcgAt(k)}") + } + + // Get predictions for each data point + val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, + r.product), r.rating)) + val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) + val predictionsAndLabels = allPredictions.join(allRatings).map { case ((user, product), + (predicted, actual)) => + (predicted, actual) + } + + // Get the RMSE using regression metrics + val regressionMetrics = new RegressionMetrics(predictionsAndLabels) + println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") + + // R-squared + println(s"R-squared = ${regressionMetrics.r2}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala new file mode 100644 index 0000000000000..47d44532521ca --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println + +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object RegressionMetricsExample { + def main(args: Array[String]) : Unit = { + val conf = new SparkConf().setAppName("RegressionMetricsExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + // Load the data + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + + // Build the model + val numIterations = 100 + val model = LinearRegressionWithSGD.train(data, numIterations) + + // Get predictions + val valuesAndPreds = data.map{ point => + val prediction = model.predict(point.features) + (prediction, point.label) + } + + // Instantiate metrics object + val metrics = new RegressionMetrics(valuesAndPreds) + + // Squared error + println(s"MSE = ${metrics.meanSquaredError}") + println(s"RMSE = ${metrics.rootMeanSquaredError}") + + // R-squared + println(s"R-squared = ${metrics.r2}") + + // Mean absolute error + println(s"MAE = ${metrics.meanAbsoluteError}") + + // Explained variance + println(s"Explained variance = ${metrics.explainedVariance}") + // $example off$ + } +} +// scalastyle:on println + From cf1a80dfb2fb5139aa7b98b92023aa502ae22d62 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 20 Nov 2015 15:30:53 -0800 Subject: [PATCH 0861/1168] [SPARK-11787][SPARK-11883][SQL][FOLLOW-UP] Cleanup for this patch. This mainly moves SqlNewHadoopRDD to the sql package. There is some state that is shared between core and I've left that in core. This allows some other associated minor cleanup. Author: Nong Li Closes #9845 from nongli/spark-11787. --- .../org/apache/spark/rdd/HadoopRDD.scala | 6 +- .../spark/rdd/SqlNewHadoopRDDState.scala | 41 +++++++++++++ .../sql/catalyst/expressions/UnsafeRow.java | 59 ++++++++++++++---- .../catalyst/expressions/InputFileName.scala | 6 +- .../parquet/UnsafeRowParquetRecordReader.java | 14 +++++ .../scala/org/apache/spark/sql/SQLConf.scala | 5 ++ .../datasources}/SqlNewHadoopRDD.scala | 60 +++++++------------ .../datasources/parquet/ParquetRelation.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 43 ++++++------- .../datasources/parquet/ParquetIOSuite.scala | 19 ++++++ 10 files changed, 175 insertions(+), 80 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala rename {core/src/main/scala/org/apache/spark/rdd => sql/core/src/main/scala/org/apache/spark/sql/execution/datasources}/SqlNewHadoopRDD.scala (86%) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 7db583468792e..f37c95bedc0a5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -215,8 +215,8 @@ class HadoopRDD[K, V]( // Sets the thread local variable for the file's name split.inputSplit.value match { - case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDD.unsetInputFileName() + case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDDState.unsetInputFileName() } // Find a function that will return the FileSystem bytes read by this thread. Do this before @@ -256,7 +256,7 @@ class HadoopRDD[K, V]( override def close() { if (reader != null) { - SqlNewHadoopRDD.unsetInputFileName() + SqlNewHadoopRDDState.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala new file mode 100644 index 0000000000000..3f15fff793661 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.unsafe.types.UTF8String + +/** + * State for SqlNewHadoopRDD objects. This is split this way because of the package splits. + * TODO: Move/Combine this with org.apache.spark.sql.datasources.SqlNewHadoopRDD + */ +private[spark] object SqlNewHadoopRDDState { + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 33769363a0ed5..b6979d0c82977 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,7 +17,11 @@ package org.apache.spark.sql.catalyst.expressions; -import java.io.*; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.io.OutputStream; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -26,12 +30,26 @@ import java.util.HashSet; import java.util.Set; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; - -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.CalendarIntervalType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.NullType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.UserDefinedType; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; @@ -39,9 +57,23 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.sql.types.DataTypes.*; +import static org.apache.spark.sql.types.DataTypes.BooleanType; +import static org.apache.spark.sql.types.DataTypes.ByteType; +import static org.apache.spark.sql.types.DataTypes.DateType; +import static org.apache.spark.sql.types.DataTypes.DoubleType; +import static org.apache.spark.sql.types.DataTypes.FloatType; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.LongType; +import static org.apache.spark.sql.types.DataTypes.NullType; +import static org.apache.spark.sql.types.DataTypes.ShortType; +import static org.apache.spark.sql.types.DataTypes.TimestampType; import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. * @@ -116,11 +148,6 @@ public static boolean isMutable(DataType dt) { /** The size of this row's backing data, in bytes) */ private int sizeInBytes; - private void setNotNullAt(int i) { - assertIndexIsValid(i); - BitSetMethods.unset(baseObject, baseOffset, i); - } - /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; @@ -187,6 +214,12 @@ public void pointTo(byte[] buf, int sizeInBytes) { pointTo(buf, numFields, sizeInBytes); } + + public void setNotNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.unset(baseObject, baseOffset, i); + } + @Override public void setNullAt(int i) { assertIndexIsValid(i); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index d809877817a5b..bf215783fc27d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.rdd.SqlNewHadoopRDD +import org.apache.spark.rdd.SqlNewHadoopRDDState import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{DataType, StringType} @@ -37,13 +37,13 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override protected def initInternal(): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { - SqlNewHadoopRDD.getInputFileName() + SqlNewHadoopRDDState.getInputFileName() } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { ev.isNull = "false" s"final ${ctx.javaType(dataType)} ${ev.value} = " + - "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();" + "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 8a92e489ccb7c..dade488ca281b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -108,6 +108,19 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas */ private static final int DEFAULT_VAR_LEN_SIZE = 32; + /** + * Tries to initialize the reader for this split. Returns true if this reader supports reading + * this split and false otherwise. + */ + public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) { + try { + initialize(inputSplit, taskAttemptContext); + return true; + } catch (Exception e) { + return false; + } + } + /** * Implementation of RecordReader API. */ @@ -326,6 +339,7 @@ private void decodeBinaryBatch(int col, int num) throws IOException { } else { rowWriters[n].write(col, bytes.array(), bytes.position(), len); } + rows[n].setNotNullAt(col); } else { rows[n].setNullAt(col); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f40e603cd1939..5ef3a48c56a87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -323,6 +323,11 @@ private[spark] object SQLConf { "option must be set in Hadoop Configuration. 2. This option overrides " + "\"spark.sql.sources.outputCommitterClass\".") + val PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED = booleanConf( + key = "spark.sql.parquet.enableUnsafeRowRecordReader", + defaultValue = Some(true), + doc = "Enables using the custom ParquetUnsafeRowRecordReader.") + val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", defaultValue = Some(false), doc = "When true, enable filter pushdown for ORC files.") diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala similarity index 86% rename from core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 4d176332b69ce..56cb63d9eff2a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -20,6 +20,8 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -28,13 +30,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader import org.apache.spark.storage.StorageLevel -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.{Utils, SerializableConfiguration, ShutdownHookManager} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} import org.apache.spark.{Partition => SparkPartition, _} -import scala.reflect.ClassTag - private[spark] class SqlNewHadoopPartition( rddId: Int, @@ -61,13 +62,13 @@ private[spark] class SqlNewHadoopPartition( * changes based on [[org.apache.spark.rdd.HadoopRDD]]. */ private[spark] class SqlNewHadoopRDD[V: ClassTag]( - sc : SparkContext, + sqlContext: SQLContext, broadcastedConf: Broadcast[SerializableConfiguration], @transient private val initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) - extends RDD[V](sc, Nil) + extends RDD[V](sqlContext.sparkContext, Nil) with SparkHadoopMapReduceUtil with Logging { @@ -99,7 +100,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // If true, enable using the custom RecordReader for parquet. This only works for // a subset of the types (no complex types). protected val enableUnsafeRowParquetReader: Boolean = - sc.conf.getBoolean("spark.parquet.enableUnsafeRowRecordReader", true) + sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) @@ -120,8 +121,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } override def compute( - theSplit: SparkPartition, - context: TaskContext): Iterator[V] = { + theSplit: SparkPartition, + context: TaskContext): Iterator[V] = { val iter = new Iterator[V] { val split = theSplit.asInstanceOf[SqlNewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) @@ -132,8 +133,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // Sets the thread local variable for the file's name split.serializableHadoopSplit.value match { - case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDD.unsetInputFileName() + case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDDState.unsetInputFileName() } // Find a function that will return the FileSystem bytes read by this thread. Do this before @@ -163,15 +164,13 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( * TODO: plumb this through a different way? */ if (enableUnsafeRowParquetReader && - format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { - // TODO: move this class to sql.execution and remove this. - reader = Utils.classForName( - "org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader") - .newInstance().asInstanceOf[RecordReader[Void, V]] - try { - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - } catch { - case e: Exception => reader = null + format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { + val parquetReader: UnsafeRowParquetRecordReader = new UnsafeRowParquetRecordReader() + if (!parquetReader.tryInitialize( + split.serializableHadoopSplit.value, hadoopAttemptContext)) { + parquetReader.close() + } else { + reader = parquetReader.asInstanceOf[RecordReader[Void, V]] } } @@ -217,7 +216,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( private def close() { if (reader != null) { - SqlNewHadoopRDD.unsetInputFileName() + SqlNewHadoopRDDState.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic @@ -235,7 +234,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { @@ -276,23 +275,6 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } super.persist(storageLevel) } -} - -private[spark] object SqlNewHadoopRDD { - - /** - * The thread variable for the name of the current file being read. This is used by - * the InputFileName function in Spark SQL. - */ - private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { - override protected def initialValue(): UTF8String = UTF8String.fromString("") - } - - def getInputFileName(): UTF8String = inputFileName.get() - - private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) - - private[spark] def unsetInputFileName(): Unit = inputFileName.remove() /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index cb0aab8cc0d09..fdd745f48e973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -319,7 +319,7 @@ private[sql] class ParquetRelation( Utils.withDummyCallSite(sqlContext.sparkContext) { new SqlNewHadoopRDD( - sc = sqlContext.sparkContext, + sqlContext = sqlContext, broadcastedConf = broadcastedConf, initDriverSideJobFuncOpt = Some(setInputPaths), initLocalJobFuncOpt = Some(initLocalJobFuncOpt), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index c8028a5ef5528..cc5aae03d5516 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -337,29 +337,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - // Renable when we can toggle custom ParquetRecordReader on/off. The custom reader does - // not do row by row filtering (and we probably don't want to push that). - ignore("SPARK-11661 Still pushdown filters returned by unhandledFilters") { + // The unsafe row RecordReader does not support row by row filtering so run it with it disabled. + test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withTempPath { dir => - val path = s"${dir.getCanonicalPath}/part=1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) - val df = sqlContext.read.parquet(path).filter("a = 2") - - // This is the source RDD without Spark-side filtering. - val childRDD = - df - .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] - .child - .execute() - - // The result should be single row. - // When a filter is pushed to Parquet, Parquet can apply it to every row. - // So, we can check the number of rows returned from the Parquet - // to make sure our filter pushdown work. - assert(childRDD.count == 1) + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) + val df = sqlContext.read.parquet(path).filter("a = 2") + + // This is the source RDD without Spark-side filtering. + val childRDD = + df + .queryExecution + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .child + .execute() + + // The result should be single row. + // When a filter is pushed to Parquet, Parquet can apply it to every row. + // So, we can check the number of rows returned from the Parquet + // to make sure our filter pushdown work. + assert(childRDD.count == 1) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 177ab42f7767c..0c5d4887ed799 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -579,6 +579,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("null and non-null strings") { + // Create a dataset where the first values are NULL and then some non-null values. The + // number of non-nulls needs to be bigger than the ParquetReader batch size. + val data = sqlContext.range(200).map { i => + if (i.getLong(0) < 150) Row(None) + else Row("a") + } + val df = sqlContext.createDataFrame(data, StructType(StructField("col", StringType) :: Nil)) + assert(df.agg("col" -> "count").collect().head.getLong(0) == 50) + + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/data" + df.write.parquet(path) + + val df2 = sqlContext.read.parquet(path) + assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50) + } + } + test("read dictionary encoded decimals written as INT32") { checkAnswer( // Decimal column in this file is encoded using plain dictionary From fbfea3b6ec8f513a8b62cd437e37cc639956cac1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 15:36:30 -0800 Subject: [PATCH 0862/1168] [SPARK-11889][SQL] Fix type inference for GroupedDataset.agg in REPL In this PR I delete a method that breaks type inference for aggregators (only in the REPL) The error when this method is present is: ``` :38: error: missing parameter type for expanded function ((x$2) => x$2._2) ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() ``` Author: Michael Armbrust Closes #9870 from marmbrus/dataset-repl-agg. --- .../org/apache/spark/repl/ReplSuite.scala | 24 +++++++++++++++++ .../org/apache/spark/sql/GroupedDataset.scala | 27 +++---------------- .../apache/spark/sql/JavaDatasetSuite.java | 8 +++--- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 081aa03002cc6..cbcccb11f14ae 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -339,6 +339,30 @@ class ReplSuite extends SparkFunSuite { } } + test("Datasets agg type-inference") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |/** An `Aggregator` that adds up any numeric type returned by the given function. */ + |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { + | val numeric = implicitly[Numeric[N]] + | override def zero: N = numeric.zero + | override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2) + | override def finish(reduction: N): N = reduction + |} + | + |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn + |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS() + |ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("collecting objects of class defined in repl") { val output = runInterpreter("local[2]", """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 6de3dd626576a..263f049104762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -146,31 +146,10 @@ class GroupedDataset[K, T] private[sql]( reduce(f.call _) } - /** - * Compute aggregates by specifying a series of aggregate columns, and return a [[DataFrame]]. - * We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to [[Dataset]] again. - * - * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. - * - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * - * // Scala: - * import org.apache.spark.sql.functions._ - * df.groupBy("department").agg(max("age"), sum("expense")) - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df.groupBy("department").agg(max("age"), sum("expense")); - * }}} - * - * We can also use `Aggregator.toColumn` to pass in typed aggregate functions. - * - * @since 1.6.0 - */ + // This is here to prevent us from adding overloads that would be ambiguous. @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = - groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*) + private def agg(exprs: Column*): DataFrame = + groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*) private def withEncoder(c: Column): Column = c match { case tc: TypedColumn[_, _] => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index ce40dd856f679..f7249b8945c49 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -404,11 +404,9 @@ public String call(Tuple2 value) throws Exception { grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); - Dataset> agged2 = grouped.agg( - new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()), - expr("sum(_2)"), - count("*")) - .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG())); + Dataset> agged2 = grouped.agg( + new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) + .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); Assert.assertEquals( Arrays.asList( new Tuple4<>("a", 3, 3L, 2L), From 4a7083801e89e0730a38e2b72fc342759264a17a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 15:38:04 -0800 Subject: [PATCH 0863/1168] [SPARK-11890][SQL] Fix compilation for Scala 2.11 Author: Michael Armbrust Closes #9871 from marmbrus/scala211-break. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 918050b531c02..4a4a62ed1a468 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -670,14 +670,14 @@ trait ScalaReflection { * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return * `NullType` silently instead. */ - private def silentSchemaFor(tpe: `Type`): Schema = try { + protected def silentSchemaFor(tpe: `Type`): Schema = try { schemaFor(tpe) } catch { case _: UnsupportedOperationException => Schema(NullType, nullable = true) } /** Returns the full class name for a type. */ - private def getClassNameFromType(tpe: `Type`): String = { + protected def getClassNameFromType(tpe: `Type`): String = { tpe.erasure.typeSymbol.asClass.fullName } From f8f179cc77b11959d0383336b39c06643c2fbb12 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 16:02:03 -0800 Subject: [PATCH 0864/1168] [HOTFIX] Fix Java Dataset Tests --- .../test/java/test/org/apache/spark/sql/JavaDatasetSuite.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f7249b8945c49..f32374b4c04df 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -409,8 +409,8 @@ public String call(Tuple2 value) throws Exception { .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); Assert.assertEquals( Arrays.asList( - new Tuple4<>("a", 3, 3L, 2L), - new Tuple4<>("b", 3, 3L, 1L)), + new Tuple2<>("a", 3), + new Tuple2<>("b", 3)), agged2.collectAsList()); } From 0a14df2d80ef352f2c11701c35275db22ed9d8bc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 20 Nov 2015 16:51:47 -0800 Subject: [PATCH 0865/1168] Revert "[SPARK-11689][ML] Add user guide and example code for LDA under spark.ml" This reverts commit e359d5dcf5bd300213054ebeae9fe75c4f7eb9e7. --- docs/ml-clustering.md | 30 ------ docs/ml-guide.md | 3 +- docs/mllib-guide.md | 1 - .../spark/examples/ml/JavaLDAExample.java | 94 ------------------- .../apache/spark/examples/ml/LDAExample.scala | 77 --------------- 5 files changed, 1 insertion(+), 204 deletions(-) delete mode 100644 docs/ml-clustering.md delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md deleted file mode 100644 index 1743ef43a6ddf..0000000000000 --- a/docs/ml-clustering.md +++ /dev/null @@ -1,30 +0,0 @@ ---- -layout: global -title: Clustering - ML -displayTitle: ML - Clustering ---- - -In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). - -## Latent Dirichlet allocation (LDA) - -`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, -and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by -`EMLDAOptimizer` to a `DistributedLDAModel` if needed. - -
    - -Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. - -
    -{% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} -
    - -
    - -Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. - -{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} -
    - -
    \ No newline at end of file diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 6f35b30c3d4df..be18a05361a17 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -40,7 +40,6 @@ Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., provide class probabilities, and linear models provide model summaries. * [Feature extraction, transformation, and selection](ml-features.html) -* [Clustering](ml-clustering.html) * [Decision Trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) @@ -951,4 +950,4 @@ model.transform(test) {% endhighlight %} - \ No newline at end of file + diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 54e35fcbb15af..91e50ccfecec4 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -69,7 +69,6 @@ We list major functionality from both below, with links to detailed guides. concepts. It also contains sections on using algorithms within the Pipelines API, for example: * [Feature extraction, transformation, and selection](ml-features.html) -* [Clustering](ml-clustering.html) * [Decision trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java deleted file mode 100644 index b3a7d2eb29780..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import java.util.regex.Pattern; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.ml.clustering.LDA; -import org.apache.spark.ml.clustering.LDAModel; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -/** - * An example demonstrating LDA - * Run with - *
    - * bin/run-example ml.JavaLDAExample
    - * 
    - */ -public class JavaLDAExample { - - private static class ParseVector implements Function { - private static final Pattern separator = Pattern.compile(" "); - - @Override - public Row call(String line) { - String[] tok = separator.split(line); - double[] point = new double[tok.length]; - for (int i = 0; i < tok.length; ++i) { - point[i] = Double.parseDouble(tok[i]); - } - Vector[] points = {Vectors.dense(point)}; - return new GenericRow(points); - } - } - - public static void main(String[] args) { - - String inputFile = "data/mllib/sample_lda_data.txt"; - - // Parses the arguments - SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // Loads data - JavaRDD points = jsc.textFile(inputFile).map(new ParseVector()); - StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; - StructType schema = new StructType(fields); - DataFrame dataset = sqlContext.createDataFrame(points, schema); - - // Trains a LDA model - LDA lda = new LDA() - .setK(10) - .setMaxIter(10); - LDAModel model = lda.fit(dataset); - - System.out.println(model.logLikelihood(dataset)); - System.out.println(model.logPerplexity(dataset)); - - // Shows the result - DataFrame topics = model.describeTopics(3); - topics.show(false); - model.transform(dataset).show(false); - - jsc.stop(); - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala deleted file mode 100644 index 419ce3d87a6ac..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml - -// scalastyle:off println -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} -// $example on$ -import org.apache.spark.ml.clustering.LDA -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.types.{StructField, StructType} -// $example off$ - -/** - * An example demonstrating a LDA of ML pipeline. - * Run with - * {{{ - * bin/run-example ml.LDAExample - * }}} - */ -object LDAExample { - - final val FEATURES_COL = "features" - - def main(args: Array[String]): Unit = { - - val input = "data/mllib/sample_lda_data.txt" - // Creates a Spark context and a SQL context - val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - // Loads data - val rowRDD = sc.textFile(input).filter(_.nonEmpty) - .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) - val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) - val dataset = sqlContext.createDataFrame(rowRDD, schema) - - // Trains a LDA model - val lda = new LDA() - .setK(10) - .setMaxIter(10) - .setFeaturesCol(FEATURES_COL) - val model = lda.fit(dataset) - val transformed = model.transform(dataset) - - val ll = model.logLikelihood(dataset) - val lp = model.logPerplexity(dataset) - - // describeTopics - val topics = model.describeTopics(3) - - // Shows the result - topics.show(false) - transformed.show(false) - - // $example off$ - sc.stop() - } -} -// scalastyle:on println From 71317b0cd33cd5fbb5eb7b4be8fa39062193f8a3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Nov 2015 23:31:19 -0800 Subject: [PATCH 0866/1168] [SPARK-11819][SQL][FOLLOW-UP] fix scala 2.11 build seems scala 2.11 doesn't support: define private methods in `trait xxx` and use it in `object xxx extend xxx`. Author: Wenchen Fan Closes #9879 from cloud-fan/follow. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4a4a62ed1a468..476becec4dd52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -670,14 +670,14 @@ trait ScalaReflection { * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return * `NullType` silently instead. */ - protected def silentSchemaFor(tpe: `Type`): Schema = try { + def silentSchemaFor(tpe: `Type`): Schema = try { schemaFor(tpe) } catch { case _: UnsupportedOperationException => Schema(NullType, nullable = true) } /** Returns the full class name for a type. */ - protected def getClassNameFromType(tpe: `Type`): String = { + def getClassNameFromType(tpe: `Type`): String = { tpe.erasure.typeSymbol.asClass.fullName } From 40e0d7a0b96beabc2d34ce553844e4b2c054a94a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 21 Nov 2015 00:10:13 -0800 Subject: [PATCH 0867/1168] [SPARK-11900][SQL] Add since version for all encoders Author: Reynold Xin Closes #9881 from rxin/SPARK-11900. --- .../scala/org/apache/spark/sql/Encoder.scala | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 86bb536459035..5cb8edf64e87c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -45,13 +45,52 @@ trait Encoder[T] extends Serializable { */ object Encoders { + /** + * An encoder for nullable boolean type. + * @since 1.6.0 + */ def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() + + /** + * An encoder for nullable byte type. + * @since 1.6.0 + */ def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() + + /** + * An encoder for nullable short type. + * @since 1.6.0 + */ def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() + + /** + * An encoder for nullable int type. + * @since 1.6.0 + */ def INT: Encoder[java.lang.Integer] = ExpressionEncoder() + + /** + * An encoder for nullable long type. + * @since 1.6.0 + */ def LONG: Encoder[java.lang.Long] = ExpressionEncoder() + + /** + * An encoder for nullable float type. + * @since 1.6.0 + */ def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() + + /** + * An encoder for nullable double type. + * @since 1.6.0 + */ def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() + + /** + * An encoder for nullable string type. + * @since 1.6.0 + */ def STRING: Encoder[java.lang.String] = ExpressionEncoder() /** @@ -59,6 +98,8 @@ object Encoders { * This encoder maps T into a single byte array (binary) field. * * T must be publicly accessible. + * + * @since 1.6.0 */ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) @@ -67,6 +108,8 @@ object Encoders { * This encoder maps T into a single byte array (binary) field. * * T must be publicly accessible. + * + * @since 1.6.0 */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) @@ -77,6 +120,8 @@ object Encoders { * Note that this is extremely inefficient and should only be used as the last resort. * * T must be publicly accessible. + * + * @since 1.6.0 */ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) @@ -87,6 +132,8 @@ object Encoders { * Note that this is extremely inefficient and should only be used as the last resort. * * T must be publicly accessible. + * + * @since 1.6.0 */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) @@ -120,12 +167,20 @@ object Encoders { ) } + /** + * An encoder for 2-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2]( e1: Encoder[T1], e2: Encoder[T2]): Encoder[(T1, T2)] = { ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2)) } + /** + * An encoder for 3-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2, T3]( e1: Encoder[T1], e2: Encoder[T2], @@ -133,6 +188,10 @@ object Encoders { ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3)) } + /** + * An encoder for 4-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2, T3, T4]( e1: Encoder[T1], e2: Encoder[T2], @@ -141,6 +200,10 @@ object Encoders { ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4)) } + /** + * An encoder for 5-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2, T3, T4, T5]( e1: Encoder[T1], e2: Encoder[T2], From a42cb6208c9af3a40068fd8ef8dbb73567dc5111 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 21 Nov 2015 00:54:18 -0800 Subject: [PATCH 0868/1168] [SPARK-11901][SQL] API audit for Aggregator. Author: Reynold Xin Closes #9882 from rxin/SPARK-11901. --- .../scala/org/apache/spark/sql/Dataset.scala | 1 - .../spark/sql/expressions/Aggregator.scala | 39 ++++++++++++------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bdcdc5d47cbae..07647508421a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -22,7 +22,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 72610e735f782..b0cd32b5f73e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} /** * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] @@ -32,55 +31,65 @@ import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} * case class Data(i: Int) * * val customSummer = new Aggregator[Data, Int, Int] { - * def zero = 0 - * def reduce(b: Int, a: Data) = b + a.i - * def present(r: Int) = r + * def zero: Int = 0 + * def reduce(b: Int, a: Data): Int = b + a.i + * def merge(b1: Int, b2: Int): Int = b1 + b2 + * def present(r: Int): Int = r * }.toColumn() * - * val ds: Dataset[Data] + * val ds: Dataset[Data] = ... * val aggregated = ds.select(customSummer) * }}} * * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird * - * @tparam A The input type for the aggregation. + * @tparam I The input type for the aggregation. * @tparam B The type of the intermediate value of the reduction. - * @tparam C The type of the final result. + * @tparam O The type of the final output result. + * + * @since 1.6.0 */ -abstract class Aggregator[-A, B, C] extends Serializable { +abstract class Aggregator[-I, B, O] extends Serializable { - /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + /** + * A zero value for this aggregation. Should satisfy the property that any b + zero = b. + * @since 1.6.0 + */ def zero: B /** * Combine two values to produce a new value. For performance, the function may modify `b` and * return it instead of constructing new object for b. + * @since 1.6.0 */ - def reduce(b: B, a: A): B + def reduce(b: B, a: I): B /** - * Merge two intermediate values + * Merge two intermediate values. + * @since 1.6.0 */ def merge(b1: B, b2: B): B /** * Transform the output of the reduction. + * @since 1.6.0 */ - def finish(reduction: B): C + def finish(reduction: B): O /** * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]] * operations. + * @since 1.6.0 */ def toColumn( implicit bEncoder: Encoder[B], - cEncoder: Encoder[C]): TypedColumn[A, C] = { + cEncoder: Encoder[O]): TypedColumn[I, O] = { val expr = new AggregateExpression( TypedAggregateExpression(this), Complete, false) - new TypedColumn[A, C](expr, encoderFor[C]) + new TypedColumn[I, O](expr, encoderFor[O]) } } From 61699b49ce95a4a5166ab45c90085afd38c691c1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 21 Nov 2015 15:00:37 -0800 Subject: [PATCH 0869/1168] [SPARK-11899][SQL] API audit for GroupedDataset. 1. Renamed map to mapGroup, flatMap to flatMapGroup. 2. Renamed asKey -> keyAs. 3. Added more documentation. 4. Changed type parameter T to V on GroupedDataset. 5. Added since versions for all functions. Author: Reynold Xin Closes #9880 from rxin/SPARK-11899. --- .../api/java/function/MapGroupFunction.java | 2 +- .../scala/org/apache/spark/sql/Encoder.scala | 4 + .../sql/catalyst/JavaTypeInference.scala | 3 +- .../scala/org/apache/spark/sql/Column.scala | 2 + .../org/apache/spark/sql/DataFrame.scala | 1 - .../org/apache/spark/sql/GroupedDataset.scala | 132 ++++++++++++++---- .../apache/spark/sql/JavaDatasetSuite.java | 8 +- .../spark/sql/DatasetPrimitiveSuite.scala | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 20 +-- 9 files changed, 131 insertions(+), 45 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java index 2935f9986a560..4f3f222e064bb 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java @@ -21,7 +21,7 @@ import java.util.Iterator; /** - * Base interface for a map function used in GroupedDataset's map function. + * Base interface for a map function used in GroupedDataset's mapGroup function. */ public interface MapGroupFunction extends Serializable { R call(K key, Iterator values) throws Exception; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 5cb8edf64e87c..03aa25eda807f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.types._ * * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking * and reuse internal buffers to improve performance. + * + * @since 1.6.0 */ trait Encoder[T] extends Serializable { @@ -42,6 +44,8 @@ trait Encoder[T] extends Serializable { /** * Methods for creating encoders. + * + * @since 1.6.0 */ object Encoders { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 88a457f87ce4e..7d4cfbe6faecb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ /** * Type-inference utilities for POJOs and Java collections. */ -private [sql] object JavaTypeInference { +object JavaTypeInference { private val iterableType = TypeToken.of(classOf[JIterable[_]]) private val mapType = TypeToken.of(classOf[JMap[_, _]]) @@ -53,7 +53,6 @@ private [sql] object JavaTypeInference { * @return (SQL data type, nullable) */ private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { - // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 82e9cd7f50a31..30c554a85e693 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -46,6 +46,8 @@ private[sql] object Column { * @tparam T The input type expected for this expression. Can be `Any` if the expression is type * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). * @tparam U The output type of this column. + * + * @since 1.6.0 */ class TypedColumn[-T, U]( expr: Expression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 7abcecaa2880e..5586fc994b98a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -110,7 +110,6 @@ private[sql] object DataFrame { * @groupname action Actions * @since 1.3.0 */ -// TODO: Improve documentation. @Experimental class DataFrame private[sql]( @transient val sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 263f049104762..7f43ce16901b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Ou import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.expressions.Aggregator /** * :: Experimental :: @@ -36,11 +37,13 @@ import org.apache.spark.sql.execution.QueryExecution * making this change to the class hierarchy would break some function signatures. As such, this * class should be considered a preview of the final API. Changes will be made to the interface * after Spark 1.6. + * + * @since 1.6.0 */ @Experimental -class GroupedDataset[K, T] private[sql]( +class GroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], - tEncoder: Encoder[T], + tEncoder: Encoder[V], val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { @@ -67,8 +70,10 @@ class GroupedDataset[K, T] private[sql]( /** * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. + * + * @since 1.6.0 */ - def asKey[L : Encoder]: GroupedDataset[L, T] = + def keyAs[L : Encoder]: GroupedDataset[L, V] = new GroupedDataset( encoderFor[L], unresolvedTEncoder, @@ -78,6 +83,8 @@ class GroupedDataset[K, T] private[sql]( /** * Returns a [[Dataset]] that contains each unique key. + * + * @since 1.6.0 */ def keys: Dataset[K] = { new Dataset[K]( @@ -92,12 +99,18 @@ class GroupedDataset[K, T] private[sql]( * function can return an iterator containing elements of an arbitrary type which will be returned * as a new [[Dataset]]. * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. + * + * @since 1.6.0 */ - def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = { + def flatMapGroup[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups( @@ -108,8 +121,25 @@ class GroupedDataset[K, T] private[sql]( logicalPlan)) } - def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { - flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder) + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an iterator containing elements of an arbitrary type which will be returned + * as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def flatMapGroup[U](f: FlatMapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + flatMapGroup((key, data) => f.call(key, data.asJava).asScala)(encoder) } /** @@ -117,32 +147,62 @@ class GroupedDataset[K, T] private[sql]( * be passed the group key and an iterator that contains all of the elements in the group. The * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. + * + * @since 1.6.0 */ - def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = { - val func = (key: K, it: Iterator[T]) => Iterator(f(key, it)) - flatMap(func) + def mapGroup[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) + flatMapGroup(func) } - def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { - map((key, data) => f.call(key, data.asJava))(encoder) + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def mapGroup[U](f: MapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + mapGroup((key, data) => f.call(key, data.asJava))(encoder) } /** * Reduces the elements of each group of data using the specified binary function. * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 */ - def reduce(f: (T, T) => T): Dataset[(K, T)] = { - val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f)) + def reduce(f: (V, V) => V): Dataset[(K, V)] = { + val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder) - flatMap(func) + flatMapGroup(func) } - def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = { + /** + * Reduces the elements of each group of data using the specified binary function. + * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 + */ + def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = { reduce(f.call _) } @@ -185,41 +245,51 @@ class GroupedDataset[K, T] private[sql]( /** * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key * and the result of computing this aggregation over all elements in the group. + * + * @since 1.6.0 */ - def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] = + def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 */ - def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] = + def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 */ def agg[U1, U2, U3]( - col1: TypedColumn[T, U1], - col2: TypedColumn[T, U2], - col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] = + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 */ def agg[U1, U2, U3, U4]( - col1: TypedColumn[T, U1], - col2: TypedColumn[T, U2], - col3: TypedColumn[T, U3], - col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] = + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] /** * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. + * + * @since 1.6.0 */ def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long])) @@ -228,10 +298,12 @@ class GroupedDataset[K, T] private[sql]( * be passed the grouping key and 2 iterators containing all elements in the group from * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 */ def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( - f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { + f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { implicit def uEnc: Encoder[U] = other.unresolvedTEncoder new Dataset[R]( sqlContext, @@ -243,9 +315,17 @@ class GroupedDataset[K, T] private[sql]( other.logicalPlan)) } + /** + * Applies the given function to each cogrouped data. For each unique group, the function will + * be passed the grouping key and 2 iterators containing all elements in the group from + * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an + * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 + */ def cogroup[U, R]( other: GroupedDataset[K, U], - f: CoGroupFunction[K, T, U, R], + f: CoGroupFunction[K, V, U, R], encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f32374b4c04df..cf335efdd23b8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -170,7 +170,7 @@ public Integer call(String v) throws Exception { } }, Encoders.INT()); - Dataset mapped = grouped.map(new MapGroupFunction() { + Dataset mapped = grouped.mapGroup(new MapGroupFunction() { @Override public String call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); @@ -183,7 +183,7 @@ public String call(Integer key, Iterator values) throws Exception { Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); - Dataset flatMapped = grouped.flatMap( + Dataset flatMapped = grouped.flatMapGroup( new FlatMapGroupFunction() { @Override public Iterable call(Integer key, Iterator values) throws Exception { @@ -247,9 +247,9 @@ public void testGroupByColumn() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = context.createDataset(data, Encoders.STRING()); GroupedDataset grouped = - ds.groupBy(length(col("value"))).asKey(Encoders.INT()); + ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); - Dataset mapped = grouped.map( + Dataset mapped = grouped.mapGroup( new MapGroupFunction() { @Override public String call(Integer key, Iterator data) throws Exception { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 63b00975e4eb1..d387710357be0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -86,7 +86,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() val grouped = ds.groupBy(_ % 2) - val agged = grouped.map { case (g, iter) => + val agged = grouped.mapGroup { case (g, iter) => val name = if (g == 0) "even" else "odd" (name, iter.size) } @@ -99,7 +99,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq("a", "b", "c", "xyz", "hello").toDS() val grouped = ds.groupBy(_.length) - val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) } + val agged = grouped.flatMapGroup { case (g, iter) => Iterator(g.toString, iter.mkString) } checkAnswer( agged, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 89d964aa3e469..9da02550b39ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -224,7 +224,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) } + val agged = grouped.mapGroup { case (g, iter) => (g._1, iter.map(_._2).sum) } checkAnswer( agged, @@ -234,7 +234,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } + val agged = grouped.flatMapGroup { case (g, iter) => + Iterator(g._1, iter.map(_._2).sum.toString) + } checkAnswer( agged, @@ -253,7 +255,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") - val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } + val agged = grouped.mapGroup { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } checkAnswer( agged, @@ -262,8 +264,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1").asKey[String] - val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } + val grouped = ds.groupBy($"_1").keyAs[String] + val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -272,8 +274,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)] - val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } + val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)] + val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -282,8 +284,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData] - val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } + val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] + val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, From db1f66a8ebe60b8d2d1a54148871e16aace586fd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 22 Nov 2015 10:36:47 -0800 Subject: [PATCH 0870/1168] [SPARK-11908][SQL] Add NullType support to RowEncoder JIRA: https://issues.apache.org/jira/browse/SPARK-11908 We should add NullType support to RowEncoder. Author: Liang-Chi Hsieh Closes #9891 from viirya/rowencoder-nulltype. --- .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 5 +++-- .../org/apache/spark/sql/catalyst/expressions/objects.scala | 3 +++ .../apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 4cda4824acdc3..fa553e7c5324c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -48,7 +48,7 @@ object RowEncoder { private def extractorsFor( inputObject: Expression, inputType: DataType): Expression = inputType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => inputObject case udt: UserDefinedType[_] => @@ -143,6 +143,7 @@ object RowEncoder { case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) case udt: UserDefinedType[_] => ObjectType(udt.userClass) + case _: NullType => ObjectType(classOf[java.lang.Object]) } private def constructorFor(schema: StructType): Expression = { @@ -158,7 +159,7 @@ object RowEncoder { } private def constructorFor(input: Expression): Expression = input.dataType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => input case udt: UserDefinedType[_] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index ef7399e0196ab..82317d3385167 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -369,6 +369,9 @@ case class MapObjects( private lazy val completeFunction = function(loopAttribute) private def itemAccessorMethod(dataType: DataType): String => String = dataType match { + case NullType => + val nullTypeClassName = NullType.getClass.getName + ".MODULE$" + (i: String) => s".get($i, $nullTypeClassName)" case IntegerType => (i: String) => s".getInt($i)" case LongType => (i: String) => s".getLong($i)" case FloatType => (i: String) => s".getFloat($i)" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 46c6e0d98d349..0ea51ece4bc5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -80,11 +80,13 @@ class RowEncoderSuite extends SparkFunSuite { private val structOfString = new StructType().add("str", StringType) private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) private val arrayOfString = ArrayType(StringType) + private val arrayOfNull = ArrayType(NullType) private val mapOfString = MapType(StringType, StringType) private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) encodeDecodeTest( new StructType() + .add("null", NullType) .add("boolean", BooleanType) .add("byte", ByteType) .add("short", ShortType) @@ -101,6 +103,7 @@ class RowEncoderSuite extends SparkFunSuite { encodeDecodeTest( new StructType() + .add("arrayOfNull", arrayOfNull) .add("arrayOfString", arrayOfString) .add("arrayOfArrayOfString", ArrayType(arrayOfString)) .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) From 1a2135760cd4f23a0fa068359b9909b504c040ae Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 22 Nov 2015 21:45:46 -0800 Subject: [PATCH 0871/1168] [SPARK-11895][ML] rename and refactor DatasetExample under mllib/examples We used the name `Dataset` to refer to `SchemaRDD` in 1.2 in ML pipelines and created this example file. Since `Dataset` has a new meaning in Spark 1.6, we should rename it to avoid confusion. This PR also removes support for dense format to simplify the example code. cc: yinxusen Author: Xiangrui Meng Closes #9873 from mengxr/SPARK-11895. --- .../DataFrameExample.scala} | 71 +++++++------------ 1 file changed, 26 insertions(+), 45 deletions(-) rename examples/src/main/scala/org/apache/spark/examples/{mllib/DatasetExample.scala => ml/DataFrameExample.scala} (51%) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala similarity index 51% rename from examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index dc13f82488af7..424f00158c2f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -16,7 +16,7 @@ */ // scalastyle:off println -package org.apache.spark.examples.mllib +package org.apache.spark.examples.ml import java.io.File @@ -24,25 +24,22 @@ import com.google.common.io.Files import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, DataFrame} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** - * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with + * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] + * ./bin/run-example ml.DataFrameExample [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ -object DatasetExample { +object DataFrameExample { - case class Params( - input: String = "data/mllib/sample_libsvm_data.txt", - dataFormat: String = "libsvm") extends AbstractParams[Params] + case class Params(input: String = "data/mllib/sample_libsvm_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -52,9 +49,6 @@ object DatasetExample { opt[String]("input") .text(s"input path to dataset") .action((x, c) => c.copy(input = x)) - opt[String]("dataFormat") - .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") - .action((x, c) => c.copy(input = x)) checkConfig { params => success } @@ -69,55 +63,42 @@ object DatasetExample { def run(params: Params) { - val conf = new SparkConf().setAppName(s"DatasetExample with $params") + val conf = new SparkConf().setAppName(s"DataFrameExample with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ // for implicit conversions // Load input data - val origData: RDD[LabeledPoint] = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) - } - println(s"Loaded ${origData.count()} instances from file: ${params.input}") - - // Convert input data to DataFrame explicitly. - val df: DataFrame = origData.toDF() - println(s"Inferred schema:\n${df.schema.prettyJson}") - println(s"Converted to DataFrame with ${df.count()} records") - - // Select columns - val labelsDf: DataFrame = df.select("label") - val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } - val numLabels = labels.count() - val meanLabel = labels.fold(0.0)(_ + _) / numLabels - println(s"Selected label column with average value $meanLabel") - - val featuresDf: DataFrame = df.select("features") - val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } + println(s"Loading LIBSVM file with UDT from ${params.input}.") + val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache() + println("Schema from LIBSVM:") + df.printSchema() + println(s"Loaded training data as a DataFrame with ${df.count()} records.") + + // Show statistical summary of labels. + val labelSummary = df.describe("label") + labelSummary.show() + + // Convert features column to an RDD of vectors. + val features = df.select("features").map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + // Save the records in a parquet file. val tmpDir = Files.createTempDir() tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") df.write.parquet(outputDir) + // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.read.parquet(outputDir) - - println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } - val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( - (summary, feat) => summary.add(feat), - (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") + val newDF = sqlContext.read.parquet(outputDir) + println(s"Schema from Parquet:") + newDF.printSchema() sc.stop() } - } // scalastyle:on println From c66a348432b22fa7b483f855fbba5213129e185a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 22 Nov 2015 21:48:48 -0800 Subject: [PATCH 0872/1168] [SPARK-6791][ML] Add read/write for CrossValidator and Evaluators I believe this works for general estimators within CrossValidator, including compound estimators. (See the complex unit test.) Added read/write for all 3 Evaluators as well. CC: mengxr yanboliang Author: Joseph K. Bradley Closes #9848 from jkbradley/cv-io. --- .../scala/org/apache/spark/ml/Pipeline.scala | 38 +-- .../BinaryClassificationEvaluator.scala | 11 +- .../MulticlassClassificationEvaluator.scala | 12 +- .../ml/evaluation/RegressionEvaluator.scala | 11 +- .../apache/spark/ml/recommendation/ALS.scala | 14 +- .../spark/ml/tuning/CrossValidator.scala | 229 +++++++++++++++++- .../org/apache/spark/ml/util/ReadWrite.scala | 48 ++-- .../org/apache/spark/ml/PipelineSuite.scala | 4 +- .../BinaryClassificationEvaluatorSuite.scala | 13 +- ...lticlassClassificationEvaluatorSuite.scala | 13 +- .../evaluation/RegressionEvaluatorSuite.scala | 12 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 202 ++++++++++++++- 12 files changed, 522 insertions(+), 85 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 6f15b37abcb30..4b2b3f8489fd0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -34,7 +34,6 @@ import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -232,20 +231,9 @@ object Pipeline extends MLReadable[Pipeline] { stages: Array[PipelineStage], sc: SparkContext, path: String): Unit = { - // Copied and edited from DefaultParamsWriter.saveMetadata - // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication - val uid = instance.uid - val cls = instance.getClass.getName val stageUids = stages.map(_.uid) val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) - val metadata = ("class" -> cls) ~ - ("timestamp" -> System.currentTimeMillis()) ~ - ("sparkVersion" -> sc.version) ~ - ("uid" -> uid) ~ - ("paramMap" -> jsonParams) - val metadataPath = new Path(path, "metadata").toString - val metadataJson = compact(render(metadata)) - sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams)) // Save stages val stagesDir = new Path(path, "stages").toString @@ -266,30 +254,10 @@ object Pipeline extends MLReadable[Pipeline] { implicit val format = DefaultFormats val stagesDir = new Path(path, "stages").toString - val stageUids: Array[String] = metadata.params match { - case JObject(pairs) => - if (pairs.length != 1) { - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException( - s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.") - } - pairs.head match { - case ("stageUids", jsonValue) => - jsonValue.extract[Seq[String]].toArray - case (paramName, jsonValue) => - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + - s" in metadata: ${metadata.metadataStr}") - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") - } + val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) - val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) - val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath) + DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc) } (metadata.uid, stages) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 1fe3abaca81c3..bfb70963b151d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DoubleType @Since("1.2.0") @Experimental class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasRawPredictionCol with HasLabelCol { + extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.2.0") def this() = this(Identifiable.randomUID("binEval")) @@ -105,3 +105,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("1.4.1") override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): BinaryClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index df5f04ca5a8d9..c44db0ec595ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.types.DoubleType @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DoubleType @Since("1.5.0") @Experimental class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) @@ -101,3 +101,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("1.5.0") override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object MulticlassClassificationEvaluator + extends DefaultParamsReadable[MulticlassClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): MulticlassClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index ba012f444d3e0..daaa174a086e0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, FloatType} @Since("1.4.0") @Experimental final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("regEval")) @@ -104,3 +104,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.5.0") override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] { + + @Since("1.6.0") + override def load(path: String): RegressionEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4d35177ad9b0f..b798aa1fab767 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -27,9 +27,8 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.{FileSystem, Path} -import org.json4s.{DefaultFormats, JValue} +import org.json4s.DefaultFormats import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, Partitioner} import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} @@ -240,7 +239,7 @@ object ALSModel extends MLReadable[ALSModel] { private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - val extraMetadata = render("rank" -> instance.rank) + val extraMetadata = "rank" -> instance.rank DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) val userPath = new Path(path, "userFactors").toString instance.userFactors.write.format("parquet").save(userPath) @@ -257,14 +256,7 @@ object ALSModel extends MLReadable[ALSModel] { override def load(path: String): ALSModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) implicit val format = DefaultFormats - val rank: Int = metadata.extraMetadata match { - case Some(m: JValue) => - (m \ "rank").extract[Int] - case None => - throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" + - s" ${metadata.metadataStr}") - } - + val rank = (metadata.metadata \ "rank").extract[Int] val userPath = new Path(path, "userFactors").toString val userFactors = sqlContext.read.format("parquet").load(userPath) val itemPath = new Path(path, "itemFactors").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 77d9948ed86b9..83a9048374267 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -18,17 +18,24 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS +import org.apache.hadoop.fs.Path +import org.json4s.{JObject, DefaultFormats} +import org.json4s.jackson.JsonMethods._ -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.classification.OneVsRestParams +import org.apache.spark.ml.feature.RFormulaModel +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType + /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ @@ -53,7 +60,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { */ @Experimental class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] - with CrossValidatorParams with Logging { + with CrossValidatorParams with MLWritable with Logging { def this() = this(Identifiable.randomUID("cv")) @@ -131,6 +138,166 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } copied } + + // Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types. + // E.g., this may fail if a [[Param]] is an instance of an [[Estimator]]. + // However, this case should be unusual. + @Since("1.6.0") + override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this) +} + +@Since("1.6.0") +object CrossValidator extends MLReadable[CrossValidator] { + + @Since("1.6.0") + override def read: MLReader[CrossValidator] = new CrossValidatorReader + + @Since("1.6.0") + override def load(path: String): CrossValidator = super.load(path) + + private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(path, instance, sc) + } + + private class CrossValidatorReader extends MLReader[CrossValidator] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidator].getName + + override def load(path: String): CrossValidator = { + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + new CrossValidator(metadata.uid) + .setEstimator(estimator) + .setEvaluator(evaluator) + .setEstimatorParamMaps(estimatorParamMaps) + .setNumFolds(numFolds) + } + } + + private object CrossValidatorReader { + /** + * Examine the given estimator (which may be a compound estimator) and extract a mapping + * from UIDs to corresponding [[Params]] instances. + */ + def getUidMap(instance: Params): Map[String, Params] = { + val uidList = getUidMapImpl(instance) + val uidMap = uidList.toMap + if (uidList.size != uidMap.size) { + throw new RuntimeException("CrossValidator.load found a compound estimator with stages" + + s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}") + } + uidMap + } + + def getUidMapImpl(instance: Params): List[(String, Params)] = { + val subStages: Array[Params] = instance match { + case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] + case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] + case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) + case ovr: OneVsRestParams => + // TODO: SPARK-11892: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing type: ${ovr.getClass.getName}") + case rform: RFormulaModel => + // TODO: SPARK-11891: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing an RFormulaModel") + case _: Params => Array() + } + val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) + List((instance.uid, instance)) ++ subStageMaps + } + } + + private[tuning] object SharedReadWrite { + + /** + * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable. + * This does not check [[CrossValidator.estimatorParamMaps]]. + */ + def validateParams(instance: ValidatorParams): Unit = { + def checkElement(elem: Params, name: String): Unit = elem match { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException("CrossValidator write will fail " + + s" because it contains $name which does not implement Writable." + + s" Non-Writable $name: ${other.uid} of type ${other.getClass}") + } + checkElement(instance.getEvaluator, "evaluator") + checkElement(instance.getEstimator, "estimator") + // Check to make sure all Params apply to this estimator. Throw an error if any do not. + // Extraneous Params would cause problems when loading the estimatorParamMaps. + val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance) + instance.getEstimatorParamMaps.foreach { case pMap: ParamMap => + pMap.toSeq.foreach { case ParamPair(p, v) => + require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" + + s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" + + s" Evaluator. An extraneous Param was found: $p") + } + } + } + + private[tuning] def saveImpl( + path: String, + instance: CrossValidatorParams, + sc: SparkContext, + extraMetadata: Option[JObject] = None): Unit = { + import org.json4s.JsonDSL._ + + val estimatorParamMapsJson = compact(render( + instance.getEstimatorParamMaps.map { case paramMap => + paramMap.toSeq.map { case ParamPair(p, v) => + Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) + } + }.toSeq + )) + val jsonParams = List( + "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)), + "estimatorParamMaps" -> parse(estimatorParamMapsJson) + ) + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + + val evaluatorPath = new Path(path, "evaluator").toString + instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath) + val estimatorPath = new Path(path, "estimator").toString + instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath) + } + + private[tuning] def load[M <: Model[M]]( + path: String, + sc: SparkContext, + expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = { + + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val evaluatorPath = new Path(path, "evaluator").toString + val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc) + val estimatorPath = new Path(path, "estimator").toString + val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) + + val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator) + + val numFolds = (metadata.params \ "numFolds").extract[Int] + val estimatorParamMaps: Array[ParamMap] = + (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map { + pMap => + val paramPairs = pMap.map { case pInfo: Map[String, String] => + val est = uidToParams(pInfo("parent")) + val param = est.getParam(pInfo("name")) + val value = param.jsonDecode(pInfo("value")) + param -> value + } + ParamMap(paramPairs: _*) + }.toArray + (metadata, estimator, evaluator, estimatorParamMaps, numFolds) + } + } } /** @@ -139,14 +306,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM * * @param bestModel The best model selected from k-fold cross validation. * @param avgMetrics Average cross-validation metrics for each paramMap in - * [[estimatorParamMaps]], in the corresponding order. + * [[CrossValidator.estimatorParamMaps]], in the corresponding order. */ @Experimental class CrossValidatorModel private[ml] ( override val uid: String, val bestModel: Model[_], val avgMetrics: Array[Double]) - extends Model[CrossValidatorModel] with CrossValidatorParams { + extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { override def validateParams(): Unit = { bestModel.validateParams() @@ -168,4 +335,54 @@ class CrossValidatorModel private[ml] ( avgMetrics.clone()) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) +} + +@Since("1.6.0") +object CrossValidatorModel extends MLReadable[CrossValidatorModel] { + + import CrossValidator.SharedReadWrite + + @Since("1.6.0") + override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader + + @Since("1.6.0") + override def load(path: String): CrossValidatorModel = super.load(path) + + private[CrossValidatorModel] + class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + import org.json4s.JsonDSL._ + val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq + SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata)) + val bestModelPath = new Path(path, "bestModel").toString + instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + } + } + + private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidatorModel].getName + + override def load(path: String): CrossValidatorModel = { + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + val bestModelPath = new Path(path, "bestModel").toString + val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) + val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + cv.set(cv.estimator, estimator) + .set(cv.evaluator, evaluator) + .set(cv.estimatorParamMaps, estimatorParamMaps) + .set(cv.numFolds, numFolds) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index ff9322dba122a..8484b1f801066 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -202,25 +202,36 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid - * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. + * - paramMap + * - (optionally, extra metadata) + * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. + * @param paramMap If given, this is saved in the "paramMap" field. + * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using + * [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ def saveMetadata( instance: Params, path: String, sc: SparkContext, - extraMetadata: Option[JValue] = None): Unit = { + extraMetadata: Option[JObject] = None, + paramMap: Option[JValue] = None): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] - val jsonParams = params.map { case ParamPair(p, v) => + val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) - }.toList - val metadata = ("class" -> cls) ~ + }.toList)) + val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) ~ - ("extraMetadata" -> extraMetadata) + ("paramMap" -> jsonParams) + val metadata = extraMetadata match { + case Some(jObject) => + basicMetadata ~ jObject + case None => + basicMetadata + } val metadataPath = new Path(path, "metadata").toString val metadataJson = compact(render(metadata)) sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) @@ -251,8 +262,8 @@ private[ml] object DefaultParamsReader { /** * All info from metadata file. * @param params paramMap, as a [[JValue]] - * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]] - * @param metadataStr Full metadata file String (for debugging) + * @param metadata All metadata, including the other fields + * @param metadataJson Full metadata file String (for debugging) */ case class Metadata( className: String, @@ -260,8 +271,8 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, - extraMetadata: Option[JValue], - metadataStr: String) + metadata: JValue, + metadataJson: String) /** * Load metadata from file. @@ -279,13 +290,12 @@ private[ml] object DefaultParamsReader { val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] val params = metadata \ "paramMap" - val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]] if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr) + Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) } /** @@ -303,7 +313,17 @@ private[ml] object DefaultParamsReader { } case _ => throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") } } + + /** + * Load a [[Params]] instance from the given path, and return it. + * This assumes the instance implements [[MLReadable]]. + */ + def loadParamsInstance[T](path: String, sc: SparkContext): T = { + val metadata = DefaultParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 12aba6bc6dbeb..8c86767456368 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.ml -import java.io.File - import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index def869fe66777..a535c1218ecfa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class BinaryClassificationEvaluatorSuite extends SparkFunSuite { +class BinaryClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new BinaryClassificationEvaluator) } + + test("read/write") { + val evaluator = new BinaryClassificationEvaluator() + .setRawPredictionCol("myRawPrediction") + .setLabelCol("myLabel") + .setMetricName("areaUnderPR") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 6d8412b0b3701..7ee65975d22f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { +class MulticlassClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new MulticlassClassificationEvaluator) } + + test("read/write") { + val evaluator = new MulticlassClassificationEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("recall") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index aa722da323935..60886bf77d2f0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class RegressionEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RegressionEvaluator) @@ -73,4 +75,12 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext evaluator.setMetricName("mae") assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) } + + test("read/write") { + val evaluator = new RegressionEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("r2") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index cbe09292a0337..dd6366050c020 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,19 +18,22 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.{Pipeline, Estimator, Model} +import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{ParamPair, ParamMap} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.types.StructType -class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class CrossValidatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var dataset: DataFrame = _ @@ -95,7 +98,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { } test("validateParams should check estimatorParamMaps") { - import CrossValidatorSuite._ + import CrossValidatorSuite.{MyEstimator, MyEvaluator} val est = new MyEstimator("est") val eval = new MyEvaluator @@ -116,9 +119,194 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { cv.validateParams() } } + + test("read/write: CrossValidator with simple estimator") { + val lr = new LogisticRegression().setMaxIter(3) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + } + + test("read/write: CrossValidator with complex estimator") { + // workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]] + val lrEvaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + + val lr = new LogisticRegression().setMaxIter(3) + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val lrcv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(lrEvaluator) + .setEstimatorParamMaps(lrParamMaps) + + val hashingTF = new HashingTF() + val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv)) + val paramMaps = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 20)) + .addGrid(lr.elasticNetParam, Array(0.0, 1.0)) + .build() + val evaluator = new BinaryClassificationEvaluator() + + val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(cv.getEvaluator.uid === cv2.getEvaluator.uid) + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.getEstimator match { + case pipeline2: Pipeline => + assert(pipeline.uid === pipeline2.uid) + pipeline2.getStages match { + case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) => + assert(hashingTF.uid === hashingTF2.uid) + lrcv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded internal CrossValidator expected to be" + + s" LogisticRegression but found type ${other.getClass.getName}") + } + assert(lrcv.uid === lrcv2.uid) + assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(lrEvaluator.uid === lrcv2.getEvaluator.uid) + CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps) + case other => + throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" + + " but found: " + other.map(_.getClass.getName).mkString(", ")) + } + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" CrossValidator but found ${other.getClass.getName}") + } + } + + test("read/write: CrossValidator fails for extraneous Param") { + val lr = new LogisticRegression() + val lr2 = new LogisticRegression() + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .addGrid(lr2.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setEstimatorParamMaps(paramMaps) + withClue("CrossValidator.write failed to catch extraneous Param error") { + intercept[IllegalArgumentException] { + cv.write + } + } + } + + test("read/write: CrossValidatorModel") { + val lr = new LogisticRegression() + .setThreshold(0.6) + val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2) + .setThreshold(0.6) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6)) + cv.set(cv.estimator, lr) + .set(cv.evaluator, evaluator) + .set(cv.numFolds, 20) + .set(cv.estimatorParamMaps, paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getThreshold === lr2.getThreshold) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.bestModel match { + case lrModel2: LogisticRegressionModel => + assert(lrModel.uid === lrModel2.uid) + assert(lrModel.getThreshold === lrModel2.getThreshold) + assert(lrModel.coefficients === lrModel2.coefficients) + assert(lrModel.intercept === lrModel2.intercept) + case other => + throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" + + s" LogisticRegressionModel but found ${other.getClass.getName}") + } + assert(cv.avgMetrics === cv2.avgMetrics) + } } -object CrossValidatorSuite { +object CrossValidatorSuite extends SparkFunSuite { + + /** + * Assert sequences of estimatorParamMaps are identical. + * Params must be simple types comparable with `===`. + */ + def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = { + assert(pMaps.length === pMaps2.length) + pMaps.zip(pMaps2).foreach { case (pMap, pMap2) => + assert(pMap.size === pMap2.size) + pMap.toSeq.foreach { case ParamPair(p, v) => + assert(pMap2.contains(p)) + assert(pMap2(p) === v) + } + } + } abstract class MyModel extends Model[MyModel] From b94a50002d1d5c9249479e77e65c65ee9ee40031 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Sun, 22 Nov 2015 21:51:42 -0800 Subject: [PATCH 0873/1168] [SPARK-11835] Adds a sidebar menu to MLlib's documentation This PR adds a sidebar menu when browsing the user guide of MLlib. It uses a YAML file to describe the structure of the documentation. It should be trivial to adapt this to the other projects. ![screen shot 2015-11-18 at 4 46 12 pm](https://cloud.githubusercontent.com/assets/7594753/11259591/a55173f4-8e17-11e5-9340-0aed79d66262.png) Author: Timothy Hunter Closes #9826 from thunterdb/spark-11835. --- docs/_data/menu-ml.yaml | 10 ++++ docs/_data/menu-mllib.yaml | 75 +++++++++++++++++++++++++ docs/_includes/nav-left-wrapper-ml.html | 8 +++ docs/_includes/nav-left.html | 17 ++++++ docs/_layouts/global.html | 24 +++++--- docs/css/main.css | 37 ++++++++++++ 6 files changed, 163 insertions(+), 8 deletions(-) create mode 100644 docs/_data/menu-ml.yaml create mode 100644 docs/_data/menu-mllib.yaml create mode 100644 docs/_includes/nav-left-wrapper-ml.html create mode 100644 docs/_includes/nav-left.html diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml new file mode 100644 index 0000000000000..dff3d33bf4ed1 --- /dev/null +++ b/docs/_data/menu-ml.yaml @@ -0,0 +1,10 @@ +- text: Feature extraction, transformation, and selection + url: ml-features.html +- text: Decision trees for classification and regression + url: ml-decision-tree.html +- text: Ensembles + url: ml-ensembles.html +- text: Linear methods with elastic-net regularization + url: ml-linear-methods.html +- text: Multilayer perceptron classifier + url: ml-ann.html diff --git a/docs/_data/menu-mllib.yaml b/docs/_data/menu-mllib.yaml new file mode 100644 index 0000000000000..12d22abd52826 --- /dev/null +++ b/docs/_data/menu-mllib.yaml @@ -0,0 +1,75 @@ +- text: Data types + url: mllib-data-types.html +- text: Basic statistics + url: mllib-statistics.html + subitems: + - text: Summary statistics + url: mllib-statistics.html#summary-statistics + - text: Correlations + url: mllib-statistics.html#correlations + - text: Stratified sampling + url: mllib-statistics.html#stratified-sampling + - text: Hypothesis testing + url: mllib-statistics.html#hypothesis-testing + - text: Random data generation + url: mllib-statistics.html#random-data-generation +- text: Classification and regression + url: mllib-classification-regression.html + subitems: + - text: Linear models (SVMs, logistic regression, linear regression) + url: mllib-linear-methods.html + - text: Naive Bayes + url: mllib-naive-bayes.html + - text: decision trees + url: mllib-decision-tree.html + - text: ensembles of trees (Random Forests and Gradient-Boosted Trees) + url: mllib-ensembles.html + - text: isotonic regression + url: mllib-isotonic-regression.html +- text: Collaborative filtering + url: mllib-collaborative-filtering.html + subitems: + - text: alternating least squares (ALS) + url: mllib-collaborative-filtering.html#collaborative-filtering +- text: Clustering + url: mllib-clustering.html + subitems: + - text: k-means + url: mllib-clustering.html#k-means + - text: Gaussian mixture + url: mllib-clustering.html#gaussian-mixture + - text: power iteration clustering (PIC) + url: mllib-clustering.html#power-iteration-clustering-pic + - text: latent Dirichlet allocation (LDA) + url: mllib-clustering.html#latent-dirichlet-allocation-lda + - text: streaming k-means + url: mllib-clustering.html#streaming-k-means +- text: Dimensionality reduction + url: mllib-dimensionality-reduction.html + subitems: + - text: singular value decomposition (SVD) + url: mllib-dimensionality-reduction.html#singular-value-decomposition-svd + - text: principal component analysis (PCA) + url: mllib-dimensionality-reduction.html#principal-component-analysis-pca +- text: Feature extraction and transformation + url: mllib-feature-extraction.html +- text: Frequent pattern mining + url: mllib-frequent-pattern-mining.html + subitems: + - text: FP-growth + url: mllib-frequent-pattern-mining.html#fp-growth + - text: association rules + url: mllib-frequent-pattern-mining.html#association-rules + - text: PrefixSpan + url: mllib-frequent-pattern-mining.html#prefix-span +- text: Evaluation metrics + url: mllib-evaluation-metrics.html +- text: PMML model export + url: mllib-pmml-model-export.html +- text: Optimization (developer) + url: mllib-optimization.html + subitems: + - text: stochastic gradient descent + url: mllib-optimization.html#stochastic-gradient-descent-sgd + - text: limited-memory BFGS (L-BFGS) + url: mllib-optimization.html#limited-memory-bfgs-l-bfgs diff --git a/docs/_includes/nav-left-wrapper-ml.html b/docs/_includes/nav-left-wrapper-ml.html new file mode 100644 index 0000000000000..0103e890cc21a --- /dev/null +++ b/docs/_includes/nav-left-wrapper-ml.html @@ -0,0 +1,8 @@ +
    +
    +

    spark.ml package

    + {% include nav-left.html nav=include.nav-ml %} +

    spark.mllib package

    + {% include nav-left.html nav=include.nav-mllib %} +
    +
    \ No newline at end of file diff --git a/docs/_includes/nav-left.html b/docs/_includes/nav-left.html new file mode 100644 index 0000000000000..73176f4132554 --- /dev/null +++ b/docs/_includes/nav-left.html @@ -0,0 +1,17 @@ +{% assign navurl = page.url | remove: 'index.html' %} + diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 467ff7a03fb70..1b09e2221e173 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -124,16 +124,24 @@ -
    - {% if page.displayTitle %} -

    {{ page.displayTitle }}

    - {% else %} -

    {{ page.title }}

    - {% endif %} +
    - {{ content }} + {% if page.url contains "/ml" %} + {% include nav-left-wrapper-ml.html nav-mllib=site.data.menu-mllib nav-ml=site.data.menu-ml %} + {% endif %} -
    + +
    + {% if page.displayTitle %} +

    {{ page.displayTitle }}

    + {% else %} +

    {{ page.title }}

    + {% endif %} + + {{ content }} + +
    +
    diff --git a/docs/css/main.css b/docs/css/main.css index d770173be1014..356b324d6303b 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -39,8 +39,18 @@ margin-left: 10px; } +body .container-wrapper { + position: absolute; + width: 100%; + display: flex; +} + body #content { + position: relative; + line-height: 1.6; /* Inspired by Github's wiki style */ + background-color: white; + padding-left: 15px; } .title { @@ -155,3 +165,30 @@ ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu { * AnchorJS (anchor links when hovering over headers) */ a.anchorjs-link:hover { text-decoration: none; } + + +/** + * The left navigation bar. + */ +.left-menu-wrapper { + position: absolute; + height: 100%; + + width: 256px; + margin-top: -20px; + padding-top: 20px; + background-color: #F0F8FC; +} + +.left-menu { + position: fixed; + max-width: 350px; + + padding-right: 10px; + width: 256px; +} + +.left-menu h3 { + margin-left: 10px; + line-height: 30px; +} \ No newline at end of file From 3360ab4513307e8bedec67d0b96998d971cd6cec Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 22 Nov 2015 21:56:07 -0800 Subject: [PATCH 0874/1168] [SPARK-11912][ML] ml.feature.PCA minor refactor Like [SPARK-11852](https://issues.apache.org/jira/browse/SPARK-11852), ```k``` is params and we should save it under ```metadata/``` rather than both under ```data/``` and ```metadata/```. Refactor the constructor of ```ml.feature.PCAModel``` to take only ```pc``` but construct ```mllib.feature.PCAModel``` inside ```transform```. Author: Yanbo Liang Closes #9897 from yanboliang/spark-11912. --- .../org/apache/spark/ml/feature/PCA.scala | 23 +++++++------- .../apache/spark/ml/feature/PCASuite.scala | 31 ++++++++----------- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 32d7afee6e73b..aa88cb03d23c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -73,7 +73,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} val pca = new feature.PCA(k = $(k)) val pcaModel = pca.fit(input) - copyValues(new PCAModel(uid, pcaModel).setParent(this)) + copyValues(new PCAModel(uid, pcaModel.pc).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -99,18 +99,17 @@ object PCA extends DefaultParamsReadable[PCA] { /** * :: Experimental :: * Model fitted by [[PCA]]. + * + * @param pc A principal components Matrix. Each column is one principal component. */ @Experimental class PCAModel private[ml] ( override val uid: String, - pcaModel: feature.PCAModel) + val pc: DenseMatrix) extends Model[PCAModel] with PCAParams with MLWritable { import PCAModel._ - /** a principal components Matrix. Each column is one principal component. */ - val pc: DenseMatrix = pcaModel.pc - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -124,6 +123,7 @@ class PCAModel private[ml] ( */ override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) + val pcaModel = new feature.PCAModel($(k), pc) val pcaOp = udf { pcaModel.transform _ } dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) } @@ -139,7 +139,7 @@ class PCAModel private[ml] ( } override def copy(extra: ParamMap): PCAModel = { - val copied = new PCAModel(uid, pcaModel) + val copied = new PCAModel(uid, pc) copyValues(copied, extra).setParent(parent) } @@ -152,11 +152,11 @@ object PCAModel extends MLReadable[PCAModel] { private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter { - private case class Data(k: Int, pc: DenseMatrix) + private case class Data(pc: DenseMatrix) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.getK, instance.pc) + val data = Data(instance.pc) val dataPath = new Path(path, "data").toString sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } @@ -169,11 +169,10 @@ object PCAModel extends MLReadable[PCAModel] { override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath) - .select("k", "pc") + val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath) + .select("pc") .head() - val oldModel = new feature.PCAModel(k, pc) - val model = new PCAModel(metadata.uid, oldModel) + val model = new PCAModel(metadata.uid, pc) DefaultParamsReader.getAndSetParams(model, metadata) model } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 5a21cd20ceede..edab21e6c3072 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -32,7 +32,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead test("params") { ParamsSuite.checkParams(new PCA) val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] - val model = new PCAModel("pca", new OldPCAModel(2, mat)) + val model = new PCAModel("pca", mat) ParamsSuite.checkParams(model) } @@ -66,23 +66,18 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } } - test("read/write") { + test("PCA read/write") { + val t = new PCA() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setK(3) + testDefaultReadWrite(t) + } - def checkModelData(model1: PCAModel, model2: PCAModel): Unit = { - assert(model1.pc === model2.pc) - } - val allParams: Map[String, Any] = Map( - "k" -> 3, - "inputCol" -> "features", - "outputCol" -> "pca_features" - ) - val data = Seq( - (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))), - (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - ) - val df = sqlContext.createDataFrame(data).toDF("id", "features") - val pca = new PCA().setK(3) - testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData) + test("PCAModel read/write") { + val instance = new PCAModel("myPCAModel", + Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.pc === instance.pc) } } From dce3130a3bc89a1e5c6760625e21210f81b42e83 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Sun, 22 Nov 2015 22:05:01 -0800 Subject: [PATCH 0875/1168] [SPARK-11902][ML] Unhandled case in VectorAssembler#transform There is an unhandled case in the transform method of VectorAssembler if one of the input columns doesn't have one of the supported type DoubleType, NumericType, BooleanType or VectorUDT. So, if you try to transform a column of StringType you get a cryptic "scala.MatchError: StringType". This PR aims to fix this, throwing a SparkException when dealing with an unknown column type. Author: BenFradet Closes #9885 from BenFradet/SPARK-11902. --- .../org/apache/spark/ml/feature/VectorAssembler.scala | 2 ++ .../spark/ml/feature/VectorAssemblerSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 0feec0549852b..801096fed27bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -84,6 +84,8 @@ class VectorAssembler(override val uid: String) val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) Array.fill(numAttrs)(NumericAttribute.defaultAttr) } + case otherType => + throw new SparkException(s"VectorAssembler does not support the $otherType type") } } val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index fb21ab6b9bf2c..9c1c00f41ab1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -69,6 +69,17 @@ class VectorAssemblerSuite } } + test("transform should throw an exception in case of unsupported type") { + val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val assembler = new VectorAssembler() + .setInputCols(Array("a", "b", "c")) + .setOutputCol("features") + val thrown = intercept[SparkException] { + assembler.transform(df) + } + assert(thrown.getMessage contains "VectorAssembler does not support the StringType type") + } + test("ML attributes") { val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) From 60b46143c00bd758617e86551195008ec3864711 Mon Sep 17 00:00:00 2001 From: Xiu Guo Date: Mon, 23 Nov 2015 08:53:40 -0800 Subject: [PATCH 0876/1168] [SPARK-11628][SQL] support column datatype of char(x) to recognize HiveChar Can someone review my code to make sure I'm not missing anything? Thanks! Author: Xiu Guo Author: Xiu Guo Closes #9612 from xguo27/SPARK-11628. --- .../sql/catalyst/util/DataTypeParser.scala | 6 ++++- .../catalyst/util/DataTypeParserSuite.scala | 8 ++++-- .../spark/sql/sources/TableScanSuite.scala | 5 ++++ .../spark/sql/hive/HiveInspectors.scala | 25 ++++++++++++++++--- .../apache/spark/sql/hive/TableReader.scala | 3 +++ .../spark/sql/hive/client/HiveShim.scala | 3 ++- 6 files changed, 43 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala index 2b83651f9086d..515c071c283b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala @@ -52,7 +52,8 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)decimal".r ^^^ DecimalType.USER_DEFAULT | "(?i)date".r ^^^ DateType | "(?i)timestamp".r ^^^ TimestampType | - varchar + varchar | + char protected lazy val fixedDecimalType: Parser[DataType] = ("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { @@ -60,6 +61,9 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { DecimalType(precision.toInt, scale.toInt) } + protected lazy val char: Parser[DataType] = + "(?i)char".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType + protected lazy val varchar: Parser[DataType] = "(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala index 1e3409a9db6eb..bebf708965474 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala @@ -49,7 +49,9 @@ class DataTypeParserSuite extends SparkFunSuite { checkDataType("DATE", DateType) checkDataType("timestamp", TimestampType) checkDataType("string", StringType) + checkDataType("ChaR(5)", StringType) checkDataType("varchAr(20)", StringType) + checkDataType("cHaR(27)", StringType) checkDataType("BINARY", BinaryType) checkDataType("array", ArrayType(DoubleType, true)) @@ -83,7 +85,8 @@ class DataTypeParserSuite extends SparkFunSuite { |struct< | struct:struct, | MAP:Map, - | arrAy:Array> + | arrAy:Array, + | anotherArray:Array> """.stripMargin, StructType( StructField("struct", @@ -91,7 +94,8 @@ class DataTypeParserSuite extends SparkFunSuite { StructField("deciMal", DecimalType.USER_DEFAULT, true) :: StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: StructField("MAP", MapType(TimestampType, StringType), true) :: - StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil) + StructField("arrAy", ArrayType(DoubleType, true), true) :: + StructField("anotherArray", ArrayType(StringType, true), true) :: Nil) ) // A column name can be a reserved word in our DDL parser and SqlParser. checkDataType( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 12af8068c398f..26c1ff520406c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -85,6 +85,7 @@ case class AllDataTypesScan( Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -115,6 +116,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -154,6 +156,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { |dateField dAte, |timestampField tiMestamp, |varcharField varchaR(12), + |charField ChaR(18), |arrayFieldSimple Array, |arrayFieldComplex Array>>, |mapFieldSimple MAP, @@ -207,6 +210,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: StructField("varcharField", StringType, true) :: + StructField("charField", StringType, true) :: StructField("arrayFieldSimple", ArrayType(IntegerType), true) :: StructField("arrayFieldComplex", ArrayType( @@ -248,6 +252,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { | dateField, | timestampField, | varcharField, + | charField, | arrayFieldSimple, | arrayFieldComplex, | mapFieldSimple, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 36f0708f9da3d..95b57d6ad124a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import scala.collection.JavaConverters._ -import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} @@ -61,6 +61,7 @@ import org.apache.spark.unsafe.types.UTF8String * Primitive Type * Java Boxed Primitives: * org.apache.hadoop.hive.common.type.HiveVarchar + * org.apache.hadoop.hive.common.type.HiveChar * java.lang.String * java.lang.Integer * java.lang.Boolean @@ -75,6 +76,7 @@ import org.apache.spark.unsafe.types.UTF8String * java.sql.Timestamp * Writables: * org.apache.hadoop.hive.serde2.io.HiveVarcharWritable + * org.apache.hadoop.hive.serde2.io.HiveCharWritable * org.apache.hadoop.io.Text * org.apache.hadoop.io.IntWritable * org.apache.hadoop.hive.serde2.io.DoubleWritable @@ -93,7 +95,8 @@ import org.apache.spark.unsafe.types.UTF8String * Struct: Object[] / java.util.List / java POJO * Union: class StandardUnion { byte tag; Object object } * - * NOTICE: HiveVarchar is not supported by catalyst, it will be simply considered as String type. + * NOTICE: HiveVarchar/HiveChar is not supported by catalyst, it will be simply considered as + * String type. * * * 2. Hive ObjectInspector is a group of flexible APIs to inspect value in different data @@ -137,6 +140,7 @@ import org.apache.spark.unsafe.types.UTF8String * Primitive Object Inspectors: * WritableConstantStringObjectInspector * WritableConstantHiveVarcharObjectInspector + * WritableConstantHiveCharObjectInspector * WritableConstantHiveDecimalObjectInspector * WritableConstantTimestampObjectInspector * WritableConstantIntObjectInspector @@ -259,6 +263,8 @@ private[hive] trait HiveInspectors { UTF8String.fromString(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue) + case poi: WritableConstantHiveCharObjectInspector => + UTF8String.fromString(poi.getWritableConstantValue.getHiveChar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, @@ -303,11 +309,15 @@ private[hive] trait HiveInspectors { case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector case pi: PrimitiveObjectInspector => pi match { - // We think HiveVarchar is also a String + // We think HiveVarchar/HiveChar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) case hvoi: HiveVarcharObjectInspector => UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) + case hvoi: HiveCharObjectInspector if hvoi.preferWritable() => + UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue) + case hvoi: HiveCharObjectInspector => + UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => UTF8String.fromString(x.getPrimitiveWritableObject(data).toString) case x: StringObjectInspector => @@ -377,6 +387,15 @@ private[hive] trait HiveInspectors { null } + case _: JavaHiveCharObjectInspector => + (o: Any) => + if (o != null) { + val s = o.asInstanceOf[UTF8String].toString + new HiveChar(s, s.size) + } else { + null + } + case _: JavaHiveDecimalObjectInspector => (o: Any) => if (o != null) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 69f481c49a655..70ee02823eeba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -382,6 +382,9 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) + case oi: HiveCharObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 48bbb21e6c1de..346840079b853 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -321,7 +321,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. val varcharKeys = table.getPartitionKeys.asScala - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || + col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet filters.collect { From 14edb8eae0803435d9d6b22c8c7147526af972ca Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:13:59 -0800 Subject: [PATCH 0877/1168] [SPARK-11894][SQL] fix isNull for GetInternalRowField We should use `InternalRow.isNullAt` to check if the field is null before calling `InternalRow.getXXX` Thanks gatorsmile who discovered this bug. Author: Wenchen Fan Closes #9904 from cloud-fan/null. --- .../sql/catalyst/expressions/objects.scala | 23 ++++++++----------- .../org/apache/spark/sql/DatasetSuite.scala | 15 +++++++++++- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 82317d3385167..4a1f419f0ad8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -236,11 +236,6 @@ case class NewInstance( } if (propagateNull) { - val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" - } else { - "" - } val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" s""" @@ -531,15 +526,15 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val row = child.gen(ctx) - s""" - ${row.code} - final boolean ${ev.isNull} = ${row.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)}; - } - """ + nullSafeCodeGen(ctx, ev, eval => { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + } + """ + }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9da02550b39ce..cc8e4325fd2f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -386,7 +386,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq((JavaData(1), 1L), (JavaData(2), 1L))) } - ignore("Java encoder self join") { + test("Java encoder self join") { implicit val kryoEncoder = Encoders.javaSerialization[JavaData] val ds = Seq(JavaData(1), JavaData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == @@ -396,6 +396,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (JavaData(2), JavaData(1)), (JavaData(2), JavaData(2)))) } + + test("SPARK-11894: Incorrect results are returned when using null") { + val nullInt = null.asInstanceOf[java.lang.Integer] + val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + + checkAnswer( + ds1.joinWith(ds2, lit(true)), + ((nullInt, "1"), (nullInt, "1")), + ((new java.lang.Integer(22), "2"), (nullInt, "1")), + ((nullInt, "1"), (new java.lang.Integer(22), "2")), + ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) + } } From 29fff102c4538122c8e18423b2bc26a9d688252f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:15:40 -0800 Subject: [PATCH 0878/1168] [SPARK-11921][SQL] fix `nullable` of encoder schema Author: Wenchen Fan Closes #9906 from cloud-fan/nullable. --- .../catalyst/encoders/ExpressionEncoder.scala | 15 +++++++- .../encoders/ExpressionEncoderSuite.scala | 38 ++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 6eeba1442c1f3..7bc9aed0b204e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -54,8 +54,13 @@ object ExpressionEncoder { val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) val fromRowExpression = ScalaReflection.constructorFor[T] + val schema = ScalaReflection.schemaFor[T] match { + case ScalaReflection.Schema(s: StructType, _) => s + case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable) + } + new ExpressionEncoder[T]( - toRowExpression.dataType, + schema, flat, toRowExpression.flatten, fromRowExpression, @@ -71,7 +76,13 @@ object ExpressionEncoder { encoders.foreach(_.assertUnresolved()) val schema = StructType(encoders.zipWithIndex.map { - case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + case (e, i) => + val (dataType, nullable) = if (e.flat) { + e.schema.head.dataType -> e.schema.head.nullable + } else { + e.schema -> true + } + StructField(s"_${i + 1}", dataType, nullable) }) val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 76459b34a484f..d6ca138672ef1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} -import org.apache.spark.sql.types.ArrayType +import org.apache.spark.sql.types.{StructType, ArrayType} case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -238,6 +238,42 @@ class ExpressionEncoderSuite extends SparkFunSuite { ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + test("nullable of encoder schema") { + def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { + assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq) + } + + // test for flat encoders + checkNullable[Int](false) + checkNullable[Option[Int]](true) + checkNullable[java.lang.Integer](true) + checkNullable[String](true) + + // test for product encoders + checkNullable[(String, Int)](true, false) + checkNullable[(Int, java.lang.Long)](false, true) + + // test for nested product encoders + { + val schema = ExpressionEncoder[(Int, (String, Int))].schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + + // test for tupled encoders + { + val schema = ExpressionEncoder.tuple( + ExpressionEncoder[Int], + ExpressionEncoder[(String, Int)]).schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + } + private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() outers.put(getClass.getName, this) private def encodeDecodeTest[T : ExpressionEncoder]( From ee621318b57dddd01a331981548f9bd59d4654b1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:39:33 -0800 Subject: [PATCH 0879/1168] [SPARK-11913][SQL] support typed aggregate with complex buffer schema Author: Wenchen Fan Closes #9898 from cloud-fan/agg. --- .../aggregate/TypedAggregateExpression.scala | 25 +++++++---- .../spark/sql/DatasetAggregatorSuite.scala | 41 ++++++++++++++++++- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 6ce41aaf01e27..a9719128a626e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -23,9 +23,8 @@ import org.apache.spark.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.{OuterScopes, encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -46,14 +45,12 @@ object TypedAggregateExpression { /** * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has * the following limitations: - * - It assumes the aggregator reduces and returns a single column of type `long`. - * - It might only work when there is a single aggregator in the first column. * - It assumes the aggregator has a zero, `0`. */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. - bEncoder: ExpressionEncoder[Any], // Should be bound. + unresolvedBEncoder: ExpressionEncoder[Any], cEncoder: ExpressionEncoder[Any], children: Seq[Attribute], mutableAggBufferOffset: Int, @@ -80,10 +77,14 @@ case class TypedAggregateExpression( override lazy val inputTypes: Seq[DataType] = Nil - override val aggBufferSchema: StructType = bEncoder.schema + override val aggBufferSchema: StructType = unresolvedBEncoder.schema override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + val bEncoder = unresolvedBEncoder + .resolve(aggBufferAttributes, OuterScopes.outerScopes) + .bind(aggBufferAttributes) + // Note: although this simply copies aggBufferAttributes, this common code can not be placed // in the superclass because that will lead to initialization ordering issues. override val inputAggBufferAttributes: Seq[AttributeReference] = @@ -93,12 +94,18 @@ case class TypedAggregateExpression( lazy val boundA = aEncoder.get private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { - // todo: need a more neat way to assign the value. var i = 0 while (i < aggBufferAttributes.length) { + val offset = mutableAggBufferOffset + i aggBufferSchema(i).dataType match { - case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i)) - case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i)) + case BooleanType => buffer.setBoolean(offset, value.getBoolean(i)) + case ByteType => buffer.setByte(offset, value.getByte(i)) + case ShortType => buffer.setShort(offset, value.getShort(i)) + case IntegerType => buffer.setInt(offset, value.getInt(i)) + case LongType => buffer.setLong(offset, value.getLong(i)) + case FloatType => buffer.setFloat(offset, value.getFloat(i)) + case DoubleType => buffer.setDouble(offset, value.getDouble(i)) + case other => buffer.update(offset, value.get(i, other)) } i += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 9377589790011..19dce5d1e2f37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -67,7 +67,7 @@ object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, L } case class AggData(a: Int, b: String) -object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable { +object ClassInputAgg extends Aggregator[AggData, Int, Int] { /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ override def zero: Int = 0 @@ -88,6 +88,28 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable { override def merge(b1: Int, b2: Int): Int = b1 + b2 } +object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + override def zero: (Int, AggData) = 0 -> AggData(0, "0") + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) + + /** + * Transform the output of the reduction. + */ + override def finish(reduction: (Int, AggData)): Int = reduction._1 + + /** + * Merge two intermediate values + */ + override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) = + (b1._1 + b2._1, b1._2) +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -168,4 +190,21 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.groupBy(_.b).agg(ClassInputAgg.toColumn), ("one", 1)) } + + test("typed aggregation: complex input") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + + checkAnswer( + ds.select(ComplexBufferAgg.toColumn), + 2 + ) + + checkAnswer( + ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn), + (1.5, 2)) + + checkAnswer( + ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn), + ("one", 1), ("two", 1)) + } } From a1bc9e0055fb7c43d7618ab2d03968982bb49484 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 23 Nov 2015 10:41:17 -0800 Subject: [PATCH 0880/1168] [SPARK-7173][YARN] Add label expression support for application master Add label expression support for AM to restrict it runs on the specific set of nodes. I tested it locally and works fine. sryza and vanzin please help to review, thanks a lot. Author: jerryshao Closes #9800 from jerryshao/SPARK-7173. --- docs/running-on-yarn.md | 9 +++++++ .../org/apache/spark/deploy/yarn/Client.scala | 26 ++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index db6bfa69ee0fe..925a1e0ba6fcf 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -326,6 +326,15 @@ If you need a reference to the proper location to put log files in the YARN so t Otherwise, the client process will exit after submission. + + spark.yarn.am.nodeLabelExpression + (none) + + A YARN node label expression that restricts the set of nodes AM will be scheduled on. + Only versions of YARN greater than or equal to 2.6 support node label expressions, so when + running against earlier versions, this property will be ignored. + + spark.yarn.executor.nodeLabelExpression (none) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index ba799884f5689..a77a3e2420e24 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -225,7 +225,31 @@ private[spark] class Client( val capability = Records.newRecord(classOf[Resource]) capability.setMemory(args.amMemory + amMemoryOverhead) capability.setVirtualCores(args.amCores) - appContext.setResource(capability) + + if (sparkConf.contains("spark.yarn.am.nodeLabelExpression")) { + try { + val amRequest = Records.newRecord(classOf[ResourceRequest]) + amRequest.setResourceName(ResourceRequest.ANY) + amRequest.setPriority(Priority.newInstance(0)) + amRequest.setCapability(capability) + amRequest.setNumContainers(1) + val amLabelExpression = sparkConf.get("spark.yarn.am.nodeLabelExpression") + val method = amRequest.getClass.getMethod("setNodeLabelExpression", classOf[String]) + method.invoke(amRequest, amLabelExpression) + + val setResourceRequestMethod = + appContext.getClass.getMethod("setAMContainerResourceRequest", classOf[ResourceRequest]) + setResourceRequestMethod.invoke(appContext, amRequest) + } catch { + case e: NoSuchMethodException => + logWarning("Ignoring spark.yarn.am.nodeLabelExpression because the version " + + "of YARN does not support it") + appContext.setResource(capability) + } + } else { + appContext.setResource(capability) + } + appContext } From 5cef7a7ef0426373b37f4109fc74825c93ba5ea5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 10:45:23 -0800 Subject: [PATCH 0881/1168] [SPARK-11762][NETWORK] Account for active streams when couting outstanding requests. This way the timeout handling code can correctly close "hung" channels that are processing streams. Author: Marcelo Vanzin Closes #9747 from vanzin/SPARK-11762. --- .../network/client/StreamInterceptor.java | 12 ++++++++- .../client/TransportResponseHandler.java | 15 +++++++++-- .../TransportResponseHandlerSuite.java | 27 +++++++++++++++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index 02230a00e69fc..88ba3ccebdf20 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -30,13 +30,19 @@ */ class StreamInterceptor implements TransportFrameDecoder.Interceptor { + private final TransportResponseHandler handler; private final String streamId; private final long byteCount; private final StreamCallback callback; private volatile long bytesRead; - StreamInterceptor(String streamId, long byteCount, StreamCallback callback) { + StreamInterceptor( + TransportResponseHandler handler, + String streamId, + long byteCount, + StreamCallback callback) { + this.handler = handler; this.streamId = streamId; this.byteCount = byteCount; this.callback = callback; @@ -45,11 +51,13 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor { @Override public void exceptionCaught(Throwable cause) throws Exception { + handler.deactivateStream(); callback.onFailure(streamId, cause); } @Override public void channelInactive() throws Exception { + handler.deactivateStream(); callback.onFailure(streamId, new ClosedChannelException()); } @@ -65,8 +73,10 @@ public boolean handle(ByteBuf buf) throws Exception { RuntimeException re = new IllegalStateException(String.format( "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); callback.onFailure(streamId, re); + handler.deactivateStream(); throw re; } else if (bytesRead == byteCount) { + handler.deactivateStream(); callback.onComplete(streamId); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index ed3f36af58048..cc88991b588c1 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,6 +57,7 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; private final Queue streamCallbacks; + private volatile boolean streamActive; /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ private final AtomicLong timeOfLastRequestNs; @@ -87,9 +89,15 @@ public void removeRpcRequest(long requestId) { } public void addStreamCallback(StreamCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); streamCallbacks.offer(callback); } + @VisibleForTesting + public void deactivateStream() { + streamActive = false; + } + /** * Fire the failure callback for all outstanding requests. This is called when we have an * uncaught exception or pre-mature connection termination. @@ -177,14 +185,16 @@ public void handle(ResponseMessage message) { StreamResponse resp = (StreamResponse) message; StreamCallback callback = streamCallbacks.poll(); if (callback != null) { - StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount, + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, callback); try { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); frameDecoder.setInterceptor(interceptor); + streamActive = true; } catch (Exception e) { logger.error("Error installing stream handler.", e); + deactivateStream(); } } else { logger.error("Could not find callback for StreamResponse."); @@ -208,7 +218,8 @@ public void handle(ResponseMessage message) { /** Returns total number of outstanding requests (fetch requests + rpcs) */ public int numOutstandingRequests() { - return outstandingFetches.size() + outstandingRpcs.size(); + return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() + + (streamActive ? 1 : 0); } /** Returns the time in nanoseconds of when the last request was sent out. */ diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 17a03ebe88a93..30144f4a9fc7a 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; @@ -28,12 +29,16 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.StreamFailure; +import org.apache.spark.network.protocol.StreamResponse; +import org.apache.spark.network.util.TransportFrameDecoder; public class TransportResponseHandlerSuite { @Test @@ -112,4 +117,26 @@ public void handleFailedRPC() { verify(callback, times(1)).onFailure((Throwable) any()); assertEquals(0, handler.numOutstandingRequests()); } + + @Test + public void testActiveStreams() { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamResponse response = new StreamResponse("stream", 1234L, null); + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(response); + assertEquals(1, handler.numOutstandingRequests()); + handler.deactivateStream(); + assertEquals(0, handler.numOutstandingRequests()); + + StreamFailure failure = new StreamFailure("stream", "uh-oh"); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(failure); + assertEquals(0, handler.numOutstandingRequests()); + } } From 8f40a1b314abc6e90b707f3dad191a3886ce18b6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 23 Nov 2015 11:51:29 -0800 Subject: [PATCH 0882/1168] [SPARK-11920][ML][DOC] ML LinearRegression should use correct dataset in examples and user guide doc ML ```LinearRegression``` use ```data/mllib/sample_libsvm_data.txt``` as dataset in examples and user guide doc, but it's actually classification dataset rather than regression dataset. We should use ```data/mllib/sample_linear_regression_data.txt``` instead. The deeper causes is that ```LinearRegression``` with "normal" solver can not solve this dataset correctly, may be due to the ill condition and unreasonable label. This issue has been reported at [SPARK-11918](https://issues.apache.org/jira/browse/SPARK-11918). It will confuse users if they run the example code but get exception, so we should make this change which can clearly illustrate the usage of ```LinearRegression``` algorithm. Author: Yanbo Liang Closes #9905 from yanboliang/spark-11920. --- .../examples/ml/JavaLinearRegressionWithElasticNetExample.java | 2 +- .../src/main/python/ml/linear_regression_with_elastic_net.py | 3 ++- .../examples/ml/LinearRegressionWithElasticNetExample.scala | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java index 593f8fb3e9fe9..4ad7676c8d32b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java @@ -37,7 +37,7 @@ public static void main(String[] args) { // $example on$ // Load training data DataFrame training = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); + .load("data/mllib/sample_linear_regression_data.txt"); LinearRegression lr = new LinearRegression() .setMaxIter(10) diff --git a/examples/src/main/python/ml/linear_regression_with_elastic_net.py b/examples/src/main/python/ml/linear_regression_with_elastic_net.py index b0278276330c3..a4cd40cf26726 100644 --- a/examples/src/main/python/ml/linear_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/linear_regression_with_elastic_net.py @@ -29,7 +29,8 @@ # $example on$ # Load training data - training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + training = sqlContext.read.format("libsvm")\ + .load("data/mllib/sample_linear_regression_data.txt") lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala index 5a51ece6f9ba7..22c824cea84d3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala @@ -33,7 +33,8 @@ object LinearRegressionWithElasticNetExample { // $example on$ // Load training data - val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val training = sqlCtx.read.format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt") val lr = new LinearRegression() .setMaxIter(10) From 92349266ee3501dcbf8b2411f3bd5b136ce61c4a Mon Sep 17 00:00:00 2001 From: Mortada Mehyar Date: Mon, 23 Nov 2015 12:03:15 -0800 Subject: [PATCH 0883/1168] [SPARK-11837][EC2] python3 compatibility for launching ec2 m3 instances this currently breaks for python3 because `string` module doesn't have `letters` anymore, instead `ascii_letters` should be used Author: Mortada Mehyar Closes #9797 from mortada/python3_fix. --- ec2/spark_ec2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 9327e21e43db7..9fd652a3df4c4 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -595,7 +595,7 @@ def launch_cluster(conn, opts, cluster_name): dev = BlockDeviceType() dev.ephemeral_name = 'ephemeral%d' % i # The first ephemeral drive is /dev/sdb. - name = '/dev/sd' + string.letters[i + 1] + name = '/dev/sd' + string.ascii_letters[i + 1] block_map[name] = dev # Launch slaves From 5056b24a11913bfec29cc5f136893880a6dc69f6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 23 Nov 2015 13:19:10 -0800 Subject: [PATCH 0884/1168] [SPARK-4424] Remove spark.driver.allowMultipleContexts override in tests This patch removes `spark.driver.allowMultipleContexts=true` from our test configuration. The multiple SparkContexts check was originally disabled because certain tests suites in SQL needed to create multiple contexts. As far as I know, this configuration change is no longer necessary, so we should remove it in order to make it easier to find test cleanup bugs. Author: Josh Rosen Closes #9865 from JoshRosen/SPARK-4424. --- pom.xml | 2 -- project/SparkBuild.scala | 1 - 2 files changed, 3 deletions(-) diff --git a/pom.xml b/pom.xml index ad849112ce76c..234fd5dea1a6e 100644 --- a/pom.xml +++ b/pom.xml @@ -1958,7 +1958,6 @@ false false false - true true src @@ -1997,7 +1996,6 @@ 1 false false - true true __not_used__ diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 67724c4e9e411..f575f0012d59e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -632,7 +632,6 @@ object TestSettings { javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", - javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test += "-Dderby.system.durability=test", From 104f61445ce5a9aa9afa4ca0e09cfc5f4540ced1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 23 Nov 2015 13:44:30 -0800 Subject: [PATCH 0885/1168] [SPARK-11836][SQL] udf/cast should not create new SQLContext They should use the existing SQLContext. Author: Davies Liu Closes #9914 from davies/create_udf. --- python/pyspark/sql/column.py | 7 ++++--- python/pyspark/sql/functions.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 9ca8e1f264cfa..81fd4e782628a 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -346,9 +346,10 @@ def cast(self, dataType): if isinstance(dataType, basestring): jc = self._jc.cast(dataType) elif isinstance(dataType, DataType): - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(dataType.json()) + from pyspark.sql import SQLContext + sc = SparkContext.getOrCreate() + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(dataType.json()) jc = self._jc.cast(jdt) else: raise TypeError("unexpected type: %s" % type(dataType)) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index c3da513c13897..a1ca723bbd7ab 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1457,14 +1457,15 @@ def __init__(self, func, returnType, name=None): self._judf = self._create_judf(name) def _create_judf(self, name): + from pyspark.sql import SQLContext f, returnType = self.func, self.returnType # put them in closure `func` func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) - sc = SparkContext._active_spark_context + sc = SparkContext.getOrCreate() pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(self.returnType.json()) + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes, From a84d6f5199220d0360b935ad86ba128cf344c43d Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Mon, 23 Nov 2015 13:46:34 -0800 Subject: [PATCH 0886/1168] [SPARK-11910][STREAMING][DOCS] Update twitter4j dependency version Author: Luciano Resende Closes #9892 from lresende/SPARK-11910. --- docs/streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 96b36b7a73209..ed6b28c282135 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -723,7 +723,7 @@ Some of these advanced sources are as follows. - **Kinesis:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kinesis Client Library 1.2.1. See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. -- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using +- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j to get the public stream of tweets using [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information can be provided by any of the [methods](http://twitter4j.org/en/configuration.html) supported by Twitter4J library. You can either get the public stream, or get the filtered stream based on a From d989b639c960a15bee9c91b393e3a2d87d092df5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 13:51:43 -0800 Subject: [PATCH 0887/1168] [SPARK-11865][NETWORK] Avoid returning inactive client in TransportClientFactory. There's a very narrow race here where it would be possible for the timeout handler to close a channel after the client factory verified that the channel was still active. This change makes sure the client is marked as being recently in use so that the timeout handler does not close it until a new timeout cycle elapses. Author: Marcelo Vanzin Closes #9853 from vanzin/SPARK-11865. --- .../spark/network/client/TransportClient.java | 9 ++++- .../client/TransportClientFactory.java | 15 ++++++-- .../client/TransportResponseHandler.java | 9 +++-- .../server/TransportChannelHandler.java | 36 ++++++++++++------- 4 files changed, 52 insertions(+), 17 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index a0ba223e340a2..876fcd846791c 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -73,10 +73,12 @@ public class TransportClient implements Closeable { private final Channel channel; private final TransportResponseHandler handler; @Nullable private String clientId; + private volatile boolean timedOut; public TransportClient(Channel channel, TransportResponseHandler handler) { this.channel = Preconditions.checkNotNull(channel); this.handler = Preconditions.checkNotNull(handler); + this.timedOut = false; } public Channel getChannel() { @@ -84,7 +86,7 @@ public Channel getChannel() { } public boolean isActive() { - return channel.isOpen() || channel.isActive(); + return !timedOut && (channel.isOpen() || channel.isActive()); } public SocketAddress getSocketAddress() { @@ -263,6 +265,11 @@ public void onFailure(Throwable e) { } } + /** Mark this channel as having timed out. */ + public void timeOut() { + this.timedOut = true; + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 42a4f664e697c..659c47160c7be 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -136,8 +136,19 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO TransportClient cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null && cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); - return cachedClient; + // Make sure that the channel will not timeout by updating the last use time of the + // handler. Then check that the client is still alive, in case it timed out before + // this code was able to update things. + TransportChannelHandler handler = cachedClient.getChannel().pipeline() + .get(TransportChannelHandler.class); + synchronized (handler) { + handler.getResponseHandler().updateTimeOfLastRequest(); + } + + if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); + return cachedClient; + } } // If we reach here, we don't have an existing connection open. Let's create a new one. diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index cc88991b588c1..be181e0660826 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -71,7 +71,7 @@ public TransportResponseHandler(Channel channel) { } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingFetches.put(streamChunkId, callback); } @@ -80,7 +80,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { } public void addRpcRequest(long requestId, RpcResponseCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingRpcs.put(requestId, callback); } @@ -227,4 +227,9 @@ public long getTimeOfLastRequestNs() { return timeOfLastRequestNs.get(); } + /** Updates the time of the last request to the current system time. */ + public void updateTimeOfLastRequest() { + timeOfLastRequestNs.set(System.nanoTime()); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index f8fcd1c3d7d76..29d688a67578c 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -116,20 +116,32 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc // there are outstanding requests, we also do a secondary consistency check to ensure // there's no race between the idle timeout and incrementing the numOutstandingRequests // (see SPARK-7003). - boolean isActuallyOverdue = - System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; - if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { - if (responseHandler.numOutstandingRequests() > 0) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); - logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + - "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + - "is wrong.", address, requestTimeoutNs / 1000 / 1000); - ctx.close(); - } else if (closeIdleConnections) { - // While CloseIdleConnections is enable, we also close idle connection - ctx.close(); + // + // To avoid a race between TransportClientFactory.createClient() and this code which could + // result in an inactive client being returned, this needs to run in a synchronized block. + synchronized (this) { + boolean isActuallyOverdue = + System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; + if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { + if (responseHandler.numOutstandingRequests() > 0) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + client.timeOut(); + ctx.close(); + } else if (closeIdleConnections) { + // While CloseIdleConnections is enable, we also close idle connection + client.timeOut(); + ctx.close(); + } } } } } + + public TransportResponseHandler getResponseHandler() { + return responseHandler; + } + } From 517a74b6efe666972202e9a57b503e82fcec7410 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 13:54:19 -0800 Subject: [PATCH 0888/1168] [SPARK-11140][CORE] Transfer files using network lib when using NettyRpcEnv. This change abstracts the code that serves jars / files to executors so that each RpcEnv can have its own implementation; the akka version uses the existing HTTP-based file serving mechanism, while the netty versions uses the new stream support added to the network lib, which makes file transfers benefit from the easier security configuration of the network library, and should also reduce overhead overall. The change includes a small fix to TransportChannelHandler so that it propagates user events to downstream handlers. Author: Marcelo Vanzin Closes #9530 from vanzin/SPARK-11140. --- .../scala/org/apache/spark/SparkContext.scala | 8 +- .../scala/org/apache/spark/SparkEnv.scala | 14 -- .../scala/org/apache/spark/rpc/RpcEnv.scala | 46 ++++++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 60 +++++++- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 138 ++++++++++++++++-- .../spark/rpc/netty/NettyStreamManager.scala | 63 ++++++++ .../scala/org/apache/spark/util/Utils.scala | 9 ++ .../org/apache/spark/rpc/RpcEnvSuite.scala | 39 ++++- .../rpc/netty/NettyRpcHandlerSuite.scala | 10 +- docs/configuration.md | 2 + docs/security.md | 5 +- .../launcher/AbstractCommandBuilder.java | 2 +- .../client/TransportClientFactory.java | 6 +- .../server/TransportChannelHandler.java | 1 + 14 files changed, 356 insertions(+), 47 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index af4456c05b0a1..b153a7b08e590 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } val key = if (!isLocal && scheme == "file") { - env.httpFileServer.addFile(new File(uri.getPath)) + env.rpcEnv.fileServer.addFile(new File(uri.getPath)) } else { schemeCorrectedPath } @@ -1630,7 +1630,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli var key = "" if (path.contains("\\")) { // For local paths with backslashes on Windows, URI throws an exception - key = env.httpFileServer.addJar(new File(path)) + key = env.rpcEnv.fileServer.addJar(new File(path)) } else { val uri = new URI(path) key = uri.getScheme match { @@ -1644,7 +1644,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() try { - env.httpFileServer.addJar(new File(fileName)) + env.rpcEnv.fileServer.addJar(new File(fileName)) } catch { case e: Exception => // For now just log an error but allow to go through so spark examples work. @@ -1655,7 +1655,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } else { try { - env.httpFileServer.addJar(new File(uri.getPath)) + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) } catch { case exc: FileNotFoundException => logError(s"Jar not found at $path") diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 88df27f733f2a..84230e32a4462 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -66,7 +66,6 @@ class SparkEnv ( val blockTransferService: BlockTransferService, val blockManager: BlockManager, val securityManager: SecurityManager, - val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, val memoryManager: MemoryManager, @@ -91,7 +90,6 @@ class SparkEnv ( if (!isStopped) { isStopped = true pythonWorkers.values.foreach(_.stop()) - Option(httpFileServer).foreach(_.stop()) mapOutputTracker.stop() shuffleManager.stop() broadcastManager.stop() @@ -367,17 +365,6 @@ object SparkEnv extends Logging { val cacheManager = new CacheManager(blockManager) - val httpFileServer = - if (isDriver) { - val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(conf, securityManager, fileServerPort) - server.initialize() - conf.set("spark.fileserver.uri", server.serverUri) - server - } else { - null - } - val metricsSystem = if (isDriver) { // Don't start metrics system right now for Driver. // We need to wait for the task scheduler to give us an app ID. @@ -422,7 +409,6 @@ object SparkEnv extends Logging { blockTransferService, blockManager, securityManager, - httpFileServer, sparkFilesDir, metricsSystem, memoryManager, diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index a560fd10cdf76..3d7d281b0dd66 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,6 +17,9 @@ package org.apache.spark.rpc +import java.io.File +import java.nio.channels.ReadableByteChannel + import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} @@ -132,8 +135,51 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. */ def deserialize[T](deserializationAction: () => T): T + + /** + * Return the instance of the file server used to serve files. This may be `null` if the + * RpcEnv is not operating in server mode. + */ + def fileServer: RpcEnvFileServer + + /** + * Open a channel to download a file from the given URI. If the URIs returned by the + * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to + * retrieve the files. + * + * @param uri URI with location of the file. + */ + def openChannel(uri: String): ReadableByteChannel + } +/** + * A server used by the RpcEnv to server files to other processes owned by the application. + * + * The file server can return URIs handled by common libraries (such as "http" or "hdfs"), or + * it can return "spark" URIs which will be handled by `RpcEnv#fetchFile`. + */ +private[spark] trait RpcEnvFileServer { + + /** + * Adds a file to be served by this RpcEnv. This is used to serve files from the driver + * to executors when they're stored on the driver's local file system. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addFile(file: File): String + + /** + * Adds a jar to be served by this RpcEnv. Similar to `addFile` but for jars added using + * `SparkContext.addJar`. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addJar(file: File): String + +} private[spark] case class RpcEnvConfig( conf: SparkConf, diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 059a7e10ec12f..94dbec593c315 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -17,6 +17,8 @@ package org.apache.spark.rpc.akka +import java.io.File +import java.nio.channels.ReadableByteChannel import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future @@ -30,7 +32,7 @@ import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} import akka.serialization.JavaSerializer -import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.rpc._ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} @@ -41,7 +43,10 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} * remove Akka from the dependencies. */ private[spark] class AkkaRpcEnv private[akka] ( - val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) + val actorSystem: ActorSystem, + val securityManager: SecurityManager, + conf: SparkConf, + boundPort: Int) extends RpcEnv(conf) with Logging { private val defaultAddress: RpcAddress = { @@ -64,6 +69,8 @@ private[spark] class AkkaRpcEnv private[akka] ( */ private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + private val _fileServer = new AkkaFileServer(conf, securityManager) + private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { endpointToRef.put(endpoint, endpointRef) refToEndpoint.put(endpointRef, endpoint) @@ -223,6 +230,7 @@ private[spark] class AkkaRpcEnv private[akka] ( override def shutdown(): Unit = { actorSystem.shutdown() + _fileServer.shutdown() } override def stop(endpoint: RpcEndpointRef): Unit = { @@ -241,6 +249,52 @@ private[spark] class AkkaRpcEnv private[akka] ( deserializationAction() } } + + override def openChannel(uri: String): ReadableByteChannel = { + throw new UnsupportedOperationException( + "AkkaRpcEnv's files should be retrieved using an HTTP client.") + } + + override def fileServer: RpcEnvFileServer = _fileServer + +} + +private[akka] class AkkaFileServer( + conf: SparkConf, + securityManager: SecurityManager) extends RpcEnvFileServer { + + @volatile private var httpFileServer: HttpFileServer = _ + + override def addFile(file: File): String = { + getFileServer().addFile(file) + } + + override def addJar(file: File): String = { + getFileServer().addJar(file) + } + + def shutdown(): Unit = { + if (httpFileServer != null) { + httpFileServer.stop() + } + } + + private def getFileServer(): HttpFileServer = { + if (httpFileServer == null) synchronized { + if (httpFileServer == null) { + httpFileServer = startFileServer() + } + } + httpFileServer + } + + private def startFileServer(): HttpFileServer = { + val fileServerPort = conf.getInt("spark.fileserver.port", 0) + val server = new HttpFileServer(conf, securityManager, fileServerPort) + server.initialize() + server + } + } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -249,7 +303,7 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( config.name, config.host, config.port, config.conf, config.securityManager) actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") - new AkkaRpcEnv(actorSystem, config.conf, boundPort) + new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 3ce359868039b..68701f609f77a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -20,6 +20,7 @@ import java.io._ import java.lang.{Boolean => JBoolean} import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer +import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.Nullable @@ -45,27 +46,39 @@ private[netty] class NettyRpcEnv( host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { - private val transportConf = SparkTransportConf.fromSparkConf( + private[netty] val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) + private val streamManager = new NettyStreamManager(this) + private val transportContext = new TransportContext(transportConf, - new NettyRpcHandler(dispatcher, this)) + new NettyRpcHandler(dispatcher, this, streamManager)) - private val clientFactory = { - val bootstraps: java.util.List[TransportClientBootstrap] = - if (securityManager.isAuthenticationEnabled()) { - java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, - securityManager.isSaslEncryptionEnabled())) - } else { - java.util.Collections.emptyList[TransportClientBootstrap] - } - transportContext.createClientFactory(bootstraps) + private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = { + if (securityManager.isAuthenticationEnabled()) { + java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, + securityManager.isSaslEncryptionEnabled())) + } else { + java.util.Collections.emptyList[TransportClientBootstrap] + } } + private val clientFactory = transportContext.createClientFactory(createClientBootstraps()) + + /** + * A separate client factory for file downloads. This avoids using the same RPC handler as + * the main RPC context, so that events caused by these clients are kept isolated from the + * main RPC traffic. + * + * It also allows for different configuration of certain properties, such as the number of + * connections per peer. + */ + @volatile private var fileDownloadFactory: TransportClientFactory = _ + val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool @@ -292,6 +305,9 @@ private[netty] class NettyRpcEnv( if (clientConnectionExecutor != null) { clientConnectionExecutor.shutdownNow() } + if (fileDownloadFactory != null) { + fileDownloadFactory.close() + } } override def deserialize[T](deserializationAction: () => T): T = { @@ -300,6 +316,96 @@ private[netty] class NettyRpcEnv( } } + override def fileServer: RpcEnvFileServer = streamManager + + override def openChannel(uri: String): ReadableByteChannel = { + val parsedUri = new URI(uri) + require(parsedUri.getHost() != null, "Host name must be defined.") + require(parsedUri.getPort() > 0, "Port must be defined.") + require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.") + + val pipe = Pipe.open() + val source = new FileDownloadChannel(pipe.source()) + try { + val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) + val callback = new FileDownloadCallback(pipe.sink(), source, client) + client.stream(parsedUri.getPath(), callback) + } catch { + case e: Exception => + pipe.sink().close() + source.close() + throw e + } + + source + } + + private def downloadClient(host: String, port: Int): TransportClient = { + if (fileDownloadFactory == null) synchronized { + if (fileDownloadFactory == null) { + val module = "files" + val prefix = "spark.rpc.io." + val clone = conf.clone() + + // Copy any RPC configuration that is not overridden in the spark.files namespace. + conf.getAll.foreach { case (key, value) => + if (key.startsWith(prefix)) { + val opt = key.substring(prefix.length()) + clone.setIfMissing(s"spark.$module.io.$opt", value) + } + } + + val ioThreads = clone.getInt("spark.files.io.threads", 1) + val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads) + val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true) + fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps()) + } + } + fileDownloadFactory.createClient(host, port) + } + + private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { + + @volatile private var error: Throwable = _ + + def setError(e: Throwable): Unit = error = e + + override def read(dst: ByteBuffer): Int = { + if (error != null) { + throw error + } + source.read(dst) + } + + override def close(): Unit = source.close() + + override def isOpen(): Boolean = source.isOpen() + + } + + private class FileDownloadCallback( + sink: WritableByteChannel, + source: FileDownloadChannel, + client: TransportClient) extends StreamCallback { + + override def onData(streamId: String, buf: ByteBuffer): Unit = { + while (buf.remaining() > 0) { + sink.write(buf) + } + } + + override def onComplete(streamId: String): Unit = { + sink.close() + } + + override def onFailure(streamId: String, cause: Throwable): Unit = { + logError(s"Error downloading stream $streamId.", cause) + source.setError(cause) + sink.close() + } + + } + } private[netty] object NettyRpcEnv extends Logging { @@ -420,7 +526,7 @@ private[netty] class NettyRpcEndpointRef( override def toString: String = s"NettyRpcEndpointRef(${_address})" - def toURI: URI = new URI(s"spark://${_address}") + def toURI: URI = new URI(_address.toString) final override def equals(that: Any): Boolean = that match { case other: NettyRpcEndpointRef => _address == other._address @@ -471,7 +577,9 @@ private[netty] case class RpcFailure(e: Throwable) * with different `RpcAddress` information). */ private[netty] class NettyRpcHandler( - dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { + dispatcher: Dispatcher, + nettyEnv: NettyRpcEnv, + streamManager: StreamManager) extends RpcHandler with Logging { // TODO: Can we add connection callback (channel registered) to the underlying framework? // A variable to track whether we should dispatch the RemoteProcessConnected message. @@ -498,7 +606,7 @@ private[netty] class NettyRpcHandler( dispatcher.postRemoteMessage(messageToDispatch, callback) } - override def getStreamManager: StreamManager = new OneForOneStreamManager + override def getStreamManager: StreamManager = streamManager override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] @@ -516,8 +624,8 @@ private[netty] class NettyRpcHandler( override def connectionTerminated(client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) clients.remove(client) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) } else { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala new file mode 100644 index 0000000000000..eb1d2604fb235 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import java.io.File +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.server.StreamManager +import org.apache.spark.rpc.RpcEnvFileServer + +/** + * StreamManager implementation for serving files from a NettyRpcEnv. + */ +private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) + extends StreamManager with RpcEnvFileServer { + + private val files = new ConcurrentHashMap[String, File]() + private val jars = new ConcurrentHashMap[String, File]() + + override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { + throw new UnsupportedOperationException() + } + + override def openStream(streamId: String): ManagedBuffer = { + val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2) + val file = ftype match { + case "files" => files.get(fname) + case "jars" => jars.get(fname) + case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") + } + + require(file != null, s"File not found: $streamId") + new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) + } + + override def addFile(file: File): String = { + require(files.putIfAbsent(file.getName(), file) == null, + s"File ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/files/${file.getName()}" + } + + override def addJar(file: File): String = { + require(jars.putIfAbsent(file.getName(), file) == null, + s"JAR ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}" + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1b3acb8ef7f51..af632349c9cae 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,6 +21,7 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer +import java.nio.channels.Channels import java.util.concurrent._ import java.util.{Locale, Properties, Random, UUID} import javax.net.ssl.HttpsURLConnection @@ -535,6 +536,14 @@ private[spark] object Utils extends Logging { val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) Option(uri.getScheme).getOrElse("file") match { + case "spark" => + if (SparkEnv.get == null) { + throw new IllegalStateException( + "Cannot retrieve files with 'spark' scheme without an active SparkEnv.") + } + val source = SparkEnv.get.rpcEnv.openChannel(url) + val is = Channels.newInputStream(source) + downloadFile(url, is, targetFile, fileOverwrite) case "http" | "https" | "ftp" => var uc: URLConnection = null if (securityMgr.isAuthenticationEnabled()) { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 2f55006420ce1..2b664c6313efa 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.rpc -import java.io.NotSerializableException +import java.io.{File, NotSerializableException} +import java.util.UUID +import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} import scala.collection.mutable @@ -25,10 +27,14 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps +import com.google.common.io.Files +import org.mockito.Mockito.{mock, when} import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils /** * Common tests for an RpcEnv implementation. @@ -40,12 +46,17 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { val conf = new SparkConf() env = createRpcEnv(conf, "local", 0) + + val sparkEnv = mock(classOf[SparkEnv]) + when(sparkEnv.rpcEnv).thenReturn(env) + SparkEnv.set(sparkEnv) } override def afterAll(): Unit = { if (env != null) { env.shutdown() } + SparkEnv.set(null) } def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv @@ -713,6 +724,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1) } + test("file server") { + val conf = new SparkConf() + val tempDir = Utils.createTempDir() + val file = new File(tempDir, "file") + Files.write(UUID.randomUUID().toString(), file, UTF_8) + val jar = new File(tempDir, "jar") + Files.write(UUID.randomUUID().toString(), jar, UTF_8) + + val fileUri = env.fileServer.addFile(file) + val jarUri = env.fileServer.addJar(jar) + + val destDir = Utils.createTempDir() + val destFile = new File(destDir, file.getName()) + val destJar = new File(destDir, jar.getName()) + + val sm = new SecurityManager(conf) + val hc = SparkHadoopUtil.get.conf + Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false) + Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false) + + assert(Files.equal(file, destFile)) + assert(Files.equal(jar, destJar)) + } + } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index f9d8e80c98b66..ccca795683da3 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -25,17 +25,19 @@ import org.mockito.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} +import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc._ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) - when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())). - thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) + val sm = mock(classOf[StreamManager]) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())) + .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) @@ -47,7 +49,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { test("connectionTerminated") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) diff --git a/docs/configuration.md b/docs/configuration.md index c496146e3ed63..4de202d7f7631 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1020,6 +1020,7 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the executor to listen on. This is used for communicating with the driver. + This is only relevant when using the Akka RPC backend. @@ -1027,6 +1028,7 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the driver's HTTP file server to listen on. + This is only relevant when using the Akka RPC backend. diff --git a/docs/security.md b/docs/security.md index 177109415180b..e1af221d446b0 100644 --- a/docs/security.md +++ b/docs/security.md @@ -149,7 +149,8 @@ configure those ports. (random) Schedule tasks spark.executor.port - Akka-based. Set to "0" to choose a port randomly. + Akka-based. Set to "0" to choose a port randomly. Only used if Akka RPC backend is + configured. Executor @@ -157,7 +158,7 @@ configure those ports. (random) File server for files and jars spark.fileserver.port - Jetty-based + Jetty-based. Only used if Akka RPC backend is configured. Executor diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 3ee6bd92e47fc..55fe156cf665f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -148,7 +148,7 @@ List buildClassPath(String appClassPath) throws IOException { String scala = getScalaVersion(); List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", - "yarn", "launcher"); + "yarn", "launcher", "network/common", "network/shuffle", "network/yarn"); if (prependClasses) { if (!isTesting) { System.err.println( diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 659c47160c7be..61bafc8380049 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -170,8 +170,10 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO } /** - * Create a completely new {@link TransportClient} to the given remote host / port - * But this connection is not pooled. + * Create a completely new {@link TransportClient} to the given remote host / port. + * This connection is not pooled. + * + * As with {@link #createClient(String, int)}, this method is blocking. */ public TransportClient createUnmanagedClient(String remoteHost, int remotePort) throws IOException { diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 29d688a67578c..3164e00679035 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -138,6 +138,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } } } + ctx.fireUserEventTriggered(evt); } public TransportResponseHandler getResponseHandler() { From 398e7db585e81bd83adbd22be8838dd90f63f11f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 23 Nov 2015 16:33:26 -0800 Subject: [PATCH 0889/1168] [SPARK-9866][SQL] Speed up VersionsSuite by using persistent Ivy cache This patch attempts to speed up VersionsSuite by storing fetched Hive JARs in an Ivy cache that persists across tests runs. If `SPARK_VERSIONS_SUITE_IVY_PATH` is set, that path will be used for the cache; if it is not set, VersionsSuite will create a temporary Ivy cache which is deleted after the test completes. Author: Josh Rosen Closes #9624 from JoshRosen/SPARK-9866. --- .../apache/spark/sql/hive/client/VersionsSuite.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index c6d034a23a1c6..7bc13bc60d30e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -36,10 +36,12 @@ import org.apache.spark.util.Utils @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { - // Do not use a temp path here to speed up subsequent executions of the unit test during - // development. - private val ivyPath = Some( - new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath()) + // In order to speed up test execution during development or in Jenkins, you can specify the path + // of an existing Ivy cache: + private val ivyPath: Option[String] = { + sys.env.get("SPARK_VERSIONS_SUITE_IVY_PATH").orElse( + Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) + } private def buildConf() = { lazy val warehousePath = Utils.createTempDir() From 8fb0b02947afb12209ccfe5b70f14a69b39496b4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 23 Nov 2015 17:11:51 -0800 Subject: [PATCH 0890/1168] [SPARK-10560][PYSPARK][MLLIB][DOCS] Make StreamingLogisticRegressionWithSGD Python API equal to Scala one This is to bring the API documentation of StreamingLogisticReressionWithSGD and StreamingLinearRegressionWithSGC in line with the Scala versions. -Fixed the algorithm descriptions -Added default values to parameter descriptions -Changed StreamingLogisticRegressionWithSGD regParam to default to 0, as in the Scala version Author: Bryan Cutler Closes #9141 from BryanCutler/StreamingLogisticRegressionWithSGD-python-api-sync. --- python/pyspark/mllib/classification.py | 37 +++++++++++++++++--------- python/pyspark/mllib/regression.py | 32 ++++++++++++++-------- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index aab4015ba80f8..9e6f17ef6e942 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -652,21 +652,34 @@ def train(cls, data, lambda_=1.0): @inherit_doc class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LogisticRegression with SGD on a batch of data. - - The weights obtained at the end of training a stream are used as initial - weights for the next batch. - - :param stepSize: Step size for each iteration of gradient descent. - :param numIterations: Number of iterations run for each batch of data. - :param miniBatchFraction: Fraction of data on which SGD is run for each - iteration. - :param regParam: L2 Regularization parameter. - :param convergenceTol: A condition which decides iteration termination. + Train or predict a logistic regression model on streaming data. Training uses + Stochastic Gradient Descent to update the model based on each new batch of + incoming data from a DStream. + + Each batch of data is assumed to be an RDD of LabeledPoints. + The number of data points per batch can vary, but the number + of features must be constant. An initial weight + vector must be provided. + + :param stepSize: + Step size for each iteration of gradient descent. + (default: 0.1) + :param numIterations: + Number of iterations run for each batch of data. + (default: 50) + :param miniBatchFraction: + Fraction of each batch of data to use for updates. + (default: 1.0) + :param regParam: + L2 Regularization parameter. + (default: 0.0) + :param convergenceTol: + Value used to determine when to terminate iterations. + (default: 0.001) .. versionadded:: 1.5.0 """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01, + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.0, convergenceTol=0.001): self.stepSize = stepSize self.numIterations = numIterations diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 6f00d1df209c0..13b3397501c0b 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -734,17 +734,27 @@ def predictOnValues(self, dstream): @inherit_doc class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LinearRegression with SGD on a batch of data. - - The problem minimized is (1 / n_samples) * (y - weights'X)**2. - After training on a batch of data, the weights obtained at the end of - training are used as initial weights for the next batch. - - :param stepSize: Step size for each iteration of gradient descent. - :param numIterations: Total number of iterations run. - :param miniBatchFraction: Fraction of data on which SGD is run for each - iteration. - :param convergenceTol: A condition which decides iteration termination. + Train or predict a linear regression model on streaming data. Training uses + Stochastic Gradient Descent to update the model based on each new batch of + incoming data from a DStream (see `LinearRegressionWithSGD` for model equation). + + Each batch of data is assumed to be an RDD of LabeledPoints. + The number of data points per batch can vary, but the number + of features must be constant. An initial weight + vector must be provided. + + :param stepSize: + Step size for each iteration of gradient descent. + (default: 0.1) + :param numIterations: + Number of iterations run for each batch of data. + (default: 50) + :param miniBatchFraction: + Fraction of each batch of data to use for updates. + (default: 1.0) + :param convergenceTol: + Value used to determine when to terminate iterations. + (default: 0.001) .. versionadded:: 1.5.0 """ From 4911efa1506d9bfd614d3bafb038aef45ee3c1c8 Mon Sep 17 00:00:00 2001 From: Stephen Samuel Date: Mon, 23 Nov 2015 19:52:12 -0800 Subject: [PATCH 0891/1168] Updated sql programming guide to include jdbc fetch size Author: Stephen Samuel Closes #9377 from sksamuel/master. --- docs/sql-programming-guide.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e347754055e79..d7b205c2fa0df 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1820,6 +1820,7 @@ the Data Sources API. The following options are supported: register itself with the JDBC subsystem. + partitionColumn, lowerBound, upperBound, numPartitions @@ -1831,6 +1832,13 @@ the Data Sources API. The following options are supported: partitioned and returned. + + + fetchSize + + The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows). + +
    From 891253913ca3b94721598e973678f78e4e5e56ee Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 23 Nov 2015 22:22:15 -0800 Subject: [PATCH 0892/1168] [SPARK-11933][SQL] Rename mapGroup -> mapGroups and flatMapGroup -> flatMapGroups. Based on feedback from Matei, this is more consistent with mapPartitions in Spark. Also addresses some of the cleanups from a previous commit that renames the type variables. Author: Reynold Xin Closes #9919 from rxin/SPARK-11933. --- ...nction.java => FlatMapGroupsFunction.java} | 2 +- ...upFunction.java => MapGroupsFunction.java} | 2 +- .../org/apache/spark/sql/GroupedDataset.scala | 36 +++++++++---------- .../apache/spark/sql/JavaDatasetSuite.java | 10 +++--- .../spark/sql/DatasetPrimitiveSuite.scala | 4 +-- .../org/apache/spark/sql/DatasetSuite.scala | 12 +++---- 6 files changed, 33 insertions(+), 33 deletions(-) rename core/src/main/java/org/apache/spark/api/java/function/{FlatMapGroupFunction.java => FlatMapGroupsFunction.java} (93%) rename core/src/main/java/org/apache/spark/api/java/function/{MapGroupFunction.java => MapGroupsFunction.java} (93%) diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java similarity index 93% rename from core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java rename to core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java index 18a2d733ca70d..d7a80e7b129b0 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java @@ -23,6 +23,6 @@ /** * A function that returns zero or more output records from each grouping key and its values. */ -public interface FlatMapGroupFunction extends Serializable { +public interface FlatMapGroupsFunction extends Serializable { Iterable call(K key, Iterator values) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java similarity index 93% rename from core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java rename to core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java index 4f3f222e064bb..faa59eabc8b4f 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java @@ -23,6 +23,6 @@ /** * Base interface for a map function used in GroupedDataset's mapGroup function. */ -public interface MapGroupFunction extends Serializable { +public interface MapGroupsFunction extends Serializable { R call(K key, Iterator values) throws Exception; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 7f43ce16901b9..793a86b132907 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.expressions.Aggregator @Experimental class GroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], - tEncoder: Encoder[V], + vEncoder: Encoder[V], val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { @@ -53,12 +53,12 @@ class GroupedDataset[K, V] private[sql]( // queryexecution. private implicit val unresolvedKEncoder = encoderFor(kEncoder) - private implicit val unresolvedTEncoder = encoderFor(tEncoder) + private implicit val unresolvedVEncoder = encoderFor(vEncoder) private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) - private val resolvedTEncoder = - unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes) + private val resolvedVEncoder = + unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes) private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext @@ -76,7 +76,7 @@ class GroupedDataset[K, V] private[sql]( def keyAs[L : Encoder]: GroupedDataset[L, V] = new GroupedDataset( encoderFor[L], - unresolvedTEncoder, + unresolvedVEncoder, queryExecution, dataAttributes, groupingAttributes) @@ -110,13 +110,13 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def flatMapGroup[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { + def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups( f, resolvedKEncoder, - resolvedTEncoder, + resolvedVEncoder, groupingAttributes, logicalPlan)) } @@ -138,8 +138,8 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def flatMapGroup[U](f: FlatMapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - flatMapGroup((key, data) => f.call(key, data.asJava).asScala)(encoder) + def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) } /** @@ -158,9 +158,9 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def mapGroup[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) - flatMapGroup(func) + flatMapGroups(func) } /** @@ -179,8 +179,8 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def mapGroup[U](f: MapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - mapGroup((key, data) => f.call(key, data.asJava))(encoder) + def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + mapGroups((key, data) => f.call(key, data.asJava))(encoder) } /** @@ -192,8 +192,8 @@ class GroupedDataset[K, V] private[sql]( def reduce(f: (V, V) => V): Dataset[(K, V)] = { val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) - implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder) - flatMapGroup(func) + implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder) + flatMapGroups(func) } /** @@ -213,7 +213,7 @@ class GroupedDataset[K, V] private[sql]( private def withEncoder(c: Column): Column = c match { case tc: TypedColumn[_, _] => - tc.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes) + tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes) case _ => c } @@ -227,7 +227,7 @@ class GroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map( - _.withInputType(resolvedTEncoder, dataAttributes).named) + _.withInputType(resolvedVEncoder, dataAttributes).named) val keyColumn = if (groupingAttributes.length > 1) { Alias(CreateStruct(groupingAttributes), "key")() } else { @@ -304,7 +304,7 @@ class GroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.unresolvedTEncoder + implicit def uEnc: Encoder[U] = other.unresolvedVEncoder new Dataset[R]( sqlContext, CoGroup( diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index cf335efdd23b8..67a3190cb7d4f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -170,7 +170,7 @@ public Integer call(String v) throws Exception { } }, Encoders.INT()); - Dataset mapped = grouped.mapGroup(new MapGroupFunction() { + Dataset mapped = grouped.mapGroups(new MapGroupsFunction() { @Override public String call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); @@ -183,8 +183,8 @@ public String call(Integer key, Iterator values) throws Exception { Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); - Dataset flatMapped = grouped.flatMapGroup( - new FlatMapGroupFunction() { + Dataset flatMapped = grouped.flatMapGroups( + new FlatMapGroupsFunction() { @Override public Iterable call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); @@ -249,8 +249,8 @@ public void testGroupByColumn() { GroupedDataset grouped = ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); - Dataset mapped = grouped.mapGroup( - new MapGroupFunction() { + Dataset mapped = grouped.mapGroups( + new MapGroupsFunction() { @Override public String call(Integer key, Iterator data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index d387710357be0..f75d0961823c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -86,7 +86,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() val grouped = ds.groupBy(_ % 2) - val agged = grouped.mapGroup { case (g, iter) => + val agged = grouped.mapGroups { case (g, iter) => val name = if (g == 0) "even" else "odd" (name, iter.size) } @@ -99,7 +99,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq("a", "b", "c", "xyz", "hello").toDS() val grouped = ds.groupBy(_.length) - val agged = grouped.flatMapGroup { case (g, iter) => Iterator(g.toString, iter.mkString) } + val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) } checkAnswer( agged, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index cc8e4325fd2f5..dbdd7ba14a5b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -224,7 +224,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.mapGroup { case (g, iter) => (g._1, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) } checkAnswer( agged, @@ -234,7 +234,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.flatMapGroup { case (g, iter) => + val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } @@ -255,7 +255,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") - val agged = grouped.mapGroup { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } checkAnswer( agged, @@ -265,7 +265,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").keyAs[String] - val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -275,7 +275,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)] - val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -285,7 +285,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] - val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, From 5966b858387d22bf4ff151394fc450438d6e4291 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Mon, 23 Nov 2015 22:22:50 -0800 Subject: [PATCH 0893/1168] [SPARK-11903] Remove --skip-java-test Per [pwendell's comments on SPARK-11903](https://issues.apache.org/jira/browse/SPARK-11903?focusedCommentId=15021511&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15021511) I'm removing this dead code. If we are concerned about preserving compatibility, I can instead leave the option in and add a warning. For example: ```sh echo "Warning: '--skip-java-test' is deprecated and has no effect." ;; ``` cc pwendell, srowen Author: Nicholas Chammas Closes #9924 from nchammas/make-distribution. --- make-distribution.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/make-distribution.sh b/make-distribution.sh index d7d27e253f721..7b417fe7cf619 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -69,9 +69,6 @@ while (( "$#" )); do echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver" exit_with_usage ;; - --skip-java-test) - SKIP_JAVA_TEST=true - ;; --with-tachyon) SPARK_TACHYON=true ;; From 167f4fad328d5f5a5bef9dfc16ce47a5d2ad2090 Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Mon, 23 Nov 2015 22:26:08 -0800 Subject: [PATCH 0894/1168] [SPARK-10707][SQL] Fix nullability computation in union output Author: Mikhail Bautin Closes #9308 from mbautin/SPARK-10707. --- .../plans/logical/basicOperators.scala | 11 +++++-- .../spark/sql/execution/basicOperators.scala | 9 ++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 31 +++++++++++++++++++ 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 0c444482c5e4c..737e62fd59214 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -92,8 +92,10 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - // TODO: These aren't really the same attributes as nullability etc might change. - final override def output: Seq[Attribute] = left.output + override def output: Seq[Attribute] = + left.output.zip(right.output).map { case (leftAttr, rightAttr) => + leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable) + } final override lazy val resolved: Boolean = childrenResolved && @@ -115,7 +117,10 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) -case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) +case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + /** We don't use right.output because those rows get excluded from the set. */ + override def output: Seq[Attribute] = left.output +} case class Join( left: LogicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e79092efdaa3e..d57b8e7a9ed61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -130,8 +130,13 @@ case class Sample( * Union two plans, without a distinct. This is UNION ALL in SQL. */ case class Union(children: Seq[SparkPlan]) extends SparkPlan { - // TODO: attributes output by union should be distinct for nullability purposes - override def output: Seq[Attribute] = children.head.output + override def output: Seq[Attribute] = { + children.tail.foldLeft(children.head.output) { case (currentOutput, child) => + currentOutput.zip(child.output).map { case (a1, a2) => + a1.withNullability(a1.nullable || a2.nullable) + } + } + } override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows) override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 167aea87de077..bb82b562aaaa2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1997,4 +1997,35 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } + + test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { + // This test produced an incorrect result of 1 before the SPARK-10707 fix because of the + // NullPropagation rule: COUNT(v) got replaced with COUNT(1) because the output column of + // UNION was incorrectly considered non-nullable: + checkAnswer( + sql("""SELECT count(v) FROM ( + | SELECT v FROM ( + | SELECT 'foo' AS v UNION ALL + | SELECT NULL AS v + | ) my_union WHERE isnull(v) + |) my_subview""".stripMargin), + Seq(Row(0))) + } + + test("SPARK-10707: nullability should be correctly propagated through set operations (2)") { + // This test uses RAND() to stop column pruning for Union and checks the resulting isnull + // value. This would produce an incorrect result before the fix in SPARK-10707 because the "v" + // column of the union was considered non-nullable. + checkAnswer( + sql( + """ + |SELECT a FROM ( + | SELECT ISNULL(v) AS a, RAND() FROM ( + | SELECT 'foo' AS v UNION ALL SELECT null AS v + | ) my_union + |) my_view + """.stripMargin), + Row(false) :: Row(true) :: Nil) + } + } From 75b6df523bfa944fe857e9ca7731371eae316e6b Mon Sep 17 00:00:00 2001 From: Xiu Guo Date: Tue, 24 Nov 2015 00:07:40 -0800 Subject: [PATCH 0895/1168] [SPARK-11897][SQL] Add @scala.annotations.varargs to sql functions Author: Xiu Guo Closes #9918 from xguo27/SPARK-11897. --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b27b1340cce46..6137ce3a70fdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -689,6 +689,7 @@ object functions extends LegacyFunctions { * @group normal_funcs * @since 1.4.0 */ + @scala.annotation.varargs def array(colName: String, colNames: String*): Column = { array((colName +: colNames).map(col) : _*) } @@ -871,6 +872,7 @@ object functions extends LegacyFunctions { * @group normal_funcs * @since 1.4.0 */ + @scala.annotation.varargs def struct(colName: String, colNames: String*): Column = { struct((colName +: colNames).map(col) : _*) } From 09b9d0b7a20f60dc145ece9f99bcf0b822787b5f Mon Sep 17 00:00:00 2001 From: Forest Fang Date: Tue, 24 Nov 2015 09:03:32 +0000 Subject: [PATCH 0896/1168] [SPARK-11906][WEB UI] Speculation Tasks Cause ProgressBar UI Overflow When there are speculative tasks in the stage, running progress bar could overflow and goes hidden on a new line: ![image](https://cloud.githubusercontent.com/assets/4317392/11326841/5fd3482e-9142-11e5-8ca5-cb2f0c0c8964.png) 3 completed / 2 running (including 1 speculative) out of 4 total tasks This is a simple fix by capping the started tasks at `total - completed` tasks ![image](https://cloud.githubusercontent.com/assets/4317392/11326842/6bb67260-9142-11e5-90f0-37f9174878ec.png) I should note my preferred way to fix it is via css style ```css .progress { display: flex; } ``` which shifts the correction burden from driver to web browser. However I couldn't get selenium test to measure the position/dimension of the progress bar correctly to get this unit tested. It also has the side effect that the width will be calibrated so the running occupies 2 / 5 instead of 1 / 4. ![image](https://cloud.githubusercontent.com/assets/4317392/11326848/7b03e9f0-9142-11e5-89ad-bd99cb0647cf.png) All in all, since this cosmetic bug is minor enough, I suppose the original simple fix should be good enough. Author: Forest Fang Closes #9896 from saurfang/progressbar. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 4 +++- .../test/scala/org/apache/spark/ui/UIUtilsSuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 25dcb604d9e5f..84a1116a5c498 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -319,7 +319,9 @@ private[spark] object UIUtils extends Logging { skipped: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) - val startWidth = "width: %s%%".format((started.toDouble/total)*100) + // started + completed can be > total when there are speculative tasks + val boundedStarted = math.min(started, total - completed) + val startWidth = "width: %s%%".format((boundedStarted.toDouble/total)*100)
    diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index 2b693c165180f..dd8d5ec27f87e 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -57,6 +57,16 @@ class UIUtilsSuite extends SparkFunSuite { ) } + test("SPARK-11906: Progress bar should not overflow because of speculative tasks") { + val generated = makeProgressBar(2, 3, 0, 0, 4).head.child.filter(_.label == "div") + val expected = Seq( +
    , +
    + ) + assert(generated.sameElements(expected), + s"\nRunning progress bar should round down\n\nExpected:\n$expected\nGenerated:\n$generated") + } + private def verify( desc: String, expected: Elem, errorMsg: String = "", baseUrl: String = ""): Unit = { val generated = makeDescription(desc, baseUrl) From dbe8b95528a3ca00988905a175e4883653e56214 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Tue, 24 Nov 2015 23:24:49 +0800 Subject: [PATCH 0897/1168] [SPARK-11043][SQL] BugFix:Set the operator log in the thrift server. `SessionManager` will set the `operationLog` if the configuration `hive.server2.logging.operation.enabled` is true in version of hive 1.2.1. But the spark did not adapt to this change, so no matter enabled the configuration or not, spark thrift server will always log the warn message. PS: if `hive.server2.logging.operation.enabled` is false, it should log the warn message (the same as hive thrift server). Author: huangzhaowei Closes #9056 from SaintBacchus/SPARK-11043. --- .../SparkExecuteStatementOperation.scala | 8 ++++---- .../thriftserver/SparkSQLSessionManager.scala | 5 +++++ .../thriftserver/HiveThriftServer2Suites.scala | 16 +++++++++++++++- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 82fef92dcb73b..e022ee86a763a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -134,12 +134,12 @@ private[hive] class SparkExecuteStatementOperation( def getResultSetSchema: TableSchema = resultSchema - override def run(): Unit = { + override def runInternal(): Unit = { setState(OperationState.PENDING) setHasResultSet(true) // avoid no resultset for async run if (!runInBackground) { - runInternal() + execute() } else { val sparkServiceUGI = Utils.getUGI() @@ -151,7 +151,7 @@ private[hive] class SparkExecuteStatementOperation( val doAsAction = new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { try { - runInternal() + execute() } catch { case e: HiveSQLException => setOperationException(e) @@ -188,7 +188,7 @@ private[hive] class SparkExecuteStatementOperation( } } - override def runInternal(): Unit = { + private def execute(): Unit = { statementId = UUID.randomUUID().toString logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index af4fcdf021bd4..de4e9c62b57a4 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -41,6 +41,11 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) + // Create operation log root directory, if operation logging is enabled + if (hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED)) { + invoke(classOf[SessionManager], this, "initOperationLogRootDir") + } + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) getAncestorField[Log](this, 3, "LOG").info( diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 1dd898aa38350..139d8e897ba1d 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.{Await, Promise, future} +import scala.io.Source import scala.util.{Random, Try} import com.google.common.base.Charsets.UTF_8 @@ -507,6 +508,12 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { assert(rs2.getInt(2) === 500) } } + + test("SPARK-11043 check operation log root directory") { + val expectedLine = + "Operation log root directory is created: " + operationLogPath.getAbsoluteFile + assert(Source.fromFile(logPath).getLines().exists(_.contains(expectedLine))) + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -642,7 +649,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl protected def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" private val pidDir: File = Utils.createTempDir("thriftserver-pid") - private var logPath: File = _ + protected var logPath: File = _ + protected var operationLogPath: File = _ private var logTailingProcess: Process = _ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] @@ -679,6 +687,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode + | --hiveconf ${ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LOG_LOCATION}=$operationLogPath | --hiveconf $portConf=$port | --driver-class-path $driverClassPath | --driver-java-options -Dlog4j.debug @@ -706,6 +715,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl warehousePath.delete() metastorePath = Utils.createTempDir() metastorePath.delete() + operationLogPath = Utils.createTempDir() + operationLogPath.delete() logPath = null logTailingProcess = null @@ -782,6 +793,9 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl metastorePath.delete() metastorePath = null + operationLogPath.delete() + operationLogPath = null + Option(logPath).foreach(_.delete()) logPath = null From 11001cc6345a7c8d4f6d732db788105e66766b5a Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 24 Nov 2015 23:32:05 +0800 Subject: [PATCH 0898/1168] [SPARK-11592][SQL] flush spark-sql command line history to history file Currently, `spark-sql` would not flush command history when exiting. Author: Daoyuan Wang Closes #9563 from adrian-wang/jline. --- .../hive/thriftserver/SparkSQLCLIDriver.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 6419002a2aa89..4b928e600b355 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -194,6 +194,22 @@ private[hive] object SparkSQLCLIDriver extends Logging { logWarning(e.getMessage) } + // add shutdown hook to flush the history to history file + Runtime.getRuntime.addShutdownHook(new Thread(new Runnable() { + override def run() = { + reader.getHistory match { + case h: FileHistory => + try { + h.flush() + } catch { + case e: IOException => + logWarning("WARNING: Failed to write command history file: " + e.getMessage) + } + case _ => + } + } + })) + // TODO: missing /* val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") From f2a2d251249765d252740ab25064037eb6e479a2 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 24 Nov 2015 09:20:09 -0800 Subject: [PATCH 0899/1168] =?UTF-8?q?[SPARK-11818][REPL]=20Fix=20ExecutorC?= =?UTF-8?q?lassLoader=20to=20lookup=20resources=20from=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …parent class loader Without patch, two additional tests of ExecutorClassLoaderSuite fails. - "resource from parent" - "resources from parent" Detailed explanation is here, https://issues.apache.org/jira/browse/SPARK-11818?focusedCommentId=15011202&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15011202 Author: Jungtaek Lim Closes #9812 from HeartSaVioR/SPARK-11818. --- .../spark/repl/ExecutorClassLoader.scala | 12 +++++++- .../spark/repl/ExecutorClassLoaderSuite.scala | 29 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index a976e96809cb8..a8859fcd4584b 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -34,7 +34,9 @@ import org.apache.spark.util.ParentClassLoader /** * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, * used to load classes defined by the interpreter when the REPL is used. - * Allows the user to specify if user class path should be first + * Allows the user to specify if user class path should be first. + * This class loader delegates getting/finding resources to parent loader, + * which makes sense until REPL never provide resource dynamically. */ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader, userClassPathFirst: Boolean) extends ClassLoader with Logging { @@ -55,6 +57,14 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } + override def getResource(name: String): URL = { + parentLoader.getResource(name) + } + + override def getResources(name: String): java.util.Enumeration[URL] = { + parentLoader.getResources(name) + } + override def findClass(name: String): Class[_] = { userClassPathFirst match { case true => findClassLocally(name).getOrElse(parentLoader.loadClass(name)) diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index a58eda12b1120..c1211f7596b9c 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -19,8 +19,13 @@ package org.apache.spark.repl import java.io.File import java.net.{URL, URLClassLoader} +import java.nio.charset.StandardCharsets +import java.util + +import com.google.common.io.Files import scala.concurrent.duration._ +import scala.io.Source import scala.language.implicitConversions import scala.language.postfixOps @@ -41,6 +46,7 @@ class ExecutorClassLoaderSuite val childClassNames = List("ReplFakeClass1", "ReplFakeClass2") val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3") + val parentResourceNames = List("fake-resource.txt") var tempDir1: File = _ var tempDir2: File = _ var url1: String = _ @@ -54,6 +60,9 @@ class ExecutorClassLoaderSuite url1 = "file://" + tempDir1 urls2 = List(tempDir2.toURI.toURL).toArray childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1")) + parentResourceNames.foreach { x => + Files.write("resource".getBytes(StandardCharsets.UTF_8), new File(tempDir2, x)) + } parentClassNames.foreach(TestUtils.createCompiledClass(_, tempDir2, "2")) } @@ -99,6 +108,26 @@ class ExecutorClassLoaderSuite } } + test("resource from parent") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val resourceName: String = parentResourceNames.head + val is = classLoader.getResourceAsStream(resourceName) + assert(is != null, s"Resource $resourceName not found") + val content = Source.fromInputStream(is, "UTF-8").getLines().next() + assert(content.contains("resource"), "File doesn't contain 'resource'") + } + + test("resources from parent") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val resourceName: String = parentResourceNames.head + val resources: util.Enumeration[URL] = classLoader.getResources(resourceName) + assert(resources.hasMoreElements, s"Resource $resourceName not found") + val fileReader = Source.fromInputStream(resources.nextElement().openStream()).bufferedReader() + assert(fileReader.readLine().contains("resource"), "File doesn't contain 'resource'") + } + test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") { // This is a regression test for SPARK-6209, a bug where each failed attempt to load a class // from the driver's class server would leak a HTTP connection, causing the class server's From 60c721a947fd866b69e2bca1ef6cf47b3847841f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Nov 2015 09:28:39 -0800 Subject: [PATCH 0900/1168] [SPARK-11942][SQL] fix encoder life cycle for CoGroup we should pass in resolved encodera to logical `CoGroup` and bind them in physical `CoGroup` Author: Wenchen Fan Closes #9928 from cloud-fan/cogroup. --- .../plans/logical/basicOperators.scala | 27 ++++++++++--------- .../org/apache/spark/sql/GroupedDataset.scala | 4 ++- .../spark/sql/execution/basicOperators.scala | 20 +++++++------- .../org/apache/spark/sql/DatasetSuite.scala | 12 +++++++++ 4 files changed, 41 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 737e62fd59214..5665fd7e5f419 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -553,19 +553,22 @@ case class MapGroups[K, T, U]( /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { - def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder]( - func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], + def apply[Key, Left, Right, Result : Encoder]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], + leftEnc: ExpressionEncoder[Left], + rightEnc: ExpressionEncoder[Right], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], left: LogicalPlan, - right: LogicalPlan): CoGroup[K, Left, Right, R] = { + right: LogicalPlan): CoGroup[Key, Left, Right, Result] = { CoGroup( func, - encoderFor[K], - encoderFor[Left], - encoderFor[Right], - encoderFor[R], - encoderFor[R].schema.toAttributes, + keyEnc, + leftEnc, + rightEnc, + encoderFor[Result], + encoderFor[Result].schema.toAttributes, leftGroup, rightGroup, left, @@ -577,12 +580,12 @@ object CoGroup { * A relation produced by applying `func` to each grouping key and associated values from left and * right children. */ -case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], - kEncoder: ExpressionEncoder[K], +case class CoGroup[Key, Left, Right, Result]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], - rEncoder: ExpressionEncoder[R], + resultEnc: ExpressionEncoder[Result], output: Seq[Attribute], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 793a86b132907..a10a89342fb5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -304,11 +304,13 @@ class GroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.unresolvedVEncoder new Dataset[R]( sqlContext, CoGroup( f, + this.resolvedKEncoder, + this.resolvedVEncoder, + other.resolvedVEncoder, this.groupingAttributes, other.groupingAttributes, this.logicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d57b8e7a9ed61..a42aea0b96d43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -375,12 +375,12 @@ case class MapGroups[K, T, U]( * iterators containing all elements in the group from left and right side. * The result of this function is encoded and flattened before being output. */ -case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], - kEncoder: ExpressionEncoder[K], +case class CoGroup[Key, Left, Right, Result]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], - rEncoder: ExpressionEncoder[R], + resultEnc: ExpressionEncoder[Result], output: Seq[Attribute], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], @@ -397,15 +397,17 @@ case class CoGroup[K, Left, Right, R]( left.execute().zipPartitions(right.execute()) { (leftData, rightData) => val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val groupKeyEncoder = kEncoder.bind(leftGroup) + val boundKeyEnc = keyEnc.bind(leftGroup) + val boundLeftEnc = leftEnc.bind(left.output) + val boundRightEnc = rightEnc.bind(right.output) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { case (key, leftResult, rightResult) => val result = func( - groupKeyEncoder.fromRow(key), - leftResult.map(leftEnc.fromRow), - rightResult.map(rightEnc.fromRow)) - result.map(rEncoder.toRow) + boundKeyEnc.fromRow(key), + leftResult.map(boundLeftEnc.fromRow), + rightResult.map(boundRightEnc.fromRow)) + result.map(resultEnc.toRow) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index dbdd7ba14a5b7..13eede1b17d8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -340,6 +340,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") } + test("cogroup with complex data") { + val ds1 = Seq(1 -> ClassData("a", 1), 2 -> ClassData("b", 2)).toDS() + val ds2 = Seq(2 -> ClassData("c", 3), 3 -> ClassData("d", 4)).toDS() + val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) => + Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString)) + } + + checkAnswer( + cogrouped, + 1 -> "a", 2 -> "bc", 3 -> "d") + } + test("SPARK-11436: we should rebind right encoder when join 2 datasets") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") From fab33b4756c7a5ad58e237af01b8b11425dba441 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 24 Nov 2015 09:52:53 -0800 Subject: [PATCH 0901/1168] [SPARK-11952][ML] Remove duplicate ml examples Remove duplicate ml examples (only for ml). mengxr Author: Yanbo Liang Closes #9933 from yanboliang/SPARK-11685. --- .../main/python/ml/gradient_boosted_trees.py | 82 ----------------- .../src/main/python/ml/logistic_regression.py | 66 -------------- .../main/python/ml/random_forest_example.py | 87 ------------------- 3 files changed, 235 deletions(-) delete mode 100644 examples/src/main/python/ml/gradient_boosted_trees.py delete mode 100644 examples/src/main/python/ml/logistic_regression.py delete mode 100644 examples/src/main/python/ml/random_forest_example.py diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py deleted file mode 100644 index c3bf8aa2eb1e6..0000000000000 --- a/examples/src/main/python/ml/gradient_boosted_trees.py +++ /dev/null @@ -1,82 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import GBTClassifier -from pyspark.ml.feature import StringIndexer -from pyspark.ml.regression import GBTRegressor -from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics -from pyspark.sql import Row, SQLContext - -""" -A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline. -Note: GBTClassifier only supports binary classification currently -Run with: - bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py -""" - - -def testClassification(train, test): - # Train a GradientBoostedTrees model. - - rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel") - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = BinaryClassificationMetrics(predictionAndLabels) - print("AUC %.3f" % metrics.areaUnderROC) - - -def testRegression(train, test): - # Train a GradientBoostedTrees model. - - rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel") - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = RegressionMetrics(predictionAndLabels) - print("rmse %.3f" % metrics.rootMeanSquaredError) - print("r2 %.3f" % metrics.r2) - print("mae %.3f" % metrics.meanAbsoluteError) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: gradient_boosted_trees", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonGBTExample") - sqlContext = SQLContext(sc) - - # Load the data stored in LIBSVM format as a DataFrame. - df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [train, test] = td.randomSplit([0.7, 0.3]) - testClassification(train, test) - testRegression(train, test) - sc.stop() diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py deleted file mode 100644 index 4cd027fdfbe8a..0000000000000 --- a/examples/src/main/python/ml/logistic_regression.py +++ /dev/null @@ -1,66 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import LogisticRegression -from pyspark.mllib.evaluation import MulticlassMetrics -from pyspark.ml.feature import StringIndexer -from pyspark.sql import SQLContext - -""" -A simple example demonstrating a logistic regression with elastic net regularization Pipeline. -Run with: - bin/spark-submit examples/src/main/python/ml/logistic_regression.py -""" - -if __name__ == "__main__": - - if len(sys.argv) > 1: - print("Usage: logistic_regression", file=sys.stderr) - exit(-1) - - sc = SparkContext(appName="PythonLogisticRegressionExample") - sqlContext = SQLContext(sc) - - # Load the data stored in LIBSVM format as a DataFrame. - df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [training, test] = td.randomSplit([0.7, 0.3]) - - lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") - lr.setElasticNetParam(0.8) - - # Fit the model - lrModel = lr.fit(training) - - predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = MulticlassMetrics(predictionAndLabels) - print("weighted f-measure %.3f" % metrics.weightedFMeasure()) - print("precision %s" % metrics.precision()) - print("recall %s" % metrics.recall()) - - sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py deleted file mode 100644 index dc6a778670193..0000000000000 --- a/examples/src/main/python/ml/random_forest_example.py +++ /dev/null @@ -1,87 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import RandomForestClassifier -from pyspark.ml.feature import StringIndexer -from pyspark.ml.regression import RandomForestRegressor -from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics -from pyspark.mllib.util import MLUtils -from pyspark.sql import Row, SQLContext - -""" -A simple example demonstrating a RandomForest Classification/Regression Pipeline. -Run with: - bin/spark-submit examples/src/main/python/ml/random_forest_example.py -""" - - -def testClassification(train, test): - # Train a RandomForest model. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - # Note: Use larger numTrees in practice. - - rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4) - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = MulticlassMetrics(predictionAndLabels) - print("weighted f-measure %.3f" % metrics.weightedFMeasure()) - print("precision %s" % metrics.precision()) - print("recall %s" % metrics.recall()) - - -def testRegression(train, test): - # Train a RandomForest model. - # Note: Use larger numTrees in practice. - - rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4) - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = RegressionMetrics(predictionAndLabels) - print("rmse %.3f" % metrics.rootMeanSquaredError) - print("r2 %.3f" % metrics.r2) - print("mae %.3f" % metrics.meanAbsoluteError) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: random_forest_example", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonRandomForestExample") - sqlContext = SQLContext(sc) - - # Load the data stored in LIBSVM format as a DataFrame. - df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [train, test] = td.randomSplit([0.7, 0.3]) - testClassification(train, test) - testRegression(train, test) - sc.stop() From a310b8cc3e5c7c2bc37c495a72031d97fa5a69a4 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 24 Nov 2015 09:54:55 -0800 Subject: [PATCH 0902/1168] [SPARK-11521][ML][DOC] Document that Logistic, Linear Regression summaries ignore weight col Doc for 1.6 that the summaries mostly ignore the weight column. To be corrected for 1.7 CC: mengxr thunterdb Author: Joseph K. Bradley Closes #9927 from jkbradley/linregsummary-doc. --- .../ml/classification/LogisticRegression.scala | 18 ++++++++++++++++++ .../spark/ml/regression/LinearRegression.scala | 15 +++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 418bbdc9a058f..d320d64dd90d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -755,23 +755,35 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns the receiver operating characteristic (ROC) curve, * which is an Dataframe having two fields (FPR, TPR) * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic */ @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") /** * Computes the area under the receiver operating characteristic (ROC) curve. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() /** * Returns the precision-recall curve, which is an Dataframe containing * two fields recall, precision with (0.0, 1.0) prepended to it. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") /** * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val fMeasureByThreshold: DataFrame = { binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") @@ -781,6 +793,9 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns a dataframe with two fields (threshold, precision) curve. * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the precision. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val precisionByThreshold: DataFrame = { binaryMetrics.precisionByThreshold().toDF("threshold", "precision") @@ -790,6 +805,9 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns a dataframe with two fields (threshold, recall) curve. * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the recall. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val recallByThreshold: DataFrame = { binaryMetrics.recallByThreshold().toDF("threshold", "recall") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 70ccec766c471..1db91666f21ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -540,6 +540,9 @@ class LinearRegressionSummary private[regression] ( * Returns the explained variance regression score. * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val explainedVariance: Double = metrics.explainedVariance @@ -547,6 +550,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanAbsoluteError: Double = metrics.meanAbsoluteError @@ -554,6 +560,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanSquaredError: Double = metrics.meanSquaredError @@ -561,6 +570,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the root mean squared error, which is defined as the square root of * the mean squared error. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val rootMeanSquaredError: Double = metrics.rootMeanSquaredError @@ -568,6 +580,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns R^2^, the coefficient of determination. * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val r2: Double = metrics.r2 From c46a8d9bc2641fc50dbf4e6fb26e7245893f7878 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 24 Nov 2015 09:56:17 -0800 Subject: [PATCH 0903/1168] [SPARK-11847][ML] Model export/import for spark.ml: LDA Add read/write support to LDA, similar to ALS. save/load for ml.LocalLDAModel is done. For DistributedLDAModel, I'm not sure if we can invoke save on the mllib.DistributedLDAModel directly. I'll send update after some test. Author: Yuhao Yang Closes #9894 from hhbyyh/ldaMLsave. --- .../org/apache/spark/ml/clustering/LDA.scala | 110 +++++++++++++++++- .../spark/mllib/clustering/LDAModel.scala | 4 +- .../apache/spark/ml/clustering/LDASuite.scala | 44 ++++++- 3 files changed, 150 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 92e05815d6a3d..830510b1698d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.clustering +import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter} import org.apache.spark.ml.param._ +import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, @@ -322,7 +323,7 @@ sealed abstract class LDAModel private[ml] ( @Since("1.6.0") override val uid: String, @Since("1.6.0") val vocabSize: Int, @Since("1.6.0") @transient protected val sqlContext: SQLContext) - extends Model[LDAModel] with LDAParams with Logging { + extends Model[LDAModel] with LDAParams with Logging with MLWritable { // NOTE to developers: // This abstraction should contain all important functionality for basic LDA usage. @@ -486,6 +487,64 @@ class LocalLDAModel private[ml] ( @Since("1.6.0") override def isDistributed: Boolean = false + + @Since("1.6.0") + override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this) +} + + +@Since("1.6.0") +object LocalLDAModel extends MLReadable[LocalLDAModel] { + + private[LocalLDAModel] + class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter { + + private case class Data( + vocabSize: Int, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val oldModel = instance.oldLocalModel + val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration, + oldModel.topicConcentration, oldModel.gammaShape) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class LocalLDAModelReader extends MLReader[LocalLDAModel] { + + private val className = classOf[LocalLDAModel].getName + + override def load(path: String): LocalLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration", + "gammaShape") + .head() + val vocabSize = data.getAs[Int](0) + val topicsMatrix = data.getAs[Matrix](1) + val docConcentration = data.getAs[Vector](2) + val topicConcentration = data.getAs[Double](3) + val gammaShape = data.getAs[Double](4) + val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration, + gammaShape) + val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader + + @Since("1.6.0") + override def load(path: String): LocalLDAModel = super.load(path) } @@ -562,6 +621,45 @@ class DistributedLDAModel private[ml] ( */ @Since("1.6.0") lazy val logPrior: Double = oldDistributedModel.logPrior + + @Since("1.6.0") + override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this) +} + + +@Since("1.6.0") +object DistributedLDAModel extends MLReadable[DistributedLDAModel] { + + private[DistributedLDAModel] + class DistributedWriter(instance: DistributedLDAModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val modelPath = new Path(path, "oldModel").toString + instance.oldDistributedModel.save(sc, modelPath) + } + } + + private class DistributedLDAModelReader extends MLReader[DistributedLDAModel] { + + private val className = classOf[DistributedLDAModel].getName + + override def load(path: String): DistributedLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val modelPath = new Path(path, "oldModel").toString + val oldModel = OldDistributedLDAModel.load(sc, modelPath) + val model = new DistributedLDAModel( + metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[DistributedLDAModel] = new DistributedLDAModelReader + + @Since("1.6.0") + override def load(path: String): DistributedLDAModel = super.load(path) } @@ -593,7 +691,8 @@ class DistributedLDAModel private[ml] ( @Since("1.6.0") @Experimental class LDA @Since("1.6.0") ( - @Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams { + @Since("1.6.0") override val uid: String) + extends Estimator[LDAModel] with LDAParams with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("lda")) @@ -695,7 +794,7 @@ class LDA @Since("1.6.0") ( } -private[clustering] object LDA { +private[clustering] object LDA extends DefaultParamsReadable[LDA] { /** Get dataset for spark.mllib LDA */ def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = { @@ -706,4 +805,7 @@ private[clustering] object LDA { (docId, features) } } + + @Since("1.6.0") + override def load(path: String): LDA = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index cd520f09bd466..7384d065a2ea8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -187,11 +187,11 @@ abstract class LDAModel private[clustering] extends Saveable { * @param topics Inferred topics (vocabSize x k matrix). */ @Since("1.3.0") -class LocalLDAModel private[clustering] ( +class LocalLDAModel private[spark] ( @Since("1.3.0") val topics: Matrix, @Since("1.5.0") override val docConcentration: Vector, @Since("1.5.0") override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double = 100) + override protected[spark] val gammaShape: Double = 100) extends LDAModel with Serializable { @Since("1.3.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index b634d31cc34f0..97dbfd9a4314a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} @@ -39,10 +40,24 @@ object LDASuite { }.map(v => new TestRow(v)) sql.createDataFrame(rdd) } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "k" -> 3, + "maxIter" -> 2, + "checkpointInterval" -> 30, + "learningOffset" -> 1023.0, + "learningDecay" -> 0.52, + "subsamplingRate" -> 0.051 + ) } -class LDASuite extends SparkFunSuite with MLlibTestSparkContext { +class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { val k: Int = 5 val vocabSize: Int = 30 @@ -218,4 +233,29 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { val lp = model.logPrior assert(lp <= 0.0 && lp != Double.NegativeInfinity) } + + test("read/write LocalLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) + } + + test("read/write DistributedLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) + } } From 835196ebf09b4427fbf12c8cf39846bced2fd73d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Nov 2015 11:09:01 -0800 Subject: [PATCH 0904/1168] [SPARK-11926][SQL] unify GetStructField and GetInternalRowField Author: Wenchen Fan Closes #9909 from cloud-fan/get-struct. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../sql/catalyst/analysis/unresolved.scala | 8 +++---- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 2 +- .../expressions/complexTypeExtractors.scala | 18 ++++++++-------- .../expressions/namedExpressions.scala | 4 ++-- .../sql/catalyst/expressions/objects.scala | 21 ------------------- .../expressions/ComplexTypeSuite.scala | 4 ++-- 9 files changed, 21 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 476becec4dd52..d133ad3f0d89d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -130,7 +130,7 @@ object ScalaReflection extends ScalaReflection { /** Returns the current path with a field at ordinal extracted. */ def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetInternalRowField(p, ordinal, dataType)) + .map(p => GetStructField(p, ordinal)) .getOrElse(BoundReference(ordinal, dataType, false)) /** Returns the current path or `BoundReference`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 6485bdfb30234..1b2a8dc4c7f14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -201,12 +201,12 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu if (attribute.isDefined) { // This target resolved to an attribute in child. It must be a struct. Expand it. attribute.get.dataType match { - case s: StructType => { - s.fields.map( f => { - val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get) + case s: StructType => s.zipWithIndex.map { + case (f, i) => + val extract = GetStructField(attribute.get, i) Alias(extract, target.get + "." + f.name)() - }) } + case _ => { throw new AnalysisException("Can only star expand struct data types. Attribute: `" + target.get + "`") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 7bc9aed0b204e..0c10a56c555f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -111,7 +111,7 @@ object ExpressionEncoder { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) - case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt) + case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index fa553e7c5324c..67518f52d4a58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -220,7 +220,7 @@ object RowEncoder { If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(GetInternalRowField(input, i, f.dataType))) + constructorFor(GetStructField(input, i))) } CreateExternalRow(convertedFields) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 540ed3500616a..169435a10ea2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -206,7 +206,7 @@ abstract class Expression extends TreeNode[Expression] { */ def prettyString: String = { transform { - case a: AttributeReference => PrettyAttribute(a.name) + case a: AttributeReference => PrettyAttribute(a.name, a.dataType) case u: UnresolvedAttribute => PrettyAttribute(u.name) }.toString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index f871b737fff3a..10ce10aaf6da2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -51,7 +51,7 @@ object ExtractValue { case (StructType(fields), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) - GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + GetStructField(child, ordinal, Some(fieldName)) case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString @@ -97,18 +97,18 @@ object ExtractValue { * Returns the value of fields in the Struct `child`. * * No need to do type checking since it is handled by [[ExtractValue]]. - * TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]]. + * + * Note that we can pass in the field name directly to keep case preserving in `toString`. + * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ -case class GetStructField(child: Expression, field: StructField, ordinal: Int) +case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends UnaryExpression { - override def dataType: DataType = child.dataType match { - case s: StructType => s(ordinal).dataType - // This is a hack to avoid breaking existing code until we remove the need for the struct field - case _ => field.dataType - } + private lazy val field = child.dataType.asInstanceOf[StructType](ordinal) + + override def dataType: DataType = field.dataType override def nullable: Boolean = child.nullable || field.nullable - override def toString: String = s"$child.${field.name}" + override def toString: String = s"$child.${name.getOrElse(field.name)}" protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow].get(ordinal, field.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 00b7970bd16c6..26b6aca79971e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -273,7 +273,8 @@ case class AttributeReference( * A place holder used when printing expressions without debugging information such as the * expression id or the unresolved indicator. */ -case class PrettyAttribute(name: String) extends Attribute with Unevaluable { +case class PrettyAttribute(name: String, dataType: DataType = NullType) + extends Attribute with Unevaluable { override def toString: String = name @@ -286,7 +287,6 @@ case class PrettyAttribute(name: String) extends Attribute with Unevaluable { override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException - override def dataType: DataType = NullType } object VirtualColumn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 4a1f419f0ad8d..62d09f0f55105 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -517,27 +517,6 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { } } -case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType) - extends UnaryExpression { - - override def nullable: Boolean = true - - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.isNullAt($ordinal)) { - ${ev.isNull} = true; - } else { - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; - } - """ - }) - } -} - /** * Serializes an input object using a generic serializer (Kryo or Java). * @param kryo if true, use Kryo. Otherwise, use Java. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index e60990aeb423f..62fd47234b33b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -79,8 +79,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { def getStructField(expr: Expression, fieldName: String): GetStructField = { expr.dataType match { case StructType(fields) => - val field = fields.find(_.name == fieldName).get - GetStructField(expr, field, fields.indexOf(field)) + val index = fields.indexWhere(_.name == fieldName) + GetStructField(expr, index) } } From 8c41bbec935fc9b0023937c1f0246bdf4bcfc6d6 Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 24 Nov 2015 12:22:33 -0800 Subject: [PATCH 0905/1168] [SPARK-11872] Prevent the call to SparkContext#stop() in the listener bus's thread This is continuation of SPARK-11761 Andrew suggested adding this protection. See tail of https://github.com/apache/spark/pull/9741 Author: tedyu Closes #9852 from tedyu/master. --- .../scala/org/apache/spark/SparkContext.scala | 4 +++ .../spark/scheduler/SparkListenerSuite.scala | 31 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b153a7b08e590..e19ba113702c6 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1694,6 +1694,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Shut down the SparkContext. def stop() { + if (AsynchronousListenerBus.withinListenerThread.value) { + throw new SparkException("Cannot stop SparkContext within listener thread of" + + " AsynchronousListenerBus") + } // Use the stopping variable to ensure no contention for the stop scenario. // Still track the stopped variable for use elsewhere in the code. if (!stopped.compareAndSet(false, true)) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 84e545851f49e..f20d5be7c0ee0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.scalatest.Matchers +import org.apache.spark.SparkException import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ResetSystemProperties import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} @@ -36,6 +37,21 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val jobCompletionTime = 1421191296660L + test("don't call sc.stop in listener") { + sc = new SparkContext("local", "SparkListenerSuite") + val listener = new SparkContextStoppingListener(sc) + val bus = new LiveListenerBus + bus.addListener(listener) + + // Starting listener bus should flush all buffered events + bus.start(sc) + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + bus.stop() + assert(listener.sparkExSeen) + } + test("basic creation and shutdown of LiveListenerBus") { val counter = new BasicJobCounter val bus = new LiveListenerBus @@ -443,6 +459,21 @@ private class BasicJobCounter extends SparkListener { override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } +/** + * A simple listener that tries to stop SparkContext. + */ +private class SparkContextStoppingListener(val sc: SparkContext) extends SparkListener { + @volatile var sparkExSeen = false + override def onJobEnd(job: SparkListenerJobEnd): Unit = { + try { + sc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } + } +} + private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener { var count = 0 override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 From 5db64cde0e53a09aa8bf500f3a2743c2c6c046e4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 12:54:37 -0800 Subject: [PATCH 0906/1168] [SPARK-11946][SQL] Audit pivot API for 1.6. Currently pivot's signature looks like ```scala scala.annotation.varargs def pivot(pivotColumn: Column, values: Column*): GroupedData scala.annotation.varargs def pivot(pivotColumn: String, values: Any*): GroupedData ``` I think we can remove the one that takes "Column" types, since callers should always be passing in literals. It'd also be more clear if the values are not varargs, but rather Seq or java.util.List. I also made similar changes for Python. Author: Reynold Xin Closes #9929 from rxin/SPARK-11946. --- .../apache/spark/scheduler/DAGScheduler.scala | 1 - python/pyspark/sql/group.py | 12 +- .../sql/catalyst/expressions/literals.scala | 1 + .../org/apache/spark/sql/GroupedData.scala | 154 ++++++++++-------- .../apache/spark/sql/JavaDataFrameSuite.java | 16 ++ .../spark/sql/DataFramePivotSuite.scala | 21 +-- .../apache/spark/sql/test/SQLTestData.scala | 1 + 7 files changed, 125 insertions(+), 81 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ae725b467d8c4..77a184dfe4bee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1574,7 +1574,6 @@ class DAGScheduler( } def stop() { - logInfo("Stopping DAGScheduler") messageScheduler.shutdownNow() eventProcessLoop.stop() taskScheduler.stop() diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 227f40bc3cf53..d8ed7eb2dda64 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -168,20 +168,24 @@ def sum(self, *cols): """ @since(1.6) - def pivot(self, pivot_col, *values): + def pivot(self, pivot_col, values=None): """Pivots a column of the current DataFrame and preform the specified aggregation. :param pivot_col: Column to pivot :param values: Optional list of values of pivotColumn that will be translated to columns in the output data frame. If values are not provided the method with do an immediate call to .distinct() on the pivot column. - >>> df4.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings").collect() + + >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] + >>> df4.groupBy("year").pivot("course").sum("earnings").collect() [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ - jgd = self._jdf.pivot(_to_java_column(pivot_col), - _to_seq(self.sql_ctx._sc, values, _create_column_from_literal)) + if values is None: + jgd = self._jdf.pivot(pivot_col) + else: + jgd = self._jdf.pivot(pivot_col, values) return GroupedData(jgd, self.sql_ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e34fd49be8389..68ec688c99f93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -44,6 +44,7 @@ object Literal { case a: Array[Byte] => Literal(a, BinaryType) case i: CalendarInterval => Literal(i, CalendarIntervalType) case null => Literal(null, NullType) + case v: Literal => v case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 63dd7fbcbe9e4..ee7150cbbfbca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAli import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} -import org.apache.spark.sql.types.{StringType, NumericType} +import org.apache.spark.sql.types.NumericType /** @@ -282,74 +282,96 @@ class GroupedData protected[sql]( } /** - * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified - * aggregation. - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) - * // Or without specifying column values - * df.groupBy($"year").pivot($"course").agg(sum($"earnings")) - * }}} - * @param pivotColumn Column to pivot - * @param values Optional list of values of pivotColumn that will be translated to columns in the - * output data frame. If values are not provided the method with do an immediate - * call to .distinct() on the pivot column. - * @since 1.6.0 - */ - @scala.annotation.varargs - def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match { - case _: GroupedData.PivotType => - throw new UnsupportedOperationException("repeated pivots are not supported") - case GroupedData.GroupByType => - val pivotValues = if (values.nonEmpty) { - values.map { - case Column(literal: Literal) => literal - case other => - throw new UnsupportedOperationException( - s"The values of a pivot must be literals, found $other") - } - } else { - // This is to prevent unintended OOM errors when the number of distinct values is large - val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) - // Get the distinct values of the column and sort them so its consistent - val values = df.select(pivotColumn) - .distinct() - .sort(pivotColumn) - .map(_.get(0)) - .take(maxValues + 1) - .map(Literal(_)).toSeq - if (values.length > maxValues) { - throw new RuntimeException( - s"The pivot column $pivotColumn has more than $maxValues distinct values, " + - "this could indicate an error. " + - "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " + - s"to at least the number of distinct values of the pivot column.") - } - values - } - new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues)) - case _ => - throw new UnsupportedOperationException("pivot is only supported after a groupBy") + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @since 1.6.0 + */ + def pivot(pivotColumn: String): GroupedData = { + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) + // Get the distinct values of the column and sort them so its consistent + val values = df.select(pivotColumn) + .distinct() + .sort(pivotColumn) + .map(_.get(0)) + .take(maxValues + 1) + .toSeq + + if (values.length > maxValues) { + throw new AnalysisException( + s"The pivot column $pivotColumn has more than $maxValues distinct values, " + + "this could indicate an error. " + + s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " + + "to at least the number of distinct values of the pivot column.") + } + + pivot(pivotColumn, values) } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings") - * // Or without specifying column values - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * @param pivotColumn Column to pivot - * @param values Optional list of values of pivotColumn that will be translated to columns in the - * output data frame. If values are not provided the method with do an immediate - * call to .distinct() on the pivot column. - * @since 1.6.0 - */ - @scala.annotation.varargs - def pivot(pivotColumn: String, values: Any*): GroupedData = { - val resolvedPivotColumn = Column(df.resolve(pivotColumn)) - pivot(resolvedPivotColumn, values.map(functions.lit): _*) + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = { + groupType match { + case GroupedData.GroupByType => + new GroupedData( + df, + groupingExprs, + GroupedData.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + case _: GroupedData.PivotType => + throw new UnsupportedOperationException("repeated pivots are not supported") + case _ => + throw new UnsupportedOperationException("pivot is only supported after a groupBy") + } + } + + /** + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings"); + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = { + pivot(pivotColumn, values.asScala) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 567bdddece80e..a12fed3c0c6af 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -282,4 +282,20 @@ public void testSampleBy() { Assert.assertEquals(1, actual[1].getLong(0)); Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13); } + + @Test + public void pivot() { + DataFrame df = context.table("courseSales"); + Row[] actual = df.groupBy("year") + .pivot("course", Arrays.asList("dotNET", "Java")) + .agg(sum("earnings")).orderBy("year").collect(); + + Assert.assertEquals(2012, actual[0].getInt(0)); + Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01); + Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01); + + Assert.assertEquals(2013, actual[1].getInt(0)); + Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01); + Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 0c23d142670c1..fc53aba68ebb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -25,7 +25,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot courses with literals") { checkAnswer( - courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings")), Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil ) @@ -33,14 +33,15 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot year with literals") { checkAnswer( - courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")), + courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } test("pivot courses with literals and multiple aggregations") { checkAnswer( - courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + courseSales.groupBy($"year") + .pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings"), avg($"earnings")), Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil @@ -49,14 +50,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot year with string values (cast)") { checkAnswer( - courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"), + courseSales.groupBy("course").pivot("year", Seq("2012", "2013")).sum("earnings"), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } test("pivot year with int values") { checkAnswer( - courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"), + courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).sum("earnings"), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } @@ -64,22 +65,22 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot courses with no values") { // Note Java comes before dotNet in sorted order checkAnswer( - courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), + courseSales.groupBy("year").pivot("course").agg(sum($"earnings")), Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil ) } test("pivot year with no values") { checkAnswer( - courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), + courseSales.groupBy("course").pivot("year").agg(sum($"earnings")), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } - test("pivot max values inforced") { + test("pivot max values enforced") { sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) - intercept[RuntimeException]( - courseSales.groupBy($"year").pivot($"course") + intercept[AnalysisException]( + courseSales.groupBy("year").pivot("course") ) sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index abad0d7eaaedf..83c63e04f344a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -281,6 +281,7 @@ private[sql] trait SQLTestData { self => person salary complexData + courseSales } } From 301669531cbfdc7321cff6a654a8449d53497b94 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 24 Nov 2015 15:08:02 -0600 Subject: [PATCH 0907/1168] [SPARK-11929][CORE] Make the repl log4j configuration override the root logger. In the default Spark distribution, there are currently two separate log4j config files, with different default values for the root logger, so that when running the shell you have a different default log level. This makes the shell more usable, since the logs don't overwhelm the output. But if you install a custom log4j.properties, you lose that, because then it's going to be used no matter whether you're running a regular app or the shell. With this change, the overriding of the log level is done differently; the log level repl's main class (org.apache.spark.repl.Main) is used to define the root logger's level when running the shell, defaulting to WARN if it's not set explicitly. On a somewhat related change, the shell output about the "sc" variable was changed a bit to contain a little more useful information about the application, since when the root logger's log level is WARN, that information is never shown to the user. Author: Marcelo Vanzin Closes #9816 from vanzin/shell-logging. --- conf/log4j.properties.template | 5 +++ .../spark/log4j-defaults-repl.properties | 33 -------------- .../apache/spark/log4j-defaults.properties | 5 +++ .../main/scala/org/apache/spark/Logging.scala | 45 ++++++++++--------- .../apache/spark/repl/SparkILoopInit.scala | 21 ++++----- .../org/apache/spark/repl/SparkILoop.scala | 25 ++++++----- 6 files changed, 57 insertions(+), 77 deletions(-) delete mode 100644 core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index f3046be54d7c6..9809b0c828487 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -22,6 +22,11 @@ log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n +# Set the default spark-shell log level to WARN. When running the spark-shell, the +# log level for this class is used to overwrite the root logger's log level, so that +# the user can have different defaults for the shell and regular Spark apps. +log4j.logger.org.apache.spark.repl.Main=WARN + # Settings to quiet third party logs that are too verbose log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties deleted file mode 100644 index c85abc35b93bf..0000000000000 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ /dev/null @@ -1,33 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the console -log4j.rootCategory=WARN, console -log4j.appender.console=org.apache.log4j.ConsoleAppender -log4j.appender.console.target=System.err -log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - -# Settings to quiet third party logs that are too verbose -log4j.logger.org.spark-project.jetty=WARN -log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR -log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO -log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO - -# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support -log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL -log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index d44cc85dcbd82..0750488e4adf9 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -22,6 +22,11 @@ log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n +# Set the default spark-shell log level to WARN. When running the spark-shell, the +# log level for this class is used to overwrite the root logger's log level, so that +# the user can have different defaults for the shell and regular Spark apps. +log4j.logger.org.apache.spark.repl.Main=WARN + # Settings to quiet third party logs that are too verbose log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 69f6e06ee0057..e35e158c7e8a6 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.apache.log4j.{LogManager, PropertyConfigurator} +import org.apache.log4j.{Level, LogManager, PropertyConfigurator} import org.slf4j.{Logger, LoggerFactory} import org.slf4j.impl.StaticLoggerBinder @@ -119,30 +119,31 @@ trait Logging { val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + // scalastyle:off println if (!log4j12Initialized) { - // scalastyle:off println - if (Utils.isInInterpreter) { - val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" - Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's repl log4j profile: $replDefaultLogProps") - System.err.println("To adjust logging level use sc.setLogLevel(\"INFO\")") - case None => - System.err.println(s"Spark was unable to load $replDefaultLogProps") - } - } else { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") - } + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") } - // scalastyle:on println } + + if (Utils.isInInterpreter) { + // Use the repl's main class to define the default log level when running the shell, + // overriding the root logger's config if they're different. + val rootLogger = LogManager.getRootLogger() + val replLogger = LogManager.getLogger("org.apache.spark.repl.Main") + val replLevel = Option(replLogger.getLevel()).getOrElse(Level.WARN) + if (replLevel != rootLogger.getEffectiveLevel()) { + System.err.printf("Setting default log level to \"%s\".\n", replLevel) + System.err.println("To adjust logging level use sc.setLogLevel(newLevel).") + rootLogger.setLevel(replLevel) + } + } + // scalastyle:on println } Logging.initialized = true diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index bd3314d94eed6..99e1e1df33fd8 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -123,18 +123,19 @@ private[repl] trait SparkILoopInit { def initializeSpark() { intp.beQuietDuring { command(""" - @transient val sc = { - val _sc = org.apache.spark.repl.Main.interp.createSparkContext() - println("Spark context available as sc.") - _sc - } + @transient val sc = { + val _sc = org.apache.spark.repl.Main.interp.createSparkContext() + println("Spark context available as sc " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + _sc + } """) command(""" - @transient val sqlContext = { - val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() - println("SQL context available as sqlContext.") - _sqlContext - } + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } """) command("import org.apache.spark.SparkContext._") command("import sqlContext.implicits._") diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 33d262558b1fc..e91139fb29f69 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -37,18 +37,19 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def initializeSpark() { intp.beQuietDuring { processLine(""" - @transient val sc = { - val _sc = org.apache.spark.repl.Main.createSparkContext() - println("Spark context available as sc.") - _sc - } + @transient val sc = { + val _sc = org.apache.spark.repl.Main.createSparkContext() + println("Spark context available as sc " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + _sc + } """) processLine(""" - @transient val sqlContext = { - val _sqlContext = org.apache.spark.repl.Main.createSQLContext() - println("SQL context available as sqlContext.") - _sqlContext - } + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } """) processLine("import org.apache.spark.SparkContext._") processLine("import sqlContext.implicits._") @@ -85,7 +86,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) /** Available commands */ override def commands: List[LoopCommand] = sparkStandardCommands - /** + /** * We override `loadFiles` because we need to initialize Spark *before* the REPL * sees any files, so that the Spark context is visible in those files. This is a bit of a * hack, but there isn't another hook available to us at this point. @@ -98,7 +99,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) object SparkILoop { - /** + /** * Creates an interpreter loop with default settings and feeds * the given code to it as input. */ From 0c314f5e377d97ac3115e57dc76a6a9d20d851c0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 24 Nov 2015 14:33:28 -0800 Subject: [PATCH 0908/1168] [SPARK-11805] free the array in UnsafeExternalSorter during spilling After calling spill() on SortedIterator, the array inside InMemorySorter is not needed, it should be freed during spilling, this could help to join multiple tables with limited memory. Author: Davies Liu Closes #9793 from davies/free_array. --- .../unsafe/sort/UnsafeExternalSorter.java | 10 +++--- .../unsafe/sort/UnsafeInMemorySorter.java | 31 ++++++++----------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 9a7b2ad06cab6..2e40312674737 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -468,6 +468,12 @@ public long spill() throws IOException { } allocatedPages.clear(); } + + // in-memory sorter will not be used after spilling + assert(inMemSorter != null); + released += inMemSorter.getMemoryUsage(); + inMemSorter.free(); + inMemSorter = null; return released; } } @@ -489,10 +495,6 @@ public void loadNext() throws IOException { } upstream = nextUpstream; nextUpstream = null; - - assert(inMemSorter != null); - inMemSorter.free(); - inMemSorter = null; } numRecords--; upstream.loadNext(); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index a218ad4623f46..dce1f15a2963c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -108,6 +108,7 @@ public UnsafeInMemorySorter( */ public void free() { consumer.freeArray(array); + array = null; } public void reset() { @@ -160,28 +161,22 @@ public void insertRecord(long recordPointer, long keyPrefix) { pos++; } - public static final class SortedIterator extends UnsafeSorterIterator { + public final class SortedIterator extends UnsafeSorterIterator { - private final TaskMemoryManager memoryManager; - private final int sortBufferInsertPosition; - private final LongArray sortBuffer; - private int position = 0; + private final int numRecords; + private int position; private Object baseObject; private long baseOffset; private long keyPrefix; private int recordLength; - private SortedIterator( - TaskMemoryManager memoryManager, - int sortBufferInsertPosition, - LongArray sortBuffer) { - this.memoryManager = memoryManager; - this.sortBufferInsertPosition = sortBufferInsertPosition; - this.sortBuffer = sortBuffer; + private SortedIterator(int numRecords) { + this.numRecords = numRecords; + this.position = 0; } public SortedIterator clone () { - SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer); + SortedIterator iter = new SortedIterator(numRecords); iter.position = position; iter.baseObject = baseObject; iter.baseOffset = baseOffset; @@ -192,21 +187,21 @@ public SortedIterator clone () { @Override public boolean hasNext() { - return position < sortBufferInsertPosition; + return position / 2 < numRecords; } public int numRecordsLeft() { - return (sortBufferInsertPosition - position) / 2; + return numRecords - position / 2; } @Override public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes - final long recordPointer = sortBuffer.get(position); + final long recordPointer = array.get(position); baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length recordLength = Platform.getInt(baseObject, baseOffset - 4); - keyPrefix = sortBuffer.get(position + 1); + keyPrefix = array.get(position + 1); position += 2; } @@ -229,6 +224,6 @@ public void loadNext() { */ public SortedIterator getSortedIterator() { sorter.sort(array, 0, pos / 2, sortComparator); - return new SortedIterator(memoryManager, pos, array); + return new SortedIterator(pos / 2); } } From 372244d64a8647f482c08352ff70eec8338c3c50 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 14:51:01 -0800 Subject: [PATCH 0909/1168] Added a line of comment to explain why the extra sort exists in pivot. --- sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index ee7150cbbfbca..abd531c4ba541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -304,7 +304,7 @@ class GroupedData protected[sql]( // Get the distinct values of the column and sort them so its consistent val values = df.select(pivotColumn) .distinct() - .sort(pivotColumn) + .sort(pivotColumn) // ensure that the output columns are in a consistent logical order .map(_.get(0)) .take(maxValues + 1) .toSeq From d625b5493e1b5d020d08b227d01c9e961187ee0f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 24 Nov 2015 14:59:14 -0800 Subject: [PATCH 0910/1168] [SPARK-11783][SQL] Fixes execution Hive client when using remote Hive metastore When using remote Hive metastore, `hive.metastore.uris` is set to the metastore URI. However, it overrides `javax.jdo.option.ConnectionURL` unexpectedly, thus the execution Hive client connects to the actual remote Hive metastore instead of the Derby metastore created in the temporary directory. Cleaning this configuration for the execution Hive client fixes this issue. Author: Cheng Lian Closes #9895 from liancheng/spark-11783.clean-remote-metastore-config. --- .../org/apache/spark/sql/hive/HiveContext.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c0bb5af7d5c85..8a4264194ae8d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -736,6 +736,21 @@ private[hive] object HiveContext { s"jdbc:derby:;databaseName=${localMetastore.getAbsolutePath};create=true") propMap.put("datanucleus.rdbms.datastoreAdapterClassName", "org.datanucleus.store.rdbms.adapter.DerbyAdapter") + + // SPARK-11783: When "hive.metastore.uris" is set, the metastore connection mode will be + // remote (https://cwiki.apache.org/confluence/display/Hive/AdminManual+MetastoreAdmin + // mentions that "If hive.metastore.uris is empty local mode is assumed, remote otherwise"). + // Remote means that the metastore server is running in its own process. + // When the mode is remote, configurations like "javax.jdo.option.ConnectionURL" will not be + // used (because they are used by remote metastore server that talks to the database). + // Because execution Hive should always connects to a embedded derby metastore. + // We have to remove the value of hive.metastore.uris. So, the execution Hive client connects + // to the actual embedded derby metastore instead of the remote metastore. + // You can search HiveConf.ConfVars.METASTOREURIS in the code of HiveConf (in Hive's repo). + // Then, you will find that the local metastore mode is only set to true when + // hive.metastore.uris is not set. + propMap.put(ConfVars.METASTOREURIS.varname, "") + propMap.toMap } From b9dabec8f923eda9ea22b2ad0dc3b2971ae4c553 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 24 Nov 2015 15:54:10 -0800 Subject: [PATCH 0911/1168] [SPARK-11914][SQL] Support coalesce and repartition in Dataset APIs This PR is to provide two common `coalesce` and `repartition` in Dataset APIs. After reading the comments of SPARK-9999, I am unclear about the plan for supporting re-partitioning in Dataset APIs. Currently, both RDD APIs and Dataframe APIs provide users such a flexibility to control the number of partitions. In most traditional RDBMS, they expose the number of partitions, the partitioning columns, the table partitioning methods to DBAs for performance tuning and storage planning. Normally, these parameters could largely affect the query performance. Since the actual performance depends on the workload types, I think it is almost impossible to automate the discovery of the best partitioning strategy for all the scenarios. I am wondering if Dataset APIs are planning to hide these APIs from users? Feel free to reject my PR if it does not match the plan. Thank you for your answers. marmbrus rxin cloud-fan Author: gatorsmile Closes #9899 from gatorsmile/coalesce. --- .../scala/org/apache/spark/sql/Dataset.scala | 19 +++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 15 +++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 07647508421a4..17e2611790d5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -152,6 +152,25 @@ class Dataset[T] private[sql]( */ def count(): Long = toDF().count() + /** + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * @since 1.6.0 + */ + def repartition(numPartitions: Int): Dataset[T] = withPlan { + Repartition(numPartitions, shuffle = true, _) + } + + /** + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. + * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of + * the 100 new partitions will claim 10 of the current partitions. + * @since 1.6.0 + */ + def coalesce(numPartitions: Int): Dataset[T] = withPlan { + Repartition(numPartitions, shuffle = false, _) + } + /* *********************** * * Functional Operations * * *********************** */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 13eede1b17d8b..c253fdbb8c99e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -52,6 +52,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.takeAsList(1).get(0) == item) } + test("coalesce, repartition") { + val data = (1 to 100).map(i => ClassData(i.toString, i)) + val ds = data.toDS() + + assert(ds.repartition(10).rdd.partitions.length == 10) + checkAnswer( + ds.repartition(10), + data: _*) + + assert(ds.coalesce(1).rdd.partitions.length == 1) + checkAnswer( + ds.coalesce(1), + data: _*) + } + test("as tuple") { val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") checkAnswer( From 49003f487428a79e03690ed41d49eaac8b504dfb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 18:16:07 -0800 Subject: [PATCH 0912/1168] [SPARK-11967][SQL] Consistent use of varargs for multiple paths in DataFrameReader This patch makes it consistent to use varargs in all DataFrameReader methods, including Parquet, JSON, text, and the generic load function. Also added a few more API tests for the Java API. Author: Reynold Xin Closes #9945 from rxin/SPARK-11967. --- python/pyspark/sql/readwriter.py | 19 ++++++---- .../apache/spark/sql/DataFrameReader.scala | 36 +++++++++++++++---- .../apache/spark/sql/JavaDataFrameSuite.java | 23 ++++++++++++ sql/core/src/test/resources/text-suite2.txt | 1 + .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 5 files changed, 66 insertions(+), 15 deletions(-) create mode 100644 sql/core/src/test/resources/text-suite2.txt diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index e8f0d7ec77035..2e75f0c8a1827 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -109,7 +109,7 @@ def options(self, **options): def load(self, path=None, format=None, schema=None, **options): """Loads data from a data source and returns it as a :class`DataFrame`. - :param path: optional string for file-system backed data sources. + :param path: optional string or a list of string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. :param schema: optional :class:`StructType` for the input schema. :param options: all other string options @@ -118,6 +118,7 @@ def load(self, path=None, format=None, schema=None, **options): ... opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json', ... 'python/test_support/sql/people1.json']) >>> df.dtypes @@ -130,10 +131,8 @@ def load(self, path=None, format=None, schema=None, **options): self.options(**options) if path is not None: if type(path) == list: - paths = path - gateway = self._sqlContext._sc._gateway - jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths) - return self._df(self._jreader.load(jpaths)) + return self._df( + self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) else: return self._df(self._jreader.load(path)) else: @@ -175,6 +174,8 @@ def json(self, path, schema=None): self.schema(schema) if isinstance(path, basestring): return self._df(self._jreader.json(path)) + elif type(path) == list: + return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): return self._df(self._jreader.json(path._jrdd)) else: @@ -205,16 +206,20 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) - def text(self, path): + def text(self, paths): """Loads a text file and returns a [[DataFrame]] with a single string column named "text". Each line in the text file is a new row in the resulting DataFrame. + :param paths: string, or list of strings, for input path(s). + >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt') >>> df.collect() [Row(value=u'hello'), Row(value=u'this')] """ - return self._df(self._jreader.text(path)) + if isinstance(paths, basestring): + paths = [paths] + return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) @since(1.5) def orc(self, path): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index dcb3737b70fbf..3ed1e55adec6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -24,17 +24,17 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.util.StringUtils +import org.apache.spark.{Logging, Partition} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.SqlParser import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JSONRelation} +import org.apache.spark.sql.execution.datasources.json.JSONRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, Partition} -import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} /** * :: Experimental :: @@ -104,6 +104,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.4.0 */ + // TODO: Remove this one in Spark 2.0. def load(path: String): DataFrame = { option("path", path).load() } @@ -130,7 +131,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.6.0 */ - def load(paths: Array[String]): DataFrame = { + @scala.annotation.varargs + def load(paths: String*): DataFrame = { option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load() } @@ -236,11 +238,30 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers * (e.g. 00012)
  • * - * @param path input path * @since 1.4.0 */ + // TODO: Remove this one in Spark 2.0. def json(path: String): DataFrame = format("json").load(path) + /** + * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * + * This function goes through the input once to determine the input schema. If you know the + * schema in advance, use the version that specifies the schema to avoid the extra scan. + * + * You can set the following JSON-specific options to deal with non-standard JSON files: + *
  • `primitivesAsString` (default `false`): infers all primitive values as a string type
  • + *
  • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
  • + *
  • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
  • + *
  • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes + *
  • + *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers + * (e.g. 00012)
  • + * + * @since 1.6.0 + */ + def json(paths: String*): DataFrame = format("json").load(paths : _*) + /** * Loads an `JavaRDD[String]` storing JSON objects (one object per record) and * returns the result as a [[DataFrame]]. @@ -328,10 +349,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * sqlContext.read().text("/path/to/spark/README.md") * }}} * - * @param path input path + * @param paths input path * @since 1.6.0 */ - def text(path: String): DataFrame = format("text").load(path) + @scala.annotation.varargs + def text(paths: String*): DataFrame = format("text").load(paths : _*) /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index a12fed3c0c6af..8e0b2dbca4a98 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -298,4 +298,27 @@ public void pivot() { Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01); Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); } + + public void testGenericLoad() { + DataFrame df1 = context.read().format("text").load( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); + Assert.assertEquals(4L, df1.count()); + + DataFrame df2 = context.read().format("text").load( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), + Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); + Assert.assertEquals(5L, df2.count()); + } + + @Test + public void testTextLoad() { + DataFrame df1 = context.read().text( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); + Assert.assertEquals(4L, df1.count()); + + DataFrame df2 = context.read().text( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), + Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); + Assert.assertEquals(5L, df2.count()); + } } diff --git a/sql/core/src/test/resources/text-suite2.txt b/sql/core/src/test/resources/text-suite2.txt new file mode 100644 index 0000000000000..f9d498c80493c --- /dev/null +++ b/sql/core/src/test/resources/text-suite2.txt @@ -0,0 +1 @@ +This is another file for testing multi path loading. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index dd6d06512ff60..76e9648aa7533 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -897,7 +897,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val dir2 = new File(dir, "dir2").getCanonicalPath df2.write.format("json").save(dir2) - checkAnswer(sqlContext.read.format("json").load(Array(dir1, dir2)), + checkAnswer(sqlContext.read.format("json").load(dir1, dir2), Row(1, 22) :: Row(2, 23) :: Nil) checkAnswer(sqlContext.read.format("json").load(dir1), From 5382d27f228f93bd67d605fb5c89feed20661f17 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 18:58:55 -0800 Subject: [PATCH 0913/1168] [SPARK-11947][SQL] Mark deprecated methods with "This will be removed in Spark 2.0." Also fixed some documentation as I saw them. Author: Reynold Xin Closes #9930 from rxin/SPARK-11947. --- project/MimaExcludes.scala | 3 +- .../scala/org/apache/spark/sql/Column.scala | 20 +++-- .../org/apache/spark/sql/DataFrame.scala | 72 +++++++++------ .../scala/org/apache/spark/sql/Dataset.scala | 1 + .../org/apache/spark/sql/SQLContext.scala | 88 ++++++++++--------- .../org/apache/spark/sql/SQLImplicits.scala | 25 +++++- .../sql/{ => execution}/SparkSQLParser.scala | 15 ++-- .../org/apache/spark/sql/functions.scala | 52 ++++++----- .../SimpleTextHadoopFsRelationSuite.scala | 6 +- 9 files changed, 172 insertions(+), 110 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/SparkSQLParser.scala (89%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bb45d1bb12146..54a9ad956d119 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -108,7 +108,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$") + "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSQLParser") ) ++ Seq( // SPARK-11485 ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 30c554a85e693..b3cd9e1eff142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -42,7 +42,8 @@ private[sql] object Column { /** * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. - * @since 1.6.0 + * To create a [[TypedColumn]], use the `as` function on a [[Column]]. + * * @tparam T The input type expected for this expression. Can be `Any` if the expression is type * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). * @tparam U The output type of this column. @@ -51,7 +52,8 @@ private[sql] object Column { */ class TypedColumn[-T, U]( expr: Expression, - private[sql] val encoder: ExpressionEncoder[U]) extends Column(expr) { + private[sql] val encoder: ExpressionEncoder[U]) + extends Column(expr) { /** * Inserts the specific input type and schema into any expressions that are expected to operate @@ -61,12 +63,11 @@ class TypedColumn[-T, U]( inputEncoder: ExpressionEncoder[_], schema: Seq[Attribute]): TypedColumn[T, U] = { val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]] - new TypedColumn[T, U] (expr transform { - case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => - ta.copy( - aEncoder = Some(boundEncoder), - children = schema) - }, encoder) + new TypedColumn[T, U]( + expr transform { case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy(aEncoder = Some(boundEncoder), children = schema) + }, + encoder) } } @@ -691,8 +692,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops * @since 1.3.0 + * @deprecated As of 1.5.0. Use isin. This will be removed in Spark 2.0. */ - @deprecated("use isin", "1.5.0") + @deprecated("use isin. This will be removed in Spark 2.0.", "1.5.0") @scala.annotation.varargs def in(list: Any*): Column = isin(list : _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5586fc994b98a..5eca1db9525ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1713,9 +1713,9 @@ class DataFrame private[sql]( //////////////////////////////////////////////////////////////////////////// /** - * @deprecated As of 1.3.0, replaced by `toDF()`. + * @deprecated As of 1.3.0, replaced by `toDF()`. This will be removed in Spark 2.0. */ - @deprecated("use toDF", "1.3.0") + @deprecated("Use toDF. This will be removed in Spark 2.0.", "1.3.0") def toSchemaRDD: DataFrame = this /** @@ -1725,9 +1725,9 @@ class DataFrame private[sql]( * given name; if you pass `false`, it will throw if the table already * exists. * @group output - * @deprecated As of 1.340, replaced by `write().jdbc()`. + * @deprecated As of 1.340, replaced by `write().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.jdbc()", "1.4.0") + @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { val w = if (allowExisting) write.mode(SaveMode.Overwrite) else write w.jdbc(url, table, new Properties) @@ -1744,9 +1744,9 @@ class DataFrame private[sql]( * the RDD in order via the simple statement * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. * @group output - * @deprecated As of 1.4.0, replaced by `write().jdbc()`. + * @deprecated As of 1.4.0, replaced by `write().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.jdbc()", "1.4.0") + @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { val w = if (overwrite) write.mode(SaveMode.Overwrite) else write.mode(SaveMode.Append) w.jdbc(url, table, new Properties) @@ -1757,9 +1757,9 @@ class DataFrame private[sql]( * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. * @group output - * @deprecated As of 1.4.0, replaced by `write().parquet()`. + * @deprecated As of 1.4.0, replaced by `write().parquet()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.parquet(path)", "1.4.0") + @deprecated("Use write.parquet(path). This will be removed in Spark 2.0.", "1.4.0") def saveAsParquetFile(path: String): Unit = { write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) } @@ -1782,8 +1782,9 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.saveAsTable(tableName). This will be removed in Spark 2.0.", "1.4.0") def saveAsTable(tableName: String): Unit = { write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName) } @@ -1805,8 +1806,10 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(mode).saveAsTable(tableName). This will be removed in Spark 2.0.", + "1.4.0") def saveAsTable(tableName: String, mode: SaveMode): Unit = { write.mode(mode).saveAsTable(tableName) } @@ -1829,8 +1832,10 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.format(source).saveAsTable(tableName). This will be removed in Spark 2.0.", + "1.4.0") def saveAsTable(tableName: String, source: String): Unit = { write.format(source).saveAsTable(tableName) } @@ -1853,8 +1858,10 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { write.format(source).mode(mode).saveAsTable(tableName) } @@ -1877,9 +1884,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", - "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable( tableName: String, source: String, @@ -1907,9 +1915,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", - "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable( tableName: String, source: String, @@ -1923,9 +1932,9 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default and * [[SaveMode.ErrorIfExists]] as the save mode. * @group output - * @deprecated As of 1.4.0, replaced by `write().save(path)`. + * @deprecated As of 1.4.0, replaced by `write().save(path)`. This will be removed in Spark 2.0. */ - @deprecated("Use write.save(path)", "1.4.0") + @deprecated("Use write.save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String): Unit = { write.save(path) } @@ -1935,8 +1944,9 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default. * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(mode).save(path)", "1.4.0") + @deprecated("Use write.mode(mode).save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String, mode: SaveMode): Unit = { write.mode(mode).save(path) } @@ -1946,8 +1956,9 @@ class DataFrame private[sql]( * using [[SaveMode.ErrorIfExists]] as the save mode. * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).save(path)", "1.4.0") + @deprecated("Use write.format(source).save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String, source: String): Unit = { write.format(source).save(path) } @@ -1957,8 +1968,10 @@ class DataFrame private[sql]( * [[SaveMode]] specified by mode. * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0") + @deprecated("Use write.format(source).mode(mode).save(path). " + + "This will be removed in Spark 2.0.", "1.4.0") def save(path: String, source: String, mode: SaveMode): Unit = { write.format(source).mode(mode).save(path) } @@ -1969,8 +1982,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).save(). " + + "This will be removed in Spark 2.0.", "1.4.0") def save( source: String, mode: SaveMode, @@ -1985,8 +2000,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).save(). " + + "This will be removed in Spark 2.0.", "1.4.0") def save( source: String, mode: SaveMode, @@ -1994,14 +2011,15 @@ class DataFrame private[sql]( write.format(source).mode(mode).options(options).save() } - /** * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. * @group output * @deprecated As of 1.4.0, replaced by * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def insertInto(tableName: String, overwrite: Boolean): Unit = { write.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append).insertInto(tableName) } @@ -2012,8 +2030,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().mode(SaveMode.Append).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def insertInto(tableName: String): Unit = { write.mode(SaveMode.Append).insertInto(tableName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 17e2611790d5a..dd84b8bc11e2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType /** + * :: Experimental :: * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel * using functional or relational operations. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 39471d2fb79a7..46bf544fd885f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -942,33 +942,33 @@ class SQLContext private[sql]( //////////////////////////////////////////////////////////////////////////// /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) } @@ -978,9 +978,9 @@ class SQLContext private[sql]( * [[DataFrame]] if no paths are passed in. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().parquet()`. + * @deprecated As of 1.4.0, replaced by `read().parquet()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.parquet()", "1.4.0") + @deprecated("Use read.parquet(). This will be removed in Spark 2.0.", "1.4.0") @scala.annotation.varargs def parquetFile(paths: String*): DataFrame = { if (paths.isEmpty) { @@ -995,9 +995,9 @@ class SQLContext private[sql]( * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String): DataFrame = { read.json(path) } @@ -1007,18 +1007,18 @@ class SQLContext private[sql]( * returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String, schema: StructType): DataFrame = { read.schema(schema).json(path) } /** * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String, samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(path) } @@ -1029,9 +1029,9 @@ class SQLContext private[sql]( * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String]): DataFrame = read.json(json) /** @@ -1040,9 +1040,9 @@ class SQLContext private[sql]( * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) /** @@ -1050,9 +1050,9 @@ class SQLContext private[sql]( * returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { read.schema(schema).json(json) } @@ -1062,9 +1062,9 @@ class SQLContext private[sql]( * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { read.schema(schema).json(json) } @@ -1074,9 +1074,9 @@ class SQLContext private[sql]( * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(json) } @@ -1086,9 +1086,9 @@ class SQLContext private[sql]( * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(json) } @@ -1098,9 +1098,9 @@ class SQLContext private[sql]( * using the default data source configured by spark.sql.sources.default. * * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().load(path)`. + * @deprecated As of 1.4.0, replaced by `read().load(path)`. This will be removed in Spark 2.0. */ - @deprecated("Use read.load(path)", "1.4.0") + @deprecated("Use read.load(path). This will be removed in Spark 2.0.", "1.4.0") def load(path: String): DataFrame = { read.load(path) } @@ -1110,8 +1110,9 @@ class SQLContext private[sql]( * * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use read.format(source).load(path)", "1.4.0") + @deprecated("Use read.format(source).load(path). This will be removed in Spark 2.0.", "1.4.0") def load(path: String, source: String): DataFrame = { read.format(source).load(path) } @@ -1122,8 +1123,10 @@ class SQLContext private[sql]( * * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. + * This will be removed in Spark 2.0. */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, options: java.util.Map[String, String]): DataFrame = { read.options(options).format(source).load() } @@ -1135,7 +1138,8 @@ class SQLContext private[sql]( * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, options: Map[String, String]): DataFrame = { read.options(options).format(source).load() } @@ -1148,7 +1152,8 @@ class SQLContext private[sql]( * @deprecated As of 1.4.0, replaced by * `read().format(source).schema(schema).options(options).load()`. */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).schema(schema).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { read.format(source).schema(schema).options(options).load() @@ -1162,7 +1167,8 @@ class SQLContext private[sql]( * @deprecated As of 1.4.0, replaced by * `read().format(source).schema(schema).options(options).load()`. */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).schema(schema).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { read.format(source).schema(schema).options(options).load() } @@ -1172,9 +1178,9 @@ class SQLContext private[sql]( * url named table. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc(url: String, table: String): DataFrame = { read.jdbc(url, table, new Properties) } @@ -1190,9 +1196,9 @@ class SQLContext private[sql]( * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split * evenly into this many partitions * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc( url: String, table: String, @@ -1210,9 +1216,9 @@ class SQLContext private[sql]( * of the [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { read.jdbc(url, table, theParts, new Properties) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 25ffdcde17717..6735d02954b8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -30,19 +30,38 @@ import org.apache.spark.unsafe.types.UTF8String /** * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. + * + * @since 1.6.0 */ abstract class SQLImplicits { + protected def _sqlContext: SQLContext + /** @since 1.6.0 */ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + /** @since 1.6.0 */ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() + /** @since 1.6.0 */ + implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() /** @@ -84,9 +103,9 @@ abstract class SQLImplicits { DataFrameHolder(_sqlContext.createDataFrame(data)) } - // Do NOT add more implicit conversions. They are likely to break source compatibility by - // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous - // because of [[DoubleRDDFunctions]]. + // Do NOT add more implicit conversions for primitive types. + // They are likely to break source compatibility by making existing implicit conversions + // ambiguous. In particular, RDD[Double] is dangerous because of [[DoubleRDDFunctions]]. /** * Creates a single column DataFrame from an RDD[Int]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala similarity index 89% rename from sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala index ea8fce6ca9cf2..b3e8d0d84937e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala @@ -15,24 +15,23 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.execution import scala.util.parsing.combinator.RegexParsers import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{DescribeFunction, LogicalPlan, ShowFunctions} -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.StringType - /** * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. * * @param fallback A function that parses an input string to a logical plan */ -private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { +class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { // A parser for the key-value part of the "SET [key = [value ]]" syntax private object SetCommandParser extends RegexParsers { @@ -100,14 +99,14 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr case _ ~ dbName => ShowTablesCommand(dbName) } | SHOW ~ FUNCTIONS ~> ((ident <~ ".").? ~ (ident | stringLit)).? ^^ { - case Some(f) => ShowFunctions(f._1, Some(f._2)) - case None => ShowFunctions(None, None) + case Some(f) => logical.ShowFunctions(f._1, Some(f._2)) + case None => logical.ShowFunctions(None, None) } ) private lazy val desc: Parser[LogicalPlan] = DESCRIBE ~ FUNCTION ~> EXTENDED.? ~ (ident | stringLit) ^^ { - case isExtended ~ functionName => DescribeFunction(functionName, isExtended.isDefined) + case isExtended ~ functionName => logical.DescribeFunction(functionName, isExtended.isDefined) } private lazy val others: Parser[LogicalPlan] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6137ce3a70fdb..77dd5bc72508b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql - - import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} import scala.util.Try @@ -39,11 +37,11 @@ import org.apache.spark.util.Utils * "bridge" methods due to the use of covariant return types. * * {{{ - * In LegacyFunctions: - * public abstract org.apache.spark.sql.Column avg(java.lang.String); + * // In LegacyFunctions: + * public abstract org.apache.spark.sql.Column avg(java.lang.String); * - * In functions: - * public static org.apache.spark.sql.TypedColumn avg(...); + * // In functions: + * public static org.apache.spark.sql.TypedColumn avg(...); * }}} * * This allows us to use the same functions both in typed [[Dataset]] operations and untyped @@ -2528,8 +2526,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function0[_], returnType: DataType): Column = withExpr { ScalaUDF(f, returnType, Seq()) } @@ -2541,8 +2540,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr)) } @@ -2554,8 +2554,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } @@ -2567,8 +2568,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } @@ -2580,8 +2582,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } @@ -2593,8 +2596,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } @@ -2606,8 +2610,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } @@ -2619,8 +2624,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } @@ -2632,8 +2638,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } @@ -2644,9 +2651,10 @@ object functions extends LegacyFunctions { * * @group udf_funcs * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() + * @deprecated As of 1.5.0, since it's redundant with udf(). + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } @@ -2657,9 +2665,10 @@ object functions extends LegacyFunctions { * * @group udf_funcs * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() + * @deprecated As of 1.5.0, since it's redundant with udf(). + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } @@ -2700,9 +2709,10 @@ object functions extends LegacyFunctions { * * @group udf_funcs * @since 1.4.0 - * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF + * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF. + * This will be removed in Spark 2.0. */ - @deprecated("Use callUDF", "1.5.0") + @deprecated("Use callUDF. This will be removed in Spark 2.0.", "1.5.0") def callUdf(udfName: String, cols: Column*): Column = withExpr { // Note: we avoid using closures here because on file systems that are case-insensitive, the // compiled class file for the closure here will conflict with the one in callUDF (upper case). diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 81af684ba0bf1..b554d135e4b5c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -80,7 +80,11 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat private var partitionedDF: DataFrame = _ - private val partitionedDataSchema: StructType = StructType('a.int :: 'b.int :: 'c.string :: Nil) + private val partitionedDataSchema: StructType = + new StructType() + .add("a", IntegerType) + .add("b", IntegerType) + .add("c", StringType) protected override def beforeAll(): Unit = { this.tempPath = Utils.createTempDir() From 75274221c2c56aca1b8a2fcd2d391777c4eb938d Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 24 Nov 2015 20:58:47 -0800 Subject: [PATCH 0914/1168] [STREAMING][FLAKY-TEST] Catch execution context race condition in `FileBasedWriteAheadLog.close()` There is a race condition in `FileBasedWriteAheadLog.close()`, where if delete's of old log files are in progress, the write ahead log may close, and result in a `RejectedExecutionException`. This is okay, and should be handled gracefully. Example test failures: https://amplab.cs.berkeley.edu/jenkins/job/Spark-1.6-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop1.0,label=spark-test/95/testReport/junit/org.apache.spark.streaming.util/BatchedWriteAheadLogWithCloseFileAfterWriteSuite/BatchedWriteAheadLog___clean_old_logs/ The reason the test fails is in `afterEach`, `writeAheadLog.close` is called, and there may still be async deletes in flight. tdas zsxwing Author: Burak Yavuz Closes #9953 from brkyvz/flaky-ss. --- .../streaming/util/FileBasedWriteAheadLog.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 72705f1a9c010..f5165f7c39122 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer -import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor} import java.util.{Iterator => JIterator} import scala.collection.JavaConverters._ @@ -176,10 +176,16 @@ private[streaming] class FileBasedWriteAheadLog( } oldLogFiles.foreach { logInfo => if (!executionContext.isShutdown) { - val f = Future { deleteFile(logInfo) }(executionContext) - if (waitForCompletion) { - import scala.concurrent.duration._ - Await.ready(f, 1 second) + try { + val f = Future { deleteFile(logInfo) }(executionContext) + if (waitForCompletion) { + import scala.concurrent.duration._ + Await.ready(f, 1 second) + } + } catch { + case e: RejectedExecutionException => + logWarning("Execution context shutdown before deleting old WriteAheadLogs. " + + "This would not affect recovery correctness.", e) } } } From ecab96d78a2ef28e9f440c5e8fa7b51b88fe69c1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 21:30:53 -0800 Subject: [PATCH 0915/1168] [SPARK-10621][SQL] Consistent naming for functions in SQL, Python, Scala Author: Reynold Xin Closes #9948 from rxin/SPARK-10621. --- python/pyspark/sql/functions.py | 111 +++++++++++++--- .../org/apache/spark/sql/functions.scala | 124 ++++++++++++++---- 2 files changed, 196 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a1ca723bbd7ab..e3786e0fa5fb2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -150,18 +150,18 @@ def _(): _window_functions = { 'rowNumber': - """returns a sequential number starting at 1 within a window partition. - - This is equivalent to the ROW_NUMBER function in SQL.""", + """.. note:: Deprecated in 1.6, use row_number instead.""", + 'row_number': + """returns a sequential number starting at 1 within a window partition.""", 'denseRank': + """.. note:: Deprecated in 1.6, use dense_rank instead.""", + 'dense_rank': """returns the rank of rows within a window partition, without any gaps. The difference between rank and denseRank is that denseRank leaves no gaps in ranking sequence when there are ties. That is, if you were ranking a competition using denseRank and had three people tie for second place, you would say that all three were in second - place and that the next person came in third. - - This is equivalent to the DENSE_RANK function in SQL.""", + place and that the next person came in third.""", 'rank': """returns the rank of rows within a window partition. @@ -172,14 +172,14 @@ def _(): This is equivalent to the RANK function in SQL.""", 'cumeDist': + """.. note:: Deprecated in 1.6, use cume_dist instead.""", + 'cume_dist': """returns the cumulative distribution of values within a window partition, - i.e. the fraction of rows that are below the current row. - - This is equivalent to the CUME_DIST function in SQL.""", + i.e. the fraction of rows that are below the current row.""", 'percentRank': - """returns the relative rank (i.e. percentile) of rows within a window partition. - - This is equivalent to the PERCENT_RANK function in SQL.""", + """.. note:: Deprecated in 1.6, use percent_rank instead.""", + 'percent_rank': + """returns the relative rank (i.e. percentile) of rows within a window partition.""", } for _name, _doc in _functions.items(): @@ -189,7 +189,7 @@ def _(): for _name, _doc in _binary_mathfunctions.items(): globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc)) for _name, _doc in _window_functions.items(): - globals()[_name] = since(1.4)(_create_window_function(_name, _doc)) + globals()[_name] = since(1.6)(_create_window_function(_name, _doc)) for _name, _doc in _functions_1_6.items(): globals()[_name] = since(1.6)(_create_function(_name, _doc)) del _name, _doc @@ -288,6 +288,38 @@ def countDistinct(col, *cols): @since(1.4) def monotonicallyIncreasingId(): + """ + .. note:: Deprecated in 1.6, use monotonically_increasing_id instead. + """ + return monotonically_increasing_id() + + +@since(1.6) +def input_file_name(): + """Creates a string column for the file name of the current Spark task. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.input_file_name()) + + +@since(1.6) +def isnan(col): + """An expression that returns true iff the column is NaN. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.isnan(_to_java_column(col))) + + +@since(1.6) +def isnull(col): + """An expression that returns true iff the column is null. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.isnull(_to_java_column(col))) + + +@since(1.6) +def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. @@ -300,11 +332,21 @@ def monotonicallyIncreasingId(): 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) - >>> df0.select(monotonicallyIncreasingId().alias('id')).collect() + >>> df0.select(monotonically_increasing_id().alias('id')).collect() [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.monotonicallyIncreasingId()) + return Column(sc._jvm.functions.monotonically_increasing_id()) + + +@since(1.6) +def nanvl(col1, col2): + """Returns col1 if it is not NaN, or col2 if col1 is NaN. + + Both inputs should be floating point columns (DoubleType or FloatType). + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) @since(1.4) @@ -382,15 +424,23 @@ def shiftRightUnsigned(col, numBits): @since(1.4) def sparkPartitionId(): + """ + .. note:: Deprecated in 1.6, use spark_partition_id instead. + """ + return spark_partition_id() + + +@since(1.6) +def spark_partition_id(): """A column for partition ID of the Spark task. Note that this is indeterministic because it depends on data partitioning and task scheduling. - >>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect() + >>> df.repartition(1).select(spark_partition_id().alias("pid")).collect() [Row(pid=0), Row(pid=0)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.sparkPartitionId()) + return Column(sc._jvm.functions.spark_partition_id()) @since(1.5) @@ -1410,6 +1460,33 @@ def explode(col): return Column(jc) +@since(1.6) +def get_json_object(col, path): + """ + Extracts json object from a json string based on json path specified, and returns json string + of the extracted json object. It will return null if the input json string is invalid. + + :param col: string column in json format + :param path: path to the json object to extract + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.get_json_object(_to_java_column(col), path) + return Column(jc) + + +@since(1.6) +def json_tuple(col, fields): + """Creates a new row for a json column according to the given field names. + + :param col: string column in json format + :param fields: list of fields to extract + + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.json_tuple(_to_java_column(col), fields) + return Column(jc) + + @since(1.5) def size(col): """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 77dd5bc72508b..276c5dfc8b062 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -472,6 +472,13 @@ object functions extends LegacyFunctions { // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `cume_dist`. This will be removed in Spark 2.0. + */ + @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") + def cumeDist(): Column = cume_dist() + /** * Window function: returns the cumulative distribution of values within a window partition, * i.e. the fraction of rows that are below the current row. @@ -481,13 +488,17 @@ object functions extends LegacyFunctions { * cumeDist(x) = number of values before (and including) x / N * }}} * - * - * This is equivalent to the CUME_DIST function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def cumeDist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) } + def cume_dist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) } + + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `dense_rank`. This will be removed in Spark 2.0. + */ + @deprecated("Use dense_rank. This will be removed in Spark 2.0.", "1.6.0") + def denseRank(): Column = dense_rank() /** * Window function: returns the rank of rows within a window partition, without any gaps. @@ -497,12 +508,10 @@ object functions extends LegacyFunctions { * and had three people tie for second place, you would say that all three were in second * place and that the next person came in third. * - * This is equivalent to the DENSE_RANK function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def denseRank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) } + def dense_rank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) } /** * Window function: returns the value that is `offset` rows before the current row, and @@ -620,6 +629,13 @@ object functions extends LegacyFunctions { */ def ntile(n: Int): Column = withExpr { UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) } + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `percent_rank`. This will be removed in Spark 2.0. + */ + @deprecated("Use percent_rank. This will be removed in Spark 2.0.", "1.6.0") + def percentRank(): Column = percent_rank() + /** * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. * @@ -631,9 +647,9 @@ object functions extends LegacyFunctions { * This is equivalent to the PERCENT_RANK function in SQL. * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def percentRank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) } + def percent_rank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) } /** * Window function: returns the rank of rows within a window partition. @@ -650,15 +666,20 @@ object functions extends LegacyFunctions { */ def rank(): Column = withExpr { UnresolvedWindowFunction("rank", Nil) } + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `row_number`. This will be removed in Spark 2.0. + */ + @deprecated("Use row_number. This will be removed in Spark 2.0.", "1.6.0") + def rowNumber(): Column = row_number() + /** * Window function: returns a sequential number starting at 1 within a window partition. * - * This is equivalent to the ROW_NUMBER function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def rowNumber(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) } + def row_number(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) } ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions @@ -720,20 +741,43 @@ object functions extends LegacyFunctions { @scala.annotation.varargs def coalesce(e: Column*): Column = withExpr { Coalesce(e.map(_.expr)) } + /** + * @group normal_funcs + * @deprecated As of 1.6.0, replaced by `input_file_name`. This will be removed in Spark 2.0. + */ + @deprecated("Use input_file_name. This will be removed in Spark 2.0.", "1.6.0") + def inputFileName(): Column = input_file_name() + /** * Creates a string column for the file name of the current Spark task. * * @group normal_funcs + * @since 1.6.0 */ - def inputFileName(): Column = withExpr { InputFileName() } + def input_file_name(): Column = withExpr { InputFileName() } + + /** + * @group normal_funcs + * @deprecated As of 1.6.0, replaced by `isnan`. This will be removed in Spark 2.0. + */ + @deprecated("Use isnan. This will be removed in Spark 2.0.", "1.6.0") + def isNaN(e: Column): Column = isnan(e) /** * Return true iff the column is NaN. * * @group normal_funcs - * @since 1.5.0 + * @since 1.6.0 + */ + def isnan(e: Column): Column = withExpr { IsNaN(e.expr) } + + /** + * Return true iff the column is null. + * + * @group normal_funcs + * @since 1.6.0 */ - def isNaN(e: Column): Column = withExpr { IsNaN(e.expr) } + def isnull(e: Column): Column = withExpr { IsNull(e.expr) } /** * A column expression that generates monotonically increasing 64-bit integers. @@ -750,7 +794,24 @@ object functions extends LegacyFunctions { * @group normal_funcs * @since 1.4.0 */ - def monotonicallyIncreasingId(): Column = withExpr { MonotonicallyIncreasingID() } + def monotonicallyIncreasingId(): Column = monotonically_increasing_id() + + /** + * A column expression that generates monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the record number + * within each partition in the lower 33 bits. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + * This expression would return the following IDs: + * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + * + * @group normal_funcs + * @since 1.6.0 + */ + def monotonically_increasing_id(): Column = withExpr { MonotonicallyIncreasingID() } /** * Returns col1 if it is not NaN, or col2 if col1 is NaN. @@ -825,15 +886,23 @@ object functions extends LegacyFunctions { */ def randn(): Column = randn(Utils.random.nextLong) + /** + * @group normal_funcs + * @since 1.4.0 + * @deprecated As of 1.6.0, replaced by `spark_partition_id`. This will be removed in Spark 2.0. + */ + @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") + def sparkPartitionId(): Column = spark_partition_id() + /** * Partition ID of the Spark task. * * Note that this is indeterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def sparkPartitionId(): Column = withExpr { SparkPartitionID() } + def spark_partition_id(): Column = withExpr { SparkPartitionID() } /** * Computes the square root of the specified float value. @@ -2305,6 +2374,17 @@ object functions extends LegacyFunctions { */ def explode(e: Column): Column = withExpr { Explode(e.expr) } + /** + * Extracts json object from a json string based on json path specified, and returns json string + * of the extracted json object. It will return null if the input json string is invalid. + * + * @group collection_funcs + * @since 1.6.0 + */ + def get_json_object(e: Column, path: String): Column = withExpr { + GetJsonObject(e.expr, lit(path).expr) + } + /** * Creates a new row for a json column according to the given field names. * @@ -2313,7 +2393,7 @@ object functions extends LegacyFunctions { */ @scala.annotation.varargs def json_tuple(json: Column, fields: String*): Column = withExpr { - require(fields.length > 0, "at least 1 field name should be given.") + require(fields.nonEmpty, "at least 1 field name should be given.") JsonTuple(json.expr +: fields.map(Literal.apply)) } From 8364c33978af3a2502ea6442f6ab3b14e3d6d535 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 24 Nov 2015 23:13:01 -0800 Subject: [PATCH 0916/1168] [SPARK-11979][STREAMING] Empty TrackStateRDD cannot be checkpointed and recovered from checkpoint file This solves the following exception caused when empty state RDD is checkpointed and recovered. The root cause is that an empty OpenHashMapBasedStateMap cannot be deserialized as the initialCapacity is set to zero. ``` Job aborted due to stage failure: Task 0 in stage 6.0 failed 1 times, most recent failure: Lost task 0.0 in stage 6.0 (TID 20, localhost): java.lang.IllegalArgumentException: requirement failed: Invalid initial capacity at scala.Predef$.require(Predef.scala:233) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.(StateMap.scala:96) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.(StateMap.scala:86) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.readObject(StateMap.scala:291) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:606) at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1017) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1893) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1798) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1350) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:1990) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1915) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1798) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1350) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:370) at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:76) at org.apache.spark.serializer.DeserializationStream$$anon$1.getNext(Serializer.scala:181) at org.apache.spark.util.NextIterator.hasNext(NextIterator.scala:73) at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:371) at scala.collection.Iterator$class.foreach(Iterator.scala:727) at scala.collection.AbstractIterator.foreach(Iterator.scala:1157) at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47) at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273) at scala.collection.AbstractIterator.to(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265) at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252) at scala.collection.AbstractIterator.toArray(Iterator.scala:1157) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:921) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:921) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66) at org.apache.spark.scheduler.Task.run(Task.scala:88) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:744) ``` Author: Tathagata Das Closes #9958 from tdas/SPARK-11979. --- .../spark/streaming/util/StateMap.scala | 19 +++++++----- .../spark/streaming/StateMapSuite.scala | 30 ++++++++++++------- .../streaming/rdd/TrackStateRDDSuite.scala | 10 +++++++ 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 34287c3e00908..3f139ad138c88 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -59,7 +59,7 @@ private[streaming] object StateMap { def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", DELTA_CHAIN_LENGTH_THRESHOLD) - new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold) + new OpenHashMapBasedStateMap[K, S](deltaChainThreshold) } } @@ -79,7 +79,7 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa /** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( @transient @volatile var parentStateMap: StateMap[K, S], - initialCapacity: Int = 64, + initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD ) extends StateMap[K, S] { self => @@ -89,12 +89,14 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( deltaChainThreshold = deltaChainThreshold) def this(deltaChainThreshold: Int) = this( - initialCapacity = 64, deltaChainThreshold = deltaChainThreshold) + initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold) def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) - @transient @volatile private var deltaMap = - new OpenHashMap[K, StateInfo[S]](initialCapacity) + require(initialCapacity >= 1, "Invalid initial capacity") + require(deltaChainThreshold >= 1, "Invalid delta chain threshold") + + @transient @volatile private var deltaMap = new OpenHashMap[K, StateInfo[S]](initialCapacity) /** Get the session data if it exists */ override def get(key: K): Option[S] = { @@ -284,9 +286,10 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( // Read the data of the parent map. Keep reading records, until the limiter is reached // First read the approximate number of records to expect and allocate properly size // OpenHashMap - val parentSessionStoreSizeHint = inputStream.readInt() + val parentStateMapSizeHint = inputStream.readInt() + val newStateMapInitialCapacity = math.max(parentStateMapSizeHint, DEFAULT_INITIAL_CAPACITY) val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( - initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold) + initialCapacity = newStateMapInitialCapacity, deltaChainThreshold) // Read the records until the limit marking object has been reached var parentSessionLoopDone = false @@ -338,4 +341,6 @@ private[streaming] object OpenHashMapBasedStateMap { class LimitMarker(val num: Int) extends Serializable val DELTA_CHAIN_LENGTH_THRESHOLD = 20 + + val DEFAULT_INITIAL_CAPACITY = 64 } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index 48d3b41b66cbf..c4a01eaea739e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -122,23 +122,27 @@ class StateMapSuite extends SparkFunSuite { test("OpenHashMapBasedStateMap - serializing and deserializing") { val map1 = new OpenHashMapBasedStateMap[Int, Int]() + testSerialization(map1, "error deserializing and serialized empty map") + map1.put(1, 100, 1) map1.put(2, 200, 2) + testSerialization(map1, "error deserializing and serialized map with data + no delta") val map2 = map1.copy() + // Do not test compaction + assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data") + map2.put(3, 300, 3) map2.put(4, 400, 4) + testSerialization(map2, "error deserializing and serialized map with 1 delta + new data") val map3 = map2.copy() + assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data") map3.put(3, 600, 3) map3.remove(2) - - // Do not test compaction - assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) - - val deser_map3 = Utils.deserialize[StateMap[Int, Int]]( - Utils.serialize(map3), Thread.currentThread().getContextClassLoader) - assertMap(deser_map3, map3, 1, "Deserialized map not same as original map") + testSerialization(map3, "error deserializing and serialized map with 2 delta + new data") } test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") { @@ -156,11 +160,9 @@ class StateMapSuite extends SparkFunSuite { assert(map.deltaChainLength > deltaChainThreshold) assert(map.shouldCompact === true) - val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]]( - Utils.serialize(map), Thread.currentThread().getContextClassLoader) + val deser_map = testSerialization(map, "Deserialized + compacted map not same as original map") assert(deser_map.deltaChainLength < deltaChainThreshold) assert(deser_map.shouldCompact === false) - assertMap(deser_map, map, 1, "Deserialized + compacted map not same as original map") } test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") { @@ -265,6 +267,14 @@ class StateMapSuite extends SparkFunSuite { assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") } + private def testSerialization[MapType <: StateMap[Int, Int]]( + map: MapType, msg: String): MapType = { + val deserMap = Utils.deserialize[MapType]( + Utils.serialize(map), Thread.currentThread().getContextClassLoader) + assertMap(deserMap, map, 1, msg) + deserMap + } + // Assert whether all the data and operations on a state map matches that of a reference state map private def assertMap( mapToTest: StateMap[Int, Int], diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index 0feb3af1abb0f..3b2d43f2ce581 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -332,6 +332,16 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) } + test("checkpointing empty state RDD") { + val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int]( + sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0)) + emptyStateRDD.checkpoint() + assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]]( + emptyStateRDD.getCheckpointFile.get) + assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + } + /** Assert whether the `trackStateByKey` operation generates expected results */ private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( testStateRDD: TrackStateRDD[K, V, S, T], From 324d3bf26e404b6cd9e5afdaf779e980eaa511d7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 25 Nov 2015 01:02:36 -0800 Subject: [PATCH 0917/1168] [SPARK-11970][SQL] Adding JoinType into JoinWith and support Sample in Dataset API Except inner join, maybe the other join types are also useful when users are using the joinWith function. Thus, added the joinType into the existing joinWith call in Dataset APIs. Also providing another joinWith interface for the cartesian-join-like functionality. Please provide your opinions. marmbrus rxin cloud-fan Thank you! Author: gatorsmile Closes #9921 from gatorsmile/joinWith. --- .../scala/org/apache/spark/sql/Dataset.scala | 45 +++++++++++++++---- .../org/apache/spark/sql/DatasetSuite.scala | 36 ++++++++++++--- 2 files changed, 65 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index dd84b8bc11e2b..97eb5b969280d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function._ - +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -83,7 +83,6 @@ class Dataset[T] private[sql]( /** * Returns the schema of the encoded form of the objects in this [[Dataset]]. - * * @since 1.6.0 */ def schema: StructType = resolvedTEncoder.schema @@ -185,7 +184,6 @@ class Dataset[T] private[sql]( * .transform(featurize) * .transform(...) * }}} - * * @since 1.6.0 */ def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) @@ -453,6 +451,21 @@ class Dataset[T] private[sql]( c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] + /** + * Returns a new [[Dataset]] by sampling a fraction of records. + * @since 1.6.0 + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] = + withPlan(Sample(0.0, fraction, withReplacement, seed, _)) + + /** + * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed. + * @since 1.6.0 + */ + def sample(withReplacement: Boolean, fraction: Double) : Dataset[T] = { + sample(withReplacement, fraction, Utils.random.nextLong) + } + /* **************** * * Set operations * * **************** */ @@ -511,13 +524,17 @@ class Dataset[T] private[sql]( * types as well as working with relational data where either side of the join has column * names in common. * + * @param other Right side of the join. + * @param condition Join expression. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. * @since 1.6.0 */ - def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { + def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { val left = this.logicalPlan val right = other.logicalPlan - val joined = sqlContext.executePlan(Join(left, right, Inner, Some(condition.expr))) + val joined = sqlContext.executePlan(Join(left, right, joinType = + JoinType(joinType), Some(condition.expr))) val leftOutput = joined.analyzed.output.take(left.output.length) val rightOutput = joined.analyzed.output.takeRight(right.output.length) @@ -540,6 +557,18 @@ class Dataset[T] private[sql]( } } + /** + * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair + * where `condition` evaluates to true. + * + * @param other Right side of the join. + * @param condition Join expression. + * @since 1.6.0 + */ + def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { + joinWith(other, condition, "inner") + } + /* ************************** * * Gather to Driver Actions * * ************************** */ @@ -584,7 +613,6 @@ class Dataset[T] private[sql]( * * Running take requires moving data into the application's driver process, and doing so with * a very large `n` can crash the driver process with OutOfMemoryError. - * * @since 1.6.0 */ def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect() @@ -594,7 +622,6 @@ class Dataset[T] private[sql]( * * Running take requires moving data into the application's driver process, and doing so with * a very large `n` can crash the driver process with OutOfMemoryError. - * * @since 1.6.0 */ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c253fdbb8c99e..7d539180ded9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -185,17 +185,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds2 = Seq(1, 2).toDS().as("b") checkAnswer( - ds1.joinWith(ds2, $"a.value" === $"b.value"), + ds1.joinWith(ds2, $"a.value" === $"b.value", "inner"), (1, 1), (2, 2)) } - test("joinWith, expression condition") { - val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() - val ds2 = Seq(("a", 1), ("b", 2)).toDS() + test("joinWith, expression condition, outer join") { + val nullInteger = null.asInstanceOf[Integer] + val nullString = null.asInstanceOf[String] + val ds1 = Seq(ClassNullableData("a", 1), + ClassNullableData("c", 3)).toDS() + val ds2 = Seq(("a", new Integer(1)), + ("b", new Integer(2))).toDS() checkAnswer( - ds1.joinWith(ds2, $"_1" === $"a"), - (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2))) + ds1.joinWith(ds2, $"_1" === $"a", "outer"), + (ClassNullableData("a", 1), ("a", new Integer(1))), + (ClassNullableData("c", 3), (nullString, nullInteger)), + (ClassNullableData(nullString, nullInteger), ("b", new Integer(2)))) } test("joinWith tuple with primitive, expression") { @@ -225,7 +231,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"), ((("a", 1), ("a", 1)), ("a", 1)), ((("b", 2), ("b", 2)), ("b", 2))) - } test("groupBy function, keys") { @@ -367,6 +372,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1 -> "a", 2 -> "bc", 3 -> "d") } + test("sample with replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDS() + checkAnswer( + data.sample(withReplacement = true, 0.05, seed = 13), + 5, 10, 52, 73) + } + + test("sample without replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDS() + checkAnswer( + data.sample(withReplacement = false, 0.05, seed = 13), + 3, 17, 27, 58, 62) + } + test("SPARK-11436: we should rebind right encoder when join 2 datasets") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") @@ -440,6 +461,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { case class ClassData(a: String, b: Int) +case class ClassNullableData(a: String, b: Integer) /** * A class used to test serialization using encoders. This class throws exceptions when using From 09c6a68edaf59374aa1913679137023bf5453e8a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 25 Nov 2015 01:03:18 -0800 Subject: [PATCH 0918/1168] [SPARK-11981][SQL] Move implementations of methods back to DataFrame from Queryable Also added show methods to Dataset. Author: Reynold Xin Closes #9964 from rxin/SPARK-11981. --- .../org/apache/spark/sql/DataFrame.scala | 35 ++++++++- .../scala/org/apache/spark/sql/Dataset.scala | 77 ++++++++++++++++++- .../spark/sql/execution/Queryable.scala | 32 ++------ 3 files changed, 111 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5eca1db9525ec..d8319b9a97fcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -112,8 +112,8 @@ private[sql] object DataFrame { */ @Experimental class DataFrame private[sql]( - @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: QueryExecution) + @transient override val sqlContext: SQLContext, + @DeveloperApi @transient override val queryExecution: QueryExecution) extends Queryable with Serializable { // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure @@ -282,6 +282,35 @@ class DataFrame private[sql]( */ def schema: StructType = queryExecution.analyzed.schema + /** + * Prints the schema to the console in a nice tree format. + * @group basic + * @since 1.3.0 + */ + // scalastyle:off println + override def printSchema(): Unit = println(schema.treeString) + // scalastyle:on println + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @group basic + * @since 1.3.0 + */ + override def explain(extended: Boolean): Unit = { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + // scalastyle:off println + r => println(r.getString(0)) + // scalastyle:on println + } + } + + /** + * Prints the physical plan to the console for debugging purposes. + * @since 1.3.0 + */ + override def explain(): Unit = explain(extended = false) + /** * Returns all column names and their data types as an array. * @group basic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 97eb5b969280d..da4600133290f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -61,8 +61,8 @@ import org.apache.spark.util.Utils */ @Experimental class Dataset[T] private[sql]( - @transient val sqlContext: SQLContext, - @transient val queryExecution: QueryExecution, + @transient override val sqlContext: SQLContext, + @transient override val queryExecution: QueryExecution, tEncoder: Encoder[T]) extends Queryable with Serializable { /** @@ -85,7 +85,25 @@ class Dataset[T] private[sql]( * Returns the schema of the encoded form of the objects in this [[Dataset]]. * @since 1.6.0 */ - def schema: StructType = resolvedTEncoder.schema + override def schema: StructType = resolvedTEncoder.schema + + /** + * Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format. + * @since 1.6.0 + */ + override def printSchema(): Unit = toDF().printSchema() + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @since 1.6.0 + */ + override def explain(extended: Boolean): Unit = toDF().explain(extended) + + /** + * Prints the physical plan to the console for debugging purposes. + * @since 1.6.0 + */ + override def explain(): Unit = toDF().explain() /* ************* * * Conversions * @@ -152,6 +170,59 @@ class Dataset[T] private[sql]( */ def count(): Long = toDF().count() + /** + * Displays the content of this [[Dataset]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * + * @since 1.6.0 + */ + def show(numRows: Int): Unit = show(numRows, truncate = true) + + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. + * + * @since 1.6.0 + */ + def show(): Unit = show(20) + + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @since 1.6.0 + */ + def show(truncate: Boolean): Unit = show(20, truncate) + + /** + * Displays the [[DataFrame]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @since 1.6.0 + */ + def show(numRows: Int, truncate: Boolean): Unit = toDF().show(numRows, truncate) + /** * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index 321e2c783537f..f2f5997d1b7c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution +import scala.util.control.NonFatal + import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType -import scala.util.control.NonFatal - /** A trait that holds shared code between DataFrames and Datasets. */ private[sql] trait Queryable { def schema: StructType @@ -37,31 +37,9 @@ private[sql] trait Queryable { } } - /** - * Prints the schema to the console in a nice tree format. - * @group basic - * @since 1.3.0 - */ - // scalastyle:off println - def printSchema(): Unit = println(schema.treeString) - // scalastyle:on println + def printSchema(): Unit - /** - * Prints the plans (logical and physical) to the console for debugging purposes. - * @since 1.3.0 - */ - def explain(extended: Boolean): Unit = { - val explain = ExplainCommand(queryExecution.logical, extended = extended) - sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { - // scalastyle:off println - r => println(r.getString(0)) - // scalastyle:on println - } - } + def explain(extended: Boolean): Unit - /** - * Only prints the physical plan to the console for debugging purposes. - * @since 1.3.0 - */ - def explain(): Unit = explain(extended = false) + def explain(): Unit } From d4fc583a0016dab3e2014a6562a09a00e08d5854 Mon Sep 17 00:00:00 2001 From: Ashwin Swaroop Date: Wed, 25 Nov 2015 13:41:14 +0000 Subject: [PATCH 0919/1168] [SPARK-11686][CORE] Issue WARN when dynamic allocation is disabled due to spark.dynamicAllocation.enabled and spark.executor.instances both set Changed the log type to a 'warning' instead of 'info' as required. Author: Ashwin Swaroop Closes #9926 from ashwinswaroop/master. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e19ba113702c6..2c10779f2b893 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -556,7 +556,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Optionally scale number of executors dynamically based on workload. Exposed for testing. val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { - logInfo("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") } _executorAllocationManager = From daf54f0a103488bc5ea17545e26561da66a97a62 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 25 Nov 2015 13:49:58 +0000 Subject: [PATCH 0920/1168] =?UTF-8?q?[SPARK-11860][PYSAPRK][DOCUMENTATION]?= =?UTF-8?q?=20Invalid=20argument=20specification=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …for registerFunction [Python] Straightforward change on the python doc Author: Jeff Zhang Closes #9901 from zjffdu/SPARK-11860. --- python/pyspark/sql/context.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 5a85ac31025e8..a49c1b58d0180 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -195,14 +195,15 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=StringType()): - """Registers a lambda function as a UDF so it can be used in SQL statements. + """Registers a python function (including lambda function) as a UDF + so it can be used in SQL statements. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param samplingRatio: lambda function + :param f: python function :param returnType: a :class:`DataType` object >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) From 0132b472fbf34ce4b220317d6874f1a774a3d70e Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Wed, 25 Nov 2015 09:34:34 -0600 Subject: [PATCH 0921/1168] [SPARK-10666][SPARK-6880][CORE] Use properties from ActiveJob associated with a Stage This issue was addressed in https://github.com/apache/spark/pull/5494, but the fix in that PR, while safe in the sense that it will prevent the SparkContext from shutting down, misses the actual bug. The intent of `submitMissingTasks` should be understood as "submit the Tasks that are missing for the Stage, and run them as part of the ActiveJob identified by jobId". Because of a long-standing bug, the `jobId` parameter was never being used. Instead, we were trying to use the jobId with which the Stage was created -- which may no longer exist as an ActiveJob, hence the crash reported in SPARK-6880. The correct fix is to use the ActiveJob specified by the supplied jobId parameter, which is guaranteed to exist at the call sites of submitMissingTasks. This fix should be applied to all maintenance branches, since it has existed since 1.0. kayousterhout pankajarora12 Author: Mark Hamstra Author: Imran Rashid Closes #6291 from markhamstra/SPARK-6880. --- .../apache/spark/scheduler/DAGScheduler.scala | 6 +- .../spark/scheduler/DAGSchedulerSuite.scala | 107 +++++++++++++++++- 2 files changed, 109 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 77a184dfe4bee..e01a9609b9a0d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -946,7 +946,9 @@ class DAGScheduler( stage.resetInternalAccumulators() } - val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull + // Use the scheduling pool, job group, description, etc. from an ActiveJob associated + // with this Stage + val properties = jobIdToActiveJob(jobId).properties runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are @@ -1047,7 +1049,7 @@ class DAGScheduler( stage.pendingPartitions ++= tasks.map(_.partitionId) logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 4d6b25455226f..653d41fc053c9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.util.Properties + import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal @@ -262,9 +264,10 @@ class DAGSchedulerSuite rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, - listener: JobListener = jobListener): Int = { + listener: JobListener = jobListener, + properties: Properties = null): Int = { val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener)) + runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener, properties)) jobId } @@ -1322,6 +1325,106 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + def checkJobPropertiesAndPriority(taskSet: TaskSet, expected: String, priority: Int): Unit = { + assert(taskSet.properties != null) + assert(taskSet.properties.getProperty("testProperty") === expected) + assert(taskSet.priority === priority) + } + + def launchJobsThatShareStageAndCancelFirst(): ShuffleDependency[Int, Int, Nothing] = { + val baseRdd = new MyRDD(sc, 1, Nil) + val shuffleDep1 = new ShuffleDependency(baseRdd, new HashPartitioner(1)) + val intermediateRdd = new MyRDD(sc, 1, List(shuffleDep1)) + val shuffleDep2 = new ShuffleDependency(intermediateRdd, new HashPartitioner(1)) + val finalRdd1 = new MyRDD(sc, 1, List(shuffleDep2)) + val finalRdd2 = new MyRDD(sc, 1, List(shuffleDep2)) + val job1Properties = new Properties() + val job2Properties = new Properties() + job1Properties.setProperty("testProperty", "job1") + job2Properties.setProperty("testProperty", "job2") + + // Run jobs 1 & 2, both referencing the same stage, then cancel job1. + // Note that we have to submit job2 before we cancel job1 to have them actually share + // *Stages*, and not just shuffle dependencies, due to skipped stages (at least until + // we address SPARK-10193.) + val jobId1 = submit(finalRdd1, Array(0), properties = job1Properties) + val jobId2 = submit(finalRdd2, Array(0), properties = job2Properties) + assert(scheduler.activeJobs.nonEmpty) + val testProperty1 = scheduler.jobIdToActiveJob(jobId1).properties.getProperty("testProperty") + + // remove job1 as an ActiveJob + cancel(jobId1) + + // job2 should still be running + assert(scheduler.activeJobs.nonEmpty) + val testProperty2 = scheduler.jobIdToActiveJob(jobId2).properties.getProperty("testProperty") + assert(testProperty1 != testProperty2) + // NB: This next assert isn't necessarily the "desired" behavior; it's just to document + // the current behavior. We've already submitted the TaskSet for stage 0 based on job1, but + // even though we have cancelled that job and are now running it because of job2, we haven't + // updated the TaskSet's properties. Changing the properties to "job2" is likely the more + // correct behavior. + val job1Id = 0 // TaskSet priority for Stages run with "job1" as the ActiveJob + checkJobPropertiesAndPriority(taskSets(0), "job1", job1Id) + complete(taskSets(0), Seq((Success, makeMapStatus("hostA", 1)))) + + shuffleDep1 + } + + /** + * Makes sure that tasks for a stage used by multiple jobs are submitted with the properties of a + * later, active job if they were previously run under a job that is no longer active + */ + test("stage used by two jobs, the first no longer active (SPARK-6880)") { + launchJobsThatShareStageAndCancelFirst() + + // The next check is the key for SPARK-6880. For the stage which was shared by both job1 and + // job2 but never had any tasks submitted for job1, the properties of job2 are now used to run + // the stage. + checkJobPropertiesAndPriority(taskSets(1), "job2", 1) + + complete(taskSets(1), Seq((Success, makeMapStatus("hostA", 1)))) + assert(taskSets(2).properties != null) + complete(taskSets(2), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assert(scheduler.activeJobs.isEmpty) + + assertDataStructuresEmpty() + } + + /** + * Makes sure that tasks for a stage used by multiple jobs are submitted with the properties of a + * later, active job if they were previously run under a job that is no longer active, even when + * there are fetch failures + */ + test("stage used by two jobs, some fetch failures, and the first job no longer active " + + "(SPARK-6880)") { + val shuffleDep1 = launchJobsThatShareStageAndCancelFirst() + val job2Id = 1 // TaskSet priority for Stages run with "job2" as the ActiveJob + + // lets say there is a fetch failure in this task set, which makes us go back and + // run stage 0, attempt 1 + complete(taskSets(1), Seq( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + + // stage 0, attempt 1 should have the properties of job2 + assert(taskSets(2).stageId === 0) + assert(taskSets(2).stageAttemptId === 1) + checkJobPropertiesAndPriority(taskSets(2), "job2", job2Id) + + // run the rest of the stages normally, checking that they have the correct properties + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + checkJobPropertiesAndPriority(taskSets(3), "job2", job2Id) + complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 1)))) + checkJobPropertiesAndPriority(taskSets(4), "job2", job2Id) + complete(taskSets(4), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assert(scheduler.activeJobs.isEmpty) + + assertDataStructuresEmpty() + } + test("run trivial shuffle with out-of-band failure and retry") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) From c319c4abfceae4c4f3d0ec9a898603195a8aa417 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 25 Nov 2015 09:47:20 -0800 Subject: [PATCH 0922/1168] [SPARK-11956][CORE] Fix a few bugs in network lib-based file transfer. - NettyRpcEnv::openStream() now correctly propagates errors to the read side of the pipe. - NettyStreamManager now throws if the file being transferred does not exist. - The network library now correctly handles zero-sized streams. Author: Marcelo Vanzin Closes #9941 from vanzin/SPARK-11956. --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 19 +++++++++---- .../spark/rpc/netty/NettyStreamManager.scala | 2 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 27 +++++++++++++----- .../client/TransportResponseHandler.java | 28 ++++++++++++------- .../org/apache/spark/network/StreamSuite.java | 23 ++++++++++++++- 5 files changed, 75 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 68701f609f77a..c8fa870f50e68 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -27,7 +27,7 @@ import javax.annotation.Nullable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag -import scala.util.{DynamicVariable, Failure, Success} +import scala.util.{DynamicVariable, Failure, Success, Try} import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -368,13 +368,22 @@ private[netty] class NettyRpcEnv( @volatile private var error: Throwable = _ - def setError(e: Throwable): Unit = error = e + def setError(e: Throwable): Unit = { + error = e + source.close() + } override def read(dst: ByteBuffer): Int = { - if (error != null) { - throw error + val result = if (error == null) { + Try(source.read(dst)) + } else { + Failure(error) + } + + result match { + case Success(bytesRead) => bytesRead + case Failure(error) => throw error } - source.read(dst) } override def close(): Unit = source.close() diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index eb1d2604fb235..a2768b4252dcb 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -44,7 +44,7 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") } - require(file != null, s"File not found: $streamId") + require(file != null && file.isFile(), s"File not found: $streamId") new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 2b664c6313efa..6cc958a5f6bc8 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -729,23 +729,36 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val tempDir = Utils.createTempDir() val file = new File(tempDir, "file") Files.write(UUID.randomUUID().toString(), file, UTF_8) + val empty = new File(tempDir, "empty") + Files.write("", empty, UTF_8); val jar = new File(tempDir, "jar") Files.write(UUID.randomUUID().toString(), jar, UTF_8) val fileUri = env.fileServer.addFile(file) + val emptyUri = env.fileServer.addFile(empty) val jarUri = env.fileServer.addJar(jar) val destDir = Utils.createTempDir() - val destFile = new File(destDir, file.getName()) - val destJar = new File(destDir, jar.getName()) - val sm = new SecurityManager(conf) val hc = SparkHadoopUtil.get.conf - Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false) - Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false) - assert(Files.equal(file, destFile)) - assert(Files.equal(jar, destJar)) + val files = Seq( + (file, fileUri), + (empty, emptyUri), + (jar, jarUri)) + files.foreach { case (f, uri) => + val destFile = new File(destDir, f.getName()) + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + assert(Files.equal(f, destFile)) + } + + // Try to download files that do not exist. + Seq("files", "jars").foreach { root => + intercept[Exception] { + val uri = env.address.toSparkURL + s"/$root/doesNotExist" + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + } + } } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index be181e0660826..4c15045363b84 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -185,16 +185,24 @@ public void handle(ResponseMessage message) { StreamResponse resp = (StreamResponse) message; StreamCallback callback = streamCallbacks.poll(); if (callback != null) { - StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, - callback); - try { - TransportFrameDecoder frameDecoder = (TransportFrameDecoder) - channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); - frameDecoder.setInterceptor(interceptor); - streamActive = true; - } catch (Exception e) { - logger.error("Error installing stream handler.", e); - deactivateStream(); + if (resp.byteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, + callback); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + callback.onComplete(resp.streamId); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } } } else { logger.error("Could not find callback for StreamResponse."); diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 00158fd081626..538f3efe8d6f2 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -51,13 +51,14 @@ import org.apache.spark.network.util.TransportConf; public class StreamSuite { - private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "file" }; + private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; private static TransportServer server; private static TransportClientFactory clientFactory; private static File testFile; private static File tempDir; + private static ByteBuffer emptyBuffer; private static ByteBuffer smallBuffer; private static ByteBuffer largeBuffer; @@ -73,6 +74,7 @@ private static ByteBuffer createBuffer(int bufSize) { @BeforeClass public static void setUp() throws Exception { tempDir = Files.createTempDir(); + emptyBuffer = createBuffer(0); smallBuffer = createBuffer(100); largeBuffer = createBuffer(100000); @@ -103,6 +105,8 @@ public ManagedBuffer openStream(String streamId) { return new NioManagedBuffer(largeBuffer); case "smallBuffer": return new NioManagedBuffer(smallBuffer); + case "emptyBuffer": + return new NioManagedBuffer(emptyBuffer); case "file": return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); default: @@ -138,6 +142,18 @@ public static void tearDown() { } } + @Test + public void testZeroLengthStream() throws Throwable { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5)); + task.run(); + task.check(); + } finally { + client.close(); + } + } + @Test public void testSingleStream() throws Throwable { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); @@ -226,6 +242,11 @@ public void run() { outFile = File.createTempFile("data", ".tmp", tempDir); out = new FileOutputStream(outFile); break; + case "emptyBuffer": + baos = new ByteArrayOutputStream(); + out = baos; + srcBuffer = emptyBuffer; + break; default: throw new IllegalArgumentException(streamId); } From a54c1b671a18dd5d53c07ff5ca2452e9df17d69d Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 25 Nov 2015 10:36:35 -0800 Subject: [PATCH 0923/1168] [SPARK-11984][SQL][PYTHON] Fix typos in doc for pivot for scala and python Author: felixcheung Closes #9967 from felixcheung/pypivotdoc. --- python/pyspark/sql/group.py | 6 +++--- .../src/main/scala/org/apache/spark/sql/GroupedData.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index d8ed7eb2dda64..1911588309aff 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -169,11 +169,11 @@ def sum(self, *cols): @since(1.6) def pivot(self, pivot_col, values=None): - """Pivots a column of the current DataFrame and preform the specified aggregation. + """Pivots a column of the current DataFrame and perform the specified aggregation. :param pivot_col: Column to pivot - :param values: Optional list of values of pivotColumn that will be translated to columns in - the output data frame. If values are not provided the method with do an immediate call + :param values: Optional list of values of pivot column that will be translated to columns in + the output DataFrame. If values are not provided the method will do an immediate call to .distinct() on the pivot column. >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index abd531c4ba541..13341a88a6b74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -282,7 +282,7 @@ class GroupedData protected[sql]( } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. @@ -321,7 +321,7 @@ class GroupedData protected[sql]( } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. @@ -353,7 +353,7 @@ class GroupedData protected[sql]( } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. From 82aefee79ecd678a3e37b802c0d570b7ceca46bc Mon Sep 17 00:00:00 2001 From: Zhongshuai Pei Date: Wed, 25 Nov 2015 10:37:34 -0800 Subject: [PATCH 0924/1168] [SPARK-11974][CORE] Not all the temp dirs had been deleted when the JVM exits deleting the temp dir like that ``` scala> import scala.collection.mutable import scala.collection.mutable scala> val a = mutable.Set(1,2,3,4,7,0,8,98,9) a: scala.collection.mutable.Set[Int] = Set(0, 9, 1, 2, 3, 7, 4, 8, 98) scala> a.foreach(x => {a.remove(x) }) scala> a.foreach(println(_)) 98 ``` You may not modify a collection while traversing or iterating over it.This can not delete all element of the collection Author: Zhongshuai Pei Closes #9951 from DoingDone9/Bug_RemainDir. --- .../scala/org/apache/spark/util/ShutdownHookManager.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index db4a8b304ec3e..4012dca3ecdf8 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -57,7 +57,9 @@ private[spark] object ShutdownHookManager extends Logging { // Add a shutdown hook to delete the temp dirs when the JVM exits addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => logInfo("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => + // we need to materialize the paths to delete because deleteRecursively removes items from + // shutdownDeletePaths as we are traversing through it. + shutdownDeletePaths.toArray.foreach { dirPath => try { logInfo("Deleting directory " + dirPath) Utils.deleteRecursively(new File(dirPath)) From 720a5d19b6aa6822aa19e790d69a9735fccd23d9 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 25 Nov 2015 11:11:39 -0800 Subject: [PATCH 0925/1168] [SPARK-11969] [SQL] [PYSPARK] visualization of SQL query for pyspark Currently, we does not have visualization for SQL query from Python, this PR fix that. cc zsxwing Author: Davies Liu Closes #9949 from davies/pyspark_sql_ui. --- python/pyspark/sql/dataframe.py | 2 +- .../main/scala/org/apache/spark/sql/DataFrame.scala | 7 +++++++ .../org/apache/spark/sql/execution/python.scala | 12 +++++++----- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0dd75ba7ca820..746bb55e14f22 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -277,7 +277,7 @@ def collect(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) + port = self._jdf.collectToPython() return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d8319b9a97fcf..6197f10813a3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -29,6 +29,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.PythonRDD import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ @@ -1735,6 +1736,12 @@ class DataFrame private[sql]( EvaluatePython.javaToPython(rdd) } + protected[sql] def collectToPython(): Int = { + withNewExecutionId { + PythonRDD.collectAndServe(javaToPython.rdd) + } + } + //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// // Deprecated methods diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index d611b0011da16..defcec95fb555 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -121,11 +121,13 @@ object EvaluatePython { def takeAndServe(df: DataFrame, n: Int): Int = { registerPicklers() - val iter = new SerDeUtil.AutoBatchedPickler( - df.queryExecution.executedPlan.executeTake(n).iterator.map { row => - EvaluatePython.toJava(row, df.schema) - }) - PythonRDD.serveIterator(iter, s"serve-DataFrame") + df.withNewExecutionId { + val iter = new SerDeUtil.AutoBatchedPickler( + df.queryExecution.executedPlan.executeTake(n).iterator.map { row => + EvaluatePython.toJava(row, df.schema) + }) + PythonRDD.serveIterator(iter, s"serve-DataFrame") + } } /** From 3aa7ebb7494f34c21863968a32563d95efea3832 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 25 Nov 2015 11:35:52 -0800 Subject: [PATCH 0926/1168] [MINOR] Remove unnecessary spaces in `include_example.rb` Author: Yu ISHIKAWA Closes #9960 from yu-iskw/minor-remove-spaces. --- docs/_plugins/include_example.rb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 549f81fe1b1bc..564c86680f68e 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -20,12 +20,12 @@ module Jekyll class IncludeExampleTag < Liquid::Tag - + def initialize(tag_name, markup, tokens) @markup = markup super end - + def render(context) site = context.registers[:site] config_dir = '../examples/src/main' @@ -37,7 +37,7 @@ def render(context) code = File.open(@file).read.encode("UTF-8") code = select_lines(code) - + rendered_code = Pygments.highlight(code, :lexer => @lang) hint = "
    Find full example code at " \ @@ -45,7 +45,7 @@ def render(context) rendered_code + hint end - + # Trim the code block so as to have the same indention, regardless of their positions in the # code file. def trim_codeblock(lines) From d44191140766e59faf20488cad28bfaa02084bed Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 25 Nov 2015 11:37:42 -0800 Subject: [PATCH 0927/1168] [DOCUMENTATION] Fix minor doc error Author: Jeff Zhang Closes #9956 from zjffdu/dev_typo. --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 4de202d7f7631..741d6b2b37a87 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -35,7 +35,7 @@ val sc = new SparkContext(conf) {% endhighlight %} Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may -actually require one to prevent any sort of starvation issues. +actually require more than 1 thread to prevent any sort of starvation issues. Properties that specify some time duration should be configured with a unit of time. The following format is accepted: From 6d45ed8f127877316dac5483979b5e512c30f123 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Wed, 25 Nov 2015 11:39:00 -0800 Subject: [PATCH 0928/1168] [SPARK-10864][WEB UI] app name is hidden if window is resized Currently the Web UI navbar has a minimum width of 1200px; so if a window is resized smaller than that the app name goes off screen. The 1200px width seems to have been chosen since it fits the longest example app name without wrapping. To work with smaller window widths I made the tabs wrap since it looked better than wrapping the app name. This is a distinct change in how the navbar looks and I'm not sure if it's what we actually want to do. Other notes: - min-width set to 600px to keep the tabs from wrapping individually (will need to be adjusted if tabs are added) - app name will also wrap (making three levels) if a really really long app name is used Author: Alex Bozarth Closes #9874 from ajbozarth/spark10864. --- .../main/resources/org/apache/spark/ui/static/webui.css | 8 ++------ core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 04f3070d25b4a..c628a0c706553 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -16,14 +16,9 @@ */ .navbar { - height: 50px; font-size: 15px; margin-bottom: 15px; - min-width: 1200px -} - -.navbar .navbar-inner { - height: 50px; + min-width: 600px; } .navbar .brand { @@ -46,6 +41,7 @@ .navbar-text { height: 50px; line-height: 3.3; + white-space: nowrap; } table.sortable thead { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 84a1116a5c498..1e8194f57888e 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -210,10 +210,10 @@ private[spark] object UIUtils extends Logging { {org.apache.spark.SPARK_VERSION}
    - +
    From 29556bfa56aa59344f383df30937d79e6b00f07a Mon Sep 17 00:00:00 2001 From: wangt Date: Wed, 25 Nov 2015 11:41:05 -0800 Subject: [PATCH 0929/1168] [SPARK-11880][WINDOWS][SPARK SUBMIT] bin/load-spark-env.cmd loads spark-env.cmd from wrong directory * On windows the `bin/load-spark-env.cmd` tries to load `spark-env.cmd` from `%~dp0..\..\conf`, where `~dp0` points to `bin` and `conf` is only one level up. * Updated `bin/load-spark-env.cmd` to load `spark-env.cmd` from `%~dp0..\conf`, instead of `%~dp0..\..\conf` Author: wangt Closes #9863 from toddwan/master. --- bin/load-spark-env.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd index 36d932c453b6f..59080edd294f2 100644 --- a/bin/load-spark-env.cmd +++ b/bin/load-spark-env.cmd @@ -27,7 +27,7 @@ if [%SPARK_ENV_LOADED%] == [] ( if not [%SPARK_CONF_DIR%] == [] ( set user_conf_dir=%SPARK_CONF_DIR% ) else ( - set user_conf_dir=%~dp0..\..\conf + set user_conf_dir=%~dp0..\conf ) call :LoadSparkEnv From fd0d48b1120c2d4e831d74bd412a1c995f1c2a6a Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 25 Nov 2015 11:42:53 -0800 Subject: [PATCH 0930/1168] [SPARK-10558][CORE] Fix wrong executor state in Master `ExecutorAdded` can only be sent to `AppClient` when worker report back the executor state as `LOADING`, otherwise because of concurrency issue, `AppClient` will possibly receive `ExectuorAdded` at first, then `ExecutorStateUpdated` with `LOADING` state. Also Master will change the executor state from `LAUNCHING` to `RUNNING` (`AppClient` report back the state as `RUNNING`), then to `LOADING` (worker report back to state as `LOADING`), it should be `LAUNCHING` -> `LOADING` -> `RUNNING`. Also it is wrongly shown in master UI, the state of executor should be `RUNNING` rather than `LOADING`: ![screen shot 2015-09-11 at 2 30 28 pm](https://cloud.githubusercontent.com/assets/850797/9809254/3155d840-5899-11e5-8cdf-ad06fef75762.png) Author: jerryshao Closes #8714 from jerryshao/SPARK-10558. --- .../org/apache/spark/deploy/ExecutorState.scala | 2 +- .../org/apache/spark/deploy/client/AppClient.scala | 3 --- .../org/apache/spark/deploy/master/Master.scala | 14 +++++++++++--- .../org/apache/spark/deploy/worker/Worker.scala | 2 +- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index efa88c62e1f5d..69c98e28931d7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy private[deploy] object ExecutorState extends Enumeration { - val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST, EXITED = Value + val LAUNCHING, RUNNING, KILLED, FAILED, LOST, EXITED = Value type ExecutorState = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index afab362e213b5..df6ba7d669ce9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -178,9 +178,6 @@ private[spark] class AppClient( val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not - // guaranteed), `ExecutorStateChanged` may be sent to a dead master. - sendToMaster(ExecutorStateChanged(appId.get, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index b25a487806c7f..9952c97dbdffc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -253,9 +253,17 @@ private[deploy] class Master( execOption match { case Some(exec) => { val appInfo = idToApp(appId) + val oldState = exec.state exec.state = state - if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } + + if (state == ExecutorState.RUNNING) { + assert(oldState == ExecutorState.LAUNCHING, + s"executor $execId state transfer from $oldState to RUNNING is illegal") + appInfo.resetRetryCount() + } + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) + if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") @@ -702,8 +710,8 @@ private[deploy] class Master( worker.addExecutor(exec) worker.endpoint.send(LaunchExecutor(masterUrl, exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)) - exec.application.driver.send(ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) + exec.application.driver.send( + ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) } private def registerWorker(worker: WorkerInfo): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index a45867e7680ec..418faf8fc967f 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -469,7 +469,7 @@ private[deploy] class Worker( executorDir, workerUri, conf, - appLocalDirs, ExecutorState.LOADING) + appLocalDirs, ExecutorState.RUNNING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ From a7b1fce358f3a04dba4e07b109e4325a1546b35c Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 25 Nov 2015 11:47:21 -0800 Subject: [PATCH 0931/1168] [SPARK-11935][PYSPARK] Send the Python exceptions in TransformFunction and TransformFunctionSerializer to Java The Python exception track in TransformFunction and TransformFunctionSerializer is not sent back to Java. Py4j just throws a very general exception, which is hard to debug. This PRs adds `getFailure` method to get the failure message in Java side. Author: Shixiong Zhu Closes #9922 from zsxwing/SPARK-11935. --- python/pyspark/streaming/tests.py | 82 ++++++++++++++++++- python/pyspark/streaming/util.py | 29 +++++-- .../streaming/api/python/PythonDStream.scala | 52 ++++++++++-- 3 files changed, 144 insertions(+), 19 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a0e0267cafa58..d380d697bc51c 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -404,17 +404,69 @@ def func(dstream): self._test_func(input, func, expected) def test_failed_func(self): + # Test failure in + # TransformFunction.apply(rdd: Option[RDD[_]], time: Time) input = [self.sc.parallelize([d], 1) for d in range(4)] input_stream = self.ssc.queueStream(input) def failed_func(i): - raise ValueError("failed") + raise ValueError("This is a special error") input_stream.map(failed_func).pprint() self.ssc.start() try: self.ssc.awaitTerminationOrTimeout(10) except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) + return + + self.fail("a failed func should throw an error") + + def test_failed_func2(self): + # Test failure in + # TransformFunction.apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time) + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream1 = self.ssc.queueStream(input) + input_stream2 = self.ssc.queueStream(input) + + def failed_func(rdd1, rdd2): + raise ValueError("This is a special error") + + input_stream1.transformWith(failed_func, input_stream2, True).pprint() + self.ssc.start() + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) + return + + self.fail("a failed func should throw an error") + + def test_failed_func_with_reseting_failure(self): + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream = self.ssc.queueStream(input) + + def failed_func(i): + if i == 1: + # Make it fail in the second batch + raise ValueError("This is a special error") + else: + return i + + # We should be able to see the results of the 3rd and 4th batches even if the second batch + # fails + expected = [[0], [2], [3]] + self.assertEqual(expected, self._collect(input_stream.map(failed_func), 3)) + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) return self.fail("a failed func should throw an error") @@ -780,6 +832,34 @@ def tearDown(self): if self.cpd is not None: shutil.rmtree(self.cpd) + def test_transform_function_serializer_failure(self): + inputd = tempfile.mkdtemp() + self.cpd = tempfile.mkdtemp("test_transform_function_serializer_failure") + + def setup(): + conf = SparkConf().set("spark.default.parallelism", 1) + sc = SparkContext(conf=conf) + ssc = StreamingContext(sc, 0.5) + + # A function that cannot be serialized + def process(time, rdd): + sc.parallelize(range(1, 10)) + + ssc.textFileStream(inputd).foreachRDD(process) + return ssc + + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) + try: + self.ssc.start() + except: + import traceback + failure = traceback.format_exc() + self.assertTrue( + "It appears that you are attempting to reference SparkContext" in failure) + return + + self.fail("using SparkContext in process should fail because it's not Serializable") + def test_get_or_create_and_get_active_or_create(self): inputd = tempfile.mkdtemp() outputd = tempfile.mkdtemp() + "/" diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 767c732eb90b4..c7f02bca2ae38 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -38,12 +38,15 @@ def __init__(self, ctx, func, *deserializers): self.func = func self.deserializers = deserializers self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + self.failure = None def rdd_wrapper(self, func): self._rdd_wrapper = func return self def call(self, milliseconds, jrdds): + # Clear the failure + self.failure = None try: if self.ctx is None: self.ctx = SparkContext._active_spark_context @@ -62,9 +65,11 @@ def call(self, milliseconds, jrdds): r = self.func(t, *rdds) if r: return r._jrdd - except Exception: - traceback.print_exc() - raise + except: + self.failure = traceback.format_exc() + + def getLastFailure(self): + return self.failure def __repr__(self): return "TransformFunction(%s)" % self.func @@ -89,22 +94,28 @@ def __init__(self, ctx, serializer, gateway=None): self.serializer = serializer self.gateway = gateway or self.ctx._gateway self.gateway.jvm.PythonDStream.registerSerializer(self) + self.failure = None def dumps(self, id): + # Clear the failure + self.failure = None try: func = self.gateway.gateway_property.pool[id] return bytearray(self.serializer.dumps((func.func, func.deserializers))) - except Exception: - traceback.print_exc() - raise + except: + self.failure = traceback.format_exc() def loads(self, data): + # Clear the failure + self.failure = None try: f, deserializers = self.serializer.loads(bytes(data)) return TransformFunction(self.ctx, f, *deserializers) - except Exception: - traceback.print_exc() - raise + except: + self.failure = traceback.format_exc() + + def getLastFailure(self): + return self.failure def __repr__(self): return "TransformFunctionSerializer(%s)" % self.serializer diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index dfc569451df86..994309ddd0a3e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -26,6 +26,7 @@ import scala.language.existentials import py4j.GatewayServer +import org.apache.spark.SparkException import org.apache.spark.api.java._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -40,6 +41,13 @@ import org.apache.spark.util.Utils */ private[python] trait PythonTransformFunction { def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] + + /** + * Get the failure, if any, in the last call to `call`. + * + * @return the failure message if there was a failure, or `null` if there was no failure. + */ + def getLastFailure: String } /** @@ -48,6 +56,13 @@ private[python] trait PythonTransformFunction { private[python] trait PythonTransformFunctionSerializer { def dumps(id: String): Array[Byte] def loads(bytes: Array[Byte]): PythonTransformFunction + + /** + * Get the failure, if any, in the last call to `dumps` or `loads`. + * + * @return the failure message if there was a failure, or `null` if there was no failure. + */ + def getLastFailure: String } /** @@ -59,18 +74,27 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] { def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)) - .map(_.rdd) + val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava + Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd) } def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava - Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd) + Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd) } // for function.Function2 def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { - pfunc.call(time.milliseconds, rdds) + callPythonTransformFunction(time.milliseconds, rdds) + } + + private def callPythonTransformFunction(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] = { + val resultRDD = pfunc.call(time, rdds) + val failure = pfunc.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + resultRDD } private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { @@ -103,23 +127,33 @@ private[python] object PythonTransformFunctionSerializer { /* * Register a serializer from Python, should be called during initialization */ - def register(ser: PythonTransformFunctionSerializer): Unit = { + def register(ser: PythonTransformFunctionSerializer): Unit = synchronized { serializer = ser } - def serialize(func: PythonTransformFunction): Array[Byte] = { + def serialize(func: PythonTransformFunction): Array[Byte] = synchronized { require(serializer != null, "Serializer has not been registered!") // get the id of PythonTransformFunction in py4j val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) val f = h.getClass().getDeclaredField("id") f.setAccessible(true) val id = f.get(h).asInstanceOf[String] - serializer.dumps(id) + val results = serializer.dumps(id) + val failure = serializer.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + results } - def deserialize(bytes: Array[Byte]): PythonTransformFunction = { + def deserialize(bytes: Array[Byte]): PythonTransformFunction = synchronized { require(serializer != null, "Serializer has not been registered!") - serializer.loads(bytes) + val pfunc = serializer.loads(bytes) + val failure = serializer.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + pfunc } } From 94164a257b113da66ab4c246406ac474954fcf2c Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 25 Nov 2015 12:58:18 -0800 Subject: [PATCH 0932/1168] [SPARK-11866][NETWORK][CORE] Make sure timed out RPCs are cleaned up. This change does a couple of different things to make sure that the RpcEnv-level code and the network library agree about the status of outstanding RPCs. For RPCs that do not expect a reply ("RpcEnv.send"), support for one way messages (hello CORBA!) was added to the network layer. This is a "fire and forget" message that does not require any state to be kept by the TransportClient; as a result, the RpcEnv 'Ack' message is not needed anymore. For RPCs that do expect a reply ("RpcEnv.ask"), the network library now returns the internal RPC id; if the RpcEnv layer decides to time out the RPC before the network layer does, it now asks the TransportClient to forget about the RPC, so that if the network-level timeout occurs, the client is not killed. As part of implementing the above, I cleaned up some of the code in the netty rpc backend, removing types that were not necessary and factoring out some common code. Of interest is a slight change in the exceptions when posting messages to a stopped RpcEnv; that's mostly to avoid nasty error messages from the local-cluster backend when shutting down, which pollutes the terminal output. Author: Marcelo Vanzin Closes #9917 from vanzin/SPARK-11866. --- .../spark/deploy/worker/ExecutorRunner.scala | 6 +- .../apache/spark/rpc/netty/Dispatcher.scala | 55 +++---- .../org/apache/spark/rpc/netty/Inbox.scala | 28 ++-- .../spark/rpc/netty/NettyRpcCallContext.scala | 35 +--- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 153 +++++++----------- .../org/apache/spark/rpc/netty/Outbox.scala | 64 ++++++-- .../apache/spark/rpc/netty/InboxSuite.scala | 6 +- .../rpc/netty/NettyRpcHandlerSuite.scala | 2 +- .../spark/network/client/TransportClient.java | 34 +++- .../spark/network/protocol/Message.java | 4 +- .../network/protocol/MessageDecoder.java | 3 + .../spark/network/protocol/OneWayMessage.java | 75 +++++++++ .../spark/network/sasl/SaslRpcHandler.java | 5 + .../spark/network/server/RpcHandler.java | 36 +++++ .../server/TransportRequestHandler.java | 18 ++- .../apache/spark/network/ProtocolSuite.java | 2 + .../spark/network/RpcIntegrationSuite.java | 31 ++++ .../spark/network/sasl/SparkSaslSuite.java | 9 ++ 18 files changed, 374 insertions(+), 192 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 3aef0515cbf6e..25a17473e4b53 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -92,7 +92,11 @@ private[deploy] class ExecutorRunner( process.destroy() exitCode = Some(process.waitFor()) } - worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) + try { + worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) + } catch { + case e: IllegalStateException => logWarning(e.getMessage(), e) + } } /** Stop this executor runner, including killing the process it launched */ diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index eb25d6c7b721b..533c9847661b6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -106,44 +106,30 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val iter = endpoints.keySet().iterator() while (iter.hasNext) { val name = iter.next - postMessage( - name, - _ => message, - () => { logWarning(s"Drop $message because $name has been stopped") }) + postMessage(name, message, (e) => logWarning(s"Message $message dropped.", e)) } } /** Posts a message sent by a remote endpoint. */ def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { - def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { - val rpcCallContext = - new RemoteNettyRpcCallContext( - nettyEnv, sender, callback, message.senderAddress, message.needReply) - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) - } - - def onEndpointStopped(): Unit = { - callback.onFailure( - new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) - } - - postMessage(message.receiver.name, createMessage, onEndpointStopped) + val rpcCallContext = + new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress) + val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext) + postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e)) } /** Posts a message sent by a local endpoint. */ def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = { - def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { - val rpcCallContext = - new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p) - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) - } - - def onEndpointStopped(): Unit = { - p.tryFailure( - new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) - } + val rpcCallContext = + new LocalNettyRpcCallContext(message.senderAddress, p) + val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext) + postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e)) + } - postMessage(message.receiver.name, createMessage, onEndpointStopped) + /** Posts a one-way message. */ + def postOneWayMessage(message: RequestMessage): Unit = { + postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content), + (e) => throw e) } /** @@ -155,21 +141,26 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { */ private def postMessage( endpointName: String, - createMessageFn: NettyRpcEndpointRef => InboxMessage, - callbackIfStopped: () => Unit): Unit = { + message: InboxMessage, + callbackIfStopped: (Exception) => Unit): Unit = { val shouldCallOnStop = synchronized { val data = endpoints.get(endpointName) if (stopped || data == null) { true } else { - data.inbox.post(createMessageFn(data.ref)) + data.inbox.post(message) receivers.offer(data) false } } if (shouldCallOnStop) { // We don't need to call `onStop` in the `synchronized` block - callbackIfStopped() + val error = if (stopped) { + new IllegalStateException("RpcEnv already stopped.") + } else { + new SparkException(s"Could not find $endpointName or it has been stopped.") + } + callbackIfStopped(error) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 464027f07cc88..175463cc10319 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -27,10 +27,13 @@ import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} private[netty] sealed trait InboxMessage -private[netty] case class ContentMessage( +private[netty] case class OneWayMessage( + senderAddress: RpcAddress, + content: Any) extends InboxMessage + +private[netty] case class RpcMessage( senderAddress: RpcAddress, content: Any, - needReply: Boolean, context: NettyRpcCallContext) extends InboxMessage private[netty] case object OnStart extends InboxMessage @@ -96,29 +99,24 @@ private[netty] class Inbox( while (true) { safelyCall(endpoint) { message match { - case ContentMessage(_sender, content, needReply, context) => - // The partial function to call - val pf = if (needReply) endpoint.receiveAndReply(context) else endpoint.receive + case RpcMessage(_sender, content, context) => try { - pf.applyOrElse[Any, Unit](content, { msg => + endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg => throw new SparkException(s"Unsupported message $message from ${_sender}") }) - if (!needReply) { - context.finish() - } } catch { case NonFatal(e) => - if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - context.sendFailure(e) - } else { - context.finish() - } + context.sendFailure(e) // Throw the exception -- this exception will be caught by the safelyCall function. // The endpoint's onError function will be called. throw e } + case OneWayMessage(_sender, content) => + endpoint.receive.applyOrElse[Any, Unit](content, { msg => + throw new SparkException(s"Unsupported message $message from ${_sender}") + }) + case OnStart => endpoint.onStart() if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala index 21d5bb4923d1b..6637e2321f673 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -23,49 +23,28 @@ import org.apache.spark.Logging import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc.{RpcAddress, RpcCallContext} -private[netty] abstract class NettyRpcCallContext( - endpointRef: NettyRpcEndpointRef, - override val senderAddress: RpcAddress, - needReply: Boolean) +private[netty] abstract class NettyRpcCallContext(override val senderAddress: RpcAddress) extends RpcCallContext with Logging { protected def send(message: Any): Unit override def reply(response: Any): Unit = { - if (needReply) { - send(AskResponse(endpointRef, response)) - } else { - throw new IllegalStateException( - s"Cannot send $response to the sender because the sender does not expect a reply") - } + send(response) } override def sendFailure(e: Throwable): Unit = { - if (needReply) { - send(AskResponse(endpointRef, RpcFailure(e))) - } else { - logError(e.getMessage, e) - throw new IllegalStateException( - "Cannot send reply to the sender because the sender won't handle it") - } + send(RpcFailure(e)) } - def finish(): Unit = { - if (!needReply) { - send(Ack(endpointRef)) - } - } } /** * If the sender and the receiver are in the same process, the reply can be sent back via `Promise`. */ private[netty] class LocalNettyRpcCallContext( - endpointRef: NettyRpcEndpointRef, senderAddress: RpcAddress, - needReply: Boolean, p: Promise[Any]) - extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + extends NettyRpcCallContext(senderAddress) { override protected def send(message: Any): Unit = { p.success(message) @@ -77,11 +56,9 @@ private[netty] class LocalNettyRpcCallContext( */ private[netty] class RemoteNettyRpcCallContext( nettyEnv: NettyRpcEnv, - endpointRef: NettyRpcEndpointRef, callback: RpcResponseCallback, - senderAddress: RpcAddress, - needReply: Boolean) - extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + senderAddress: RpcAddress) + extends NettyRpcCallContext(senderAddress) { override protected def send(message: Any): Unit = { val reply = nettyEnv.serialize(message) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index c8fa870f50e68..c7d74fa1d9195 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -150,7 +150,7 @@ private[netty] class NettyRpcEnv( private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = { if (receiver.client != null) { - receiver.client.sendRpc(message.content, message.createCallback(receiver.client)); + message.sendWith(receiver.client) } else { require(receiver.address != null, "Cannot send message to client endpoint with no listen address.") @@ -182,25 +182,10 @@ private[netty] class NettyRpcEnv( val remoteAddr = message.receiver.address if (remoteAddr == address) { // Message to a local RPC endpoint. - val promise = Promise[Any]() - dispatcher.postLocalMessage(message, promise) - promise.future.onComplete { - case Success(response) => - val ack = response.asInstanceOf[Ack] - logTrace(s"Received ack from ${ack.sender}") - case Failure(e) => - logWarning(s"Exception when sending $message", e) - }(ThreadUtils.sameThread) + dispatcher.postOneWayMessage(message) } else { // Message to a remote RPC endpoint. - postToOutbox(message.receiver, OutboxMessage(serialize(message), - (e) => { - logWarning(s"Exception when sending $message", e) - }, - (client, response) => { - val ack = deserialize[Ack](client, response) - logDebug(s"Receive ack from ${ack.sender}") - })) + postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message))) } } @@ -208,46 +193,52 @@ private[netty] class NettyRpcEnv( clientFactory.createClient(address.host, address.port) } - private[netty] def ask(message: RequestMessage): Future[Any] = { + private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = { val promise = Promise[Any]() val remoteAddr = message.receiver.address + + def onFailure(e: Throwable): Unit = { + if (!promise.tryFailure(e)) { + logWarning(s"Ignored failure: $e") + } + } + + def onSuccess(reply: Any): Unit = reply match { + case RpcFailure(e) => onFailure(e) + case rpcReply => + if (!promise.trySuccess(rpcReply)) { + logWarning(s"Ignored message: $reply") + } + } + if (remoteAddr == address) { val p = Promise[Any]() - dispatcher.postLocalMessage(message, p) p.future.onComplete { - case Success(response) => - val reply = response.asInstanceOf[AskResponse] - if (reply.reply.isInstanceOf[RpcFailure]) { - if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure: ${reply.reply}") - } - } else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message: ${reply}") - } - case Failure(e) => - if (!promise.tryFailure(e)) { - logWarning("Ignore Exception", e) - } + case Success(response) => onSuccess(response) + case Failure(e) => onFailure(e) }(ThreadUtils.sameThread) + dispatcher.postLocalMessage(message, p) } else { - postToOutbox(message.receiver, OutboxMessage(serialize(message), - (e) => { - if (!promise.tryFailure(e)) { - logWarning("Ignore Exception", e) - } - }, - (client, response) => { - val reply = deserialize[AskResponse](client, response) - if (reply.reply.isInstanceOf[RpcFailure]) { - if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure: ${reply.reply}") - } - } else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message: ${reply}") - } - })) + val rpcMessage = RpcOutboxMessage(serialize(message), + onFailure, + (client, response) => onSuccess(deserialize[Any](client, response))) + postToOutbox(message.receiver, rpcMessage) + promise.future.onFailure { + case _: TimeoutException => rpcMessage.onTimeout() + case _ => + }(ThreadUtils.sameThread) } - promise.future + + val timeoutCancelable = timeoutScheduler.schedule(new Runnable { + override def run(): Unit = { + promise.tryFailure( + new TimeoutException("Cannot receive any reply in ${timeout.duration}")) + } + }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) + promise.future.onComplete { v => + timeoutCancelable.cancel(true) + }(ThreadUtils.sameThread) + promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } private[netty] def serialize(content: Any): Array[Byte] = { @@ -512,25 +503,12 @@ private[netty] class NettyRpcEndpointRef( override def name: String = _name override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - val promise = Promise[Any]() - val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable { - override def run(): Unit = { - promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration)) - } - }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) - val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true)) - f.onComplete { v => - timeoutCancelable.cancel(true) - if (!promise.tryComplete(v)) { - logWarning(s"Ignore message $v") - } - }(ThreadUtils.sameThread) - promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) + nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout) } override def send(message: Any): Unit = { require(message != null, "Message is null") - nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false)) + nettyEnv.send(RequestMessage(nettyEnv.address, this, message)) } override def toString: String = s"NettyRpcEndpointRef(${_address})" @@ -549,24 +527,7 @@ private[netty] class NettyRpcEndpointRef( * The message that is sent from the sender to the receiver. */ private[netty] case class RequestMessage( - senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean) - -/** - * The base trait for all messages that are sent back from the receiver to the sender. - */ -private[netty] trait ResponseMessage - -/** - * The reply for `ask` from the receiver side. - */ -private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any) - extends ResponseMessage - -/** - * A message to send back to the receiver side. It's necessary because [[TransportClient]] only - * clean the resources when it receives a reply. - */ -private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage + senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any) /** * A response that indicates some failure happens in the receiver side. @@ -598,6 +559,18 @@ private[netty] class NettyRpcHandler( client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { + val messageToDispatch = internalReceive(client, message) + dispatcher.postRemoteMessage(messageToDispatch, callback) + } + + override def receive( + client: TransportClient, + message: Array[Byte]): Unit = { + val messageToDispatch = internalReceive(client, message) + dispatcher.postOneWayMessage(messageToDispatch) + } + + private def internalReceive(client: TransportClient, message: Array[Byte]): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) @@ -605,14 +578,12 @@ private[netty] class NettyRpcHandler( dispatcher.postToAll(RemoteProcessConnected(clientAddr)) } val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) - val messageToDispatch = if (requestMessage.senderAddress == null) { - // Create a new message with the socket address of the client as the sender. - RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content, - requestMessage.needReply) - } else { - requestMessage - } - dispatcher.postRemoteMessage(messageToDispatch, callback) + if (requestMessage.senderAddress == null) { + // Create a new message with the socket address of the client as the sender. + RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) + } else { + requestMessage + } } override def getStreamManager: StreamManager = streamManager diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 2f6817f2eb935..36fdd00bbc4c2 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -22,22 +22,56 @@ import javax.annotation.concurrent.GuardedBy import scala.util.control.NonFatal -import org.apache.spark.SparkException +import org.apache.spark.{Logging, SparkException} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.rpc.RpcAddress -private[netty] case class OutboxMessage(content: Array[Byte], - _onFailure: (Throwable) => Unit, - _onSuccess: (TransportClient, Array[Byte]) => Unit) { +private[netty] sealed trait OutboxMessage { - def createCallback(client: TransportClient): RpcResponseCallback = new RpcResponseCallback() { - override def onFailure(e: Throwable): Unit = { - _onFailure(e) - } + def sendWith(client: TransportClient): Unit - override def onSuccess(response: Array[Byte]): Unit = { - _onSuccess(client, response) - } + def onFailure(e: Throwable): Unit + +} + +private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends OutboxMessage + with Logging { + + override def sendWith(client: TransportClient): Unit = { + client.send(content) + } + + override def onFailure(e: Throwable): Unit = { + logWarning(s"Failed to send one-way RPC.", e) + } + +} + +private[netty] case class RpcOutboxMessage( + content: Array[Byte], + _onFailure: (Throwable) => Unit, + _onSuccess: (TransportClient, Array[Byte]) => Unit) + extends OutboxMessage with RpcResponseCallback { + + private var client: TransportClient = _ + private var requestId: Long = _ + + override def sendWith(client: TransportClient): Unit = { + this.client = client + this.requestId = client.sendRpc(content, this) + } + + def onTimeout(): Unit = { + require(client != null, "TransportClient has not yet been set.") + client.removeRpcRequest(requestId) + } + + override def onFailure(e: Throwable): Unit = { + _onFailure(e) + } + + override def onSuccess(response: Array[Byte]): Unit = { + _onSuccess(client, response) } } @@ -82,7 +116,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { } } if (dropped) { - message._onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message.onFailure(new SparkException("Message is dropped because Outbox is stopped")) } else { drainOutbox() } @@ -122,7 +156,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { try { val _client = synchronized { client } if (_client != null) { - _client.sendRpc(message.content, message.createCallback(_client)) + message.sendWith(_client) } else { assert(stopped == true) } @@ -195,7 +229,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { // update messages and it's safe to just drain the queue. var message = messages.poll() while (message != null) { - message._onFailure(e) + message.onFailure(e) message = messages.poll() } assert(messages.isEmpty) @@ -229,7 +263,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { // update messages and it's safe to just drain the queue. var message = messages.poll() while (message != null) { - message._onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message.onFailure(new SparkException("Message is dropped because Outbox is stopped")) message = messages.poll() } } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index 276c077b3d13e..2136795b18813 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -35,7 +35,7 @@ class InboxSuite extends SparkFunSuite { val dispatcher = mock(classOf[Dispatcher]) val inbox = new Inbox(endpointRef, endpoint) - val message = ContentMessage(null, "hi", false, null) + val message = OneWayMessage(null, "hi") inbox.post(message) inbox.process(dispatcher) assert(inbox.isEmpty) @@ -55,7 +55,7 @@ class InboxSuite extends SparkFunSuite { val dispatcher = mock(classOf[Dispatcher]) val inbox = new Inbox(endpointRef, endpoint) - val message = ContentMessage(null, "hi", true, null) + val message = RpcMessage(null, "hi", null) inbox.post(message) inbox.process(dispatcher) assert(inbox.isEmpty) @@ -83,7 +83,7 @@ class InboxSuite extends SparkFunSuite { new Thread { override def run(): Unit = { for (_ <- 0 until 100) { - val message = ContentMessage(null, "hi", false, null) + val message = OneWayMessage(null, "hi") inbox.post(message) } exitLatch.countDown() diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index ccca795683da3..323184cdd9b6e 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -33,7 +33,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) val sm = mock(classOf[StreamManager]) when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())) - .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) + .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 876fcd846791c..8a58e7b24585b 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; @@ -36,6 +37,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.StreamRequest; @@ -205,8 +207,12 @@ public void operationComplete(ChannelFuture future) throws Exception { /** * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked * with the server's response or upon any failure. + * + * @param message The message to send. + * @param callback Callback to handle the RPC's reply. + * @return The RPC's id. */ - public void sendRpc(byte[] message, final RpcResponseCallback callback) { + public long sendRpc(byte[] message, final RpcResponseCallback callback) { final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); @@ -235,6 +241,8 @@ public void operationComplete(ChannelFuture future) throws Exception { } } }); + + return requestId; } /** @@ -265,11 +273,35 @@ public void onFailure(Throwable e) { } } + /** + * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the + * message, and no delivery guarantees are made. + * + * @param message The message to send. + */ + public void send(byte[] message) { + channel.writeAndFlush(new OneWayMessage(message)); + } + + /** + * Removes any state associated with the given RPC. + * + * @param requestId The RPC id returned by {@link #sendRpc(byte[], RpcResponseCallback)}. + */ + public void removeRpcRequest(long requestId) { + handler.removeRpcRequest(requestId); + } + /** Mark this channel as having timed out. */ public void timeOut() { this.timedOut = true; } + @VisibleForTesting + public TransportResponseHandler getHandler() { + return handler; + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java index d01598c20f16f..39afd03db60ee 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -28,7 +28,8 @@ public interface Message extends Encodable { public static enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), - StreamRequest(6), StreamResponse(7), StreamFailure(8); + StreamRequest(6), StreamResponse(7), StreamFailure(8), + OneWayMessage(9); private final byte id; @@ -55,6 +56,7 @@ public static Type decode(ByteBuf buf) { case 6: return StreamRequest; case 7: return StreamResponse; case 8: return StreamFailure; + case 9: return OneWayMessage; default: throw new IllegalArgumentException("Unknown message type: " + id); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 3c04048f3821a..074780f2b95ce 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -63,6 +63,9 @@ private Message decode(Message.Type msgType, ByteBuf in) { case RpcFailure: return RpcFailure.decode(in); + case OneWayMessage: + return OneWayMessage.decode(in); + case StreamRequest: return StreamRequest.decode(in); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java new file mode 100644 index 0000000000000..95a0270be3da9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * A RPC that does not expect a reply, which is handled by a remote + * {@link org.apache.spark.network.server.RpcHandler}. + */ +public final class OneWayMessage implements RequestMessage { + /** Serialized message to send to remote RpcHandler. */ + public final byte[] message; + + public OneWayMessage(byte[] message) { + this.message = message; + } + + @Override + public Type type() { return Type.OneWayMessage; } + + @Override + public int encodedLength() { + return Encoders.ByteArrays.encodedLength(message); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.ByteArrays.encode(buf, message); + } + + public static OneWayMessage decode(ByteBuf buf) { + byte[] message = Encoders.ByteArrays.decode(buf); + return new OneWayMessage(message); + } + + @Override + public int hashCode() { + return Arrays.hashCode(message); + } + + @Override + public boolean equals(Object other) { + if (other instanceof OneWayMessage) { + OneWayMessage o = (OneWayMessage) other; + return Arrays.equals(message, o.message); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("message", message) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 7033adb9cae6f..830db94b890c5 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -108,6 +108,11 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } } + @Override + public void receive(TransportClient client, byte[] message) { + delegate.receive(client, message); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index dbb7f95f55bc0..65109ddfe13b9 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -17,6 +17,9 @@ package org.apache.spark.network.server; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; @@ -24,6 +27,9 @@ * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */ public abstract class RpcHandler { + + private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback(); + /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. @@ -47,6 +53,19 @@ public abstract void receive( */ public abstract StreamManager getStreamManager(); + /** + * Receives an RPC message that does not expect a reply. The default implementation will + * call "{@link receive(TransportClient, byte[], RpcResponseCallback}" and log a warning if + * any of the callback methods are called. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param message The serialized bytes of the RPC. + */ + public void receive(TransportClient client, byte[] message) { + receive(client, message, ONE_WAY_CALLBACK); + } + /** * Invoked when the connection associated with the given client has been invalidated. * No further requests will come from this client. @@ -54,4 +73,21 @@ public abstract void receive( public void connectionTerminated(TransportClient client) { } public void exceptionCaught(Throwable cause, TransportClient client) { } + + private static class OneWayRpcCallback implements RpcResponseCallback { + + private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class); + + @Override + public void onSuccess(byte[] response) { + logger.warn("Response provided for one-way RPC."); + } + + @Override + public void onFailure(Throwable e) { + logger.error("Error response provided for one-way RPC.", e); + } + + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 4f67bd573be21..db18ea77d1073 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.server; +import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; @@ -27,13 +28,14 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.OneWayMessage; +import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamFailure; import org.apache.spark.network.protocol.StreamRequest; @@ -95,6 +97,8 @@ public void handle(RequestMessage request) { processFetchRequest((ChunkFetchRequest) request); } else if (request instanceof RpcRequest) { processRpcRequest((RpcRequest) request); + } else if (request instanceof OneWayMessage) { + processOneWayMessage((OneWayMessage) request); } else if (request instanceof StreamRequest) { processStreamRequest((StreamRequest) request); } else { @@ -156,6 +160,14 @@ public void onFailure(Throwable e) { } } + private void processOneWayMessage(OneWayMessage req) { + try { + rpcHandler.receive(reverseClient, req.message); + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() for one-way message.", e); + } + } + /** * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 22b451fc0e60e..1aa20900ffe74 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -35,6 +35,7 @@ import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; @@ -84,6 +85,7 @@ public void requests() { testClientToServer(new RpcRequest(12345, new byte[0])); testClientToServer(new RpcRequest(12345, new byte[100])); testClientToServer(new StreamRequest("abcde")); + testClientToServer(new OneWayMessage(new byte[100])); } @Test diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 8eb56bdd9846f..88fa2258bb794 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,9 +17,11 @@ package org.apache.spark.network; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import java.util.Set; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; @@ -46,6 +48,7 @@ public class RpcIntegrationSuite { static TransportServer server; static TransportClientFactory clientFactory; static RpcHandler rpcHandler; + static List oneWayMsgs; @BeforeClass public static void setUp() throws Exception { @@ -64,12 +67,19 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } } + @Override + public void receive(TransportClient client, byte[] message) { + String msg = new String(message, Charsets.UTF_8); + oneWayMsgs.add(msg); + } + @Override public StreamManager getStreamManager() { return new OneForOneStreamManager(); } }; TransportContext context = new TransportContext(conf, rpcHandler); server = context.createServer(); clientFactory = context.createClientFactory(); + oneWayMsgs = new ArrayList<>(); } @AfterClass @@ -158,6 +168,27 @@ public void sendSuccessAndFailure() throws Exception { assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !")); } + @Test + public void sendOneWayMessage() throws Exception { + final String message = "no reply"; + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.send(message.getBytes(Charsets.UTF_8)); + assertEquals(0, client.getHandler().numOutstandingRequests()); + + // Make sure the message arrives. + long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); + while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) { + TimeUnit.MILLISECONDS.sleep(10); + } + + assertEquals(1, oneWayMsgs.size()); + assertEquals(message, oneWayMsgs.get(0)); + } finally { + client.close(); + } + } + private void assertErrorsContain(Set errors, Set contains) { assertEquals(contains.size(), errors.size()); diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index b146899670180..a6f180bc40c9a 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.*; import java.io.File; +import java.lang.reflect.Method; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; @@ -353,6 +354,14 @@ public void testRpcHandlerDelegate() throws Exception { verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); } + @Test + public void testDelegates() throws Exception { + Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods(); + for (Method m : rpcHandlerMethods) { + SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes()); + } + } + private static class SaslTestCtx { final TransportClient client; From 5e582881b7eed2a14106769f61d4efe2e3a5a23e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 25 Nov 2015 13:45:41 -0800 Subject: [PATCH 0933/1168] Fix Aggregator documentation (rename present to finish). --- .../scala/org/apache/spark/sql/expressions/Aggregator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index b0cd32b5f73e6..65117d5824755 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} * def zero: Int = 0 * def reduce(b: Int, a: Data): Int = b + a.i * def merge(b1: Int, b2: Int): Int = b1 + b2 - * def present(r: Int): Int = r + * def finish(r: Int): Int = r * }.toColumn() * * val ds: Dataset[Data] = ... From 7d88f536f3d3640e8915f3a4f2d6c6017bdccc2a Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 25 Nov 2015 13:51:30 -0800 Subject: [PATCH 0934/1168] [SPARK-11983][SQL] remove all unused codegen fallback trait Author: Daoyuan Wang Closes #9966 from adrian-wang/removeFallback. --- .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 3 +-- .../spark/sql/catalyst/expressions/regexpExpressions.scala | 4 ++-- .../spark/sql/catalyst/expressions/NonFoldableLiteral.scala | 3 +-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 533d17ea5c172..a2c6c39fd8ce2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -104,8 +104,7 @@ object Cast { } /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) - extends UnaryExpression with CodegenFallback { +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def toString: String = s"cast($child as ${dataType.simpleString})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 9e484c5ed83bf..adef6050c3565 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -66,7 +66,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes { * Simple RegEx pattern matching function */ case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { + extends BinaryExpression with StringRegexExpression { override def escape(v: String): String = StringUtils.escapeLikeRegex(v) @@ -117,7 +117,7 @@ case class Like(left: Expression, right: Expression) case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { + extends BinaryExpression with StringRegexExpression { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala index 31ecf4a9e810a..118fd695fe2f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -26,8 +26,7 @@ import org.apache.spark.sql.types._ * A literal value that is not foldable. Used in expression codegen testing to test code path * that behave differently based on foldable values. */ -case class NonFoldableLiteral(value: Any, dataType: DataType) - extends LeafExpression with CodegenFallback { +case class NonFoldableLiteral(value: Any, dataType: DataType) extends LeafExpression { override def foldable: Boolean = false override def nullable: Boolean = true From 01551c503cf56db2079ab803765269ae87e160ef Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Wed, 25 Nov 2015 15:13:13 -0800 Subject: [PATCH 0935/1168] [SPARK-11206] Support SQL UI on the history server On the live web UI, there is a SQL tab which provides valuable information for the SQL query. But once the workload is finished, we won't see the SQL tab on the history server. It will be helpful if we support SQL UI on the history server so we can analyze it even after its execution. To support SQL UI on the history server: 1. I added an `onOtherEvent` method to the `SparkListener` trait and post all SQL related events to the same event bus. 2. Two SQL events `SparkListenerSQLExecutionStart` and `SparkListenerSQLExecutionEnd` are defined in the sql module. 3. The new SQL events are written to event log using Jackson. 4. A new trait `SparkHistoryListenerFactory` is added to allow the history server to feed events to the SQL history listener. The SQL implementation is loaded at runtime using `java.util.ServiceLoader`. Author: Carson Wang Closes #9297 from carsonwang/SqlHistoryUI. --- .rat-excludes | 1 + .../org/apache/spark/JavaSparkListener.java | 3 + .../apache/spark/SparkFirehoseListener.java | 4 + .../scheduler/EventLoggingListener.scala | 4 + .../spark/scheduler/SparkListener.scala | 24 ++- .../spark/scheduler/SparkListenerBus.scala | 1 + .../scala/org/apache/spark/ui/SparkUI.scala | 16 +- .../org/apache/spark/util/JsonProtocol.scala | 11 +- ...park.scheduler.SparkHistoryListenerFactory | 1 + .../org/apache/spark/sql/SQLContext.scala | 18 ++- .../spark/sql/execution/SQLExecution.scala | 24 +-- .../spark/sql/execution/SparkPlanInfo.scala | 46 ++++++ .../sql/execution/metric/SQLMetricInfo.scala | 30 ++++ .../sql/execution/metric/SQLMetrics.scala | 56 ++++--- .../sql/execution/ui/ExecutionPage.scala | 4 +- .../spark/sql/execution/ui/SQLListener.scala | 139 ++++++++++++------ .../spark/sql/execution/ui/SQLTab.scala | 12 +- .../sql/execution/ui/SparkPlanGraph.scala | 20 +-- .../execution/metric/SQLMetricsSuite.scala | 4 +- .../sql/execution/ui/SQLListenerSuite.scala | 43 +++--- .../spark/sql/test/SharedSQLContext.scala | 1 + 21 files changed, 327 insertions(+), 135 deletions(-) create mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala diff --git a/.rat-excludes b/.rat-excludes index 08fba6d351d6a..7262c960ed6bb 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -82,4 +82,5 @@ INDEX gen-java.* .*avpr org.apache.spark.sql.sources.DataSourceRegister +org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index fa9acf0a15b88..23bc9a2e81727 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -82,4 +82,7 @@ public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } @Override public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } + @Override + public void onOtherEvent(SparkListenerEvent event) { } + } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 1214d05ba6063..e6b24afd88ad4 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,4 +118,8 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { onEvent(blockUpdated); } + @Override + public void onOtherEvent(SparkListenerEvent event) { + onEvent(event); + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 000a021a528cf..eaa07acc5132e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -207,6 +207,10 @@ private[spark] class EventLoggingListener( // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } + override def onOtherEvent(event: SparkListenerEvent): Unit = { + logEvent(event, flushLogger = true) + } + /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 896f1743332f1..075a7f13172de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -22,15 +22,19 @@ import java.util.Properties import scala.collection.Map import scala.collection.mutable -import org.apache.spark.{Logging, TaskEndReason} +import com.fasterxml.jackson.annotation.JsonTypeInfo + +import org.apache.spark.{Logging, SparkConf, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.ui.SparkUI @DeveloperApi -sealed trait SparkListenerEvent +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") +trait SparkListenerEvent @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) @@ -130,6 +134,17 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent */ private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent +/** + * Interface for creating history listeners defined in other modules like SQL, which are used to + * rebuild the history UI. + */ +private[spark] trait SparkHistoryListenerFactory { + /** + * Create listeners used to rebuild the history UI. + */ + def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] +} + /** * :: DeveloperApi :: * Interface for listening to events from the Spark scheduler. Note that this is an internal @@ -223,6 +238,11 @@ trait SparkListener { * Called when the driver receives a block update info. */ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } + + /** + * Called when other events like SQL-specific events are posted. + */ + def onOtherEvent(event: SparkListenerEvent) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 04afde33f5aad..95722a07144ec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,6 +61,7 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata + case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 4608bce202ec8..8da6884a38535 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,10 +17,13 @@ package org.apache.spark.ui -import java.util.Date +import java.util.{Date, ServiceLoader} + +import scala.collection.JavaConverters._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -154,7 +157,16 @@ private[spark] object SparkUI { appName: String, basePath: String, startTime: Long): SparkUI = { - create(None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + val sparkUI = create( + None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + + val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], + Utils.getContextOrSparkClassLoader).asScala + listenerFactories.foreach { listenerFactory => + val listeners = listenerFactory.createListeners(conf, sparkUI) + listeners.foreach(listenerBus.addListener) + } + sparkUI } /** diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c9beeb25e05af..7f5d713ec6505 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -19,19 +19,21 @@ package org.apache.spark.util import java.util.{Properties, UUID} -import org.apache.spark.scheduler.cluster.ExecutorInfo - import scala.collection.JavaConverters._ import scala.collection.Map +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ /** @@ -54,6 +56,8 @@ private[spark] object JsonProtocol { private implicit val format = DefaultFormats + private val mapper = new ObjectMapper().registerModule(DefaultScalaModule) + /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ @@ -96,6 +100,7 @@ private[spark] object JsonProtocol { executorMetricsUpdateToJson(metricsUpdate) case blockUpdated: SparkListenerBlockUpdated => throw new MatchError(blockUpdated) // TODO(ekl) implement this + case _ => parse(mapper.writeValueAsString(event)) } } @@ -511,6 +516,8 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) + .asInstanceOf[SparkListenerEvent] } } diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory new file mode 100644 index 0000000000000..507100be90967 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory @@ -0,0 +1 @@ +org.apache.spark.sql.execution.ui.SQLHistoryListenerFactory diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 46bf544fd885f..1c2ac5f6f11bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1263,6 +1263,8 @@ object SQLContext { */ @transient private val instantiatedContext = new AtomicReference[SQLContext]() + @transient private val sqlListener = new AtomicReference[SQLListener]() + /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. * @@ -1307,6 +1309,10 @@ object SQLContext { Option(instantiatedContext.get()) } + private[sql] def clearSqlListener(): Unit = { + sqlListener.set(null) + } + /** * Changes the SQLContext that will be returned in this thread and its children when * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives @@ -1355,9 +1361,13 @@ object SQLContext { * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. */ private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { - val listener = new SQLListener(sc.conf) - sc.addSparkListener(listener) - sc.ui.foreach(new SQLTab(listener, _)) - listener + if (sqlListener.get() == null) { + val listener = new SQLListener(sc.conf) + if (sqlListener.compareAndSet(null, listener)) { + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + } + } + sqlListener.get() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 1422e15549c94..34971986261c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,7 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionStart, + SparkListenerSQLExecutionEnd} import org.apache.spark.util.Utils private[sql] object SQLExecution { @@ -45,25 +46,14 @@ private[sql] object SQLExecution { sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) val r = try { val callSite = Utils.getCallSite() - sqlContext.listener.onExecutionStart( - executionId, - callSite.shortForm, - callSite.longForm, - queryExecution.toString, - SparkPlanGraph(queryExecution.executedPlan), - System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) try { body } finally { - // Ideally, we need to make sure onExecutionEnd happens after onJobStart and onJobEnd. - // However, onJobStart and onJobEnd run in the listener thread. Because we cannot add new - // SQL event types to SparkListener since it's a public API, we cannot guarantee that. - // - // SQLListener should handle the case that onExecutionEnd happens before onJobEnd. - // - // The worst case is onExecutionEnd may happen before onJobStart when the listener thread - // is very busy. If so, we cannot track the jobs for the execution. It seems acceptable. - sqlContext.listener.onExecutionEnd(executionId, System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } } finally { sc.setLocalProperty(EXECUTION_ID_KEY, null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala new file mode 100644 index 0000000000000..486ce34064e43 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.metric.SQLMetricInfo +import org.apache.spark.util.Utils + +/** + * :: DeveloperApi :: + * Stores information about a SQL SparkPlan. + */ +@DeveloperApi +class SparkPlanInfo( + val nodeName: String, + val simpleString: String, + val children: Seq[SparkPlanInfo], + val metrics: Seq[SQLMetricInfo]) + +private[sql] object SparkPlanInfo { + + def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { + val metrics = plan.metrics.toSeq.map { case (key, metric) => + new SQLMetricInfo(metric.name.getOrElse(key), metric.id, + Utils.getFormattedClassName(metric.param)) + } + val children = plan.children.map(fromSparkPlan) + + new SparkPlanInfo(plan.nodeName, plan.simpleString, children, metrics) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala new file mode 100644 index 0000000000000..2708219ad3485 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Stores information about a SQL Metric. + */ +@DeveloperApi +class SQLMetricInfo( + val name: String, + val accumulatorId: Long, + val metricParam: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 1c253e3942e95..6c0f6f8a52dc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -104,21 +104,39 @@ private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialVa override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) } +private object LongSQLMetricParam extends LongSQLMetricParam(_.sum.toString, 0L) + +private object StaticsLongSQLMetricParam extends LongSQLMetricParam( + (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.bytesToString) + } + s"\n$sum ($min, $med, $max)" + }, -1L) + private[sql] object SQLMetrics { private def createLongMetric( sc: SparkContext, name: String, - stringValue: Seq[Long] => String, - initialValue: Long): LongSQLMetric = { - val param = new LongSQLMetricParam(stringValue, initialValue) + param: LongSQLMetricParam): LongSQLMetric = { val acc = new LongSQLMetric(name, param) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - createLongMetric(sc, name, _.sum.toString, 0L) + createLongMetric(sc, name, LongSQLMetricParam) } /** @@ -126,31 +144,25 @@ private[sql] object SQLMetrics { * spill size, etc. */ def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = { - val stringValue = (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.bytesToString) - } - s"\n$sum ($min, $med, $max)" - } // The final result of this metric in physical operator UI may looks like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) - createLongMetric(sc, s"$name total (min, med, max)", stringValue, -1L) + createLongMetric(sc, s"$name total (min, med, max)", StaticsLongSQLMetricParam) + } + + def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = { + val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam) + val staticsSQLMetricParam = Utils.getFormattedClassName(StaticsLongSQLMetricParam) + val metricParam = metricParamName match { + case `longSQLMetricParam` => LongSQLMetricParam + case `staticsSQLMetricParam` => StaticsLongSQLMetricParam + } + metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]] } /** * A metric that its value will be ignored. Use this one when we need a metric parameter but don't * care about the value. */ - val nullLongMetric = new LongSQLMetric("null", new LongSQLMetricParam(_.sum.toString, 0L)) + val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index e74d6fb396e1c..c74ad40406992 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, Unparsed} - -import org.apache.commons.lang3.StringEscapeUtils +import scala.xml.Node import org.apache.spark.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 5a072de400b6a..e19a1e3e5851f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,11 +19,34 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricValue, SQLMetricParam} import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} +import org.apache.spark.ui.SparkUI + +@DeveloperApi +case class SparkListenerSQLExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + sparkPlanInfo: SparkPlanInfo, + time: Long) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) + extends SparkListenerEvent + +private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { + + override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { + List(new SQLHistoryListener(conf, sparkUI)) + } +} private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { @@ -118,7 +141,8 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics, finishTask = false) + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(), + finishTask = false) } } @@ -140,7 +164,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskMetrics, + taskEnd.taskMetrics.accumulatorUpdates(), finishTask = true) } @@ -148,15 +172,12 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi * Update the accumulator values of a task with the latest metrics for this task. This is called * every time we receive an executor heartbeat or when a task finishes. */ - private def updateTaskAccumulatorValues( + protected def updateTaskAccumulatorValues( taskId: Long, stageId: Int, stageAttemptID: Int, - metrics: TaskMetrics, + accumulatorUpdates: Map[Long, Any], finishTask: Boolean): Unit = { - if (metrics == null) { - return - } _stageIdToStageMetrics.get(stageId) match { case Some(stageMetrics) => @@ -174,9 +195,9 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi case Some(taskMetrics) => if (finishTask) { taskMetrics.finished = true - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else if (!taskMetrics.finished) { - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else { // If a task is finished, we should not override with accumulator updates from // heartbeat reports @@ -185,7 +206,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi // TODO Now just set attemptId to 0. Should fix here when we can get the attempt // id from SparkListenerExecutorMetricsUpdate stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( - attemptId = 0, finished = finishTask, metrics.accumulatorUpdates()) + attemptId = 0, finished = finishTask, accumulatorUpdates) } } case None => @@ -193,38 +214,40 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } } - def onExecutionStart( - executionId: Long, - description: String, - details: String, - physicalPlanDescription: String, - physicalPlanGraph: SparkPlanGraph, - time: Long): Unit = { - val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => - node.metrics.map(metric => metric.accumulatorId -> metric) - } - - val executionUIData = new SQLExecutionUIData(executionId, description, details, - physicalPlanDescription, physicalPlanGraph, sqlPlanMetrics.toMap, time) - synchronized { - activeExecutions(executionId) = executionUIData - _executionIdToData(executionId) = executionUIData - } - } - - def onExecutionEnd(executionId: Long, time: Long): Unit = synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - executionUIData.completionTime = Some(time) - if (!executionUIData.hasRunningJobs) { - // onExecutionEnd happens after all "onJobEnd"s - // So we should update the execution lists. - markExecutionFinished(executionId) - } else { - // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. - // Then we don't if the execution is successful, so let the last onJobEnd updates the - // execution lists. + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerSQLExecutionStart(executionId, description, details, + physicalPlanDescription, sparkPlanInfo, time) => + val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) + val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + val executionUIData = new SQLExecutionUIData( + executionId, + description, + details, + physicalPlanDescription, + physicalPlanGraph, + sqlPlanMetrics.toMap, + time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + case SparkListenerSQLExecutionEnd(executionId, time) => synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. + } } } + case _ => // Ignore } private def markExecutionFinished(executionId: Long): Unit = { @@ -289,6 +312,38 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } +private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) + extends SQLListener(conf) { + + private var sqlTabAttached = false + + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { + // Do nothing + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskInfo.accumulables.map { acc => + (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong)) + }.toMap, + finishTask = true) + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case _: SparkListenerSQLExecutionStart => + if (!sqlTabAttached) { + new SQLTab(this, sparkUI) + sqlTabAttached = true + } + super.onOtherEvent(event) + case _ => super.onOtherEvent(event) + } +} + /** * Represent all necessary data for an execution that will be used in Web UI. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 9c27944d42fc6..4f50b2ecdc8f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.execution.ui -import java.util.concurrent.atomic.AtomicInteger - import org.apache.spark.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) - extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { + extends SparkUITab(sparkUI, "SQL") with Logging { val parent = sparkUI @@ -35,13 +33,5 @@ private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) } private[sql] object SQLTab { - private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" - - private val nextTabId = new AtomicInteger(0) - - private def nextTabName: String = { - val nextId = nextTabId.getAndIncrement() - if (nextId == 0) "SQL" else s"SQL$nextId" - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index f1fce5478a3fe..7af0ff09c5c6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.SQLMetrics /** * A graph used for storing information of an executionPlan of DataFrame. @@ -48,27 +48,27 @@ private[sql] object SparkPlanGraph { /** * Build a SparkPlanGraph from the root of a SparkPlan tree. */ - def apply(plan: SparkPlan): SparkPlanGraph = { + def apply(planInfo: SparkPlanInfo): SparkPlanGraph = { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(plan, nodeIdGenerator, nodes, edges) + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges) new SparkPlanGraph(nodes, edges) } private def buildSparkPlanGraphNode( - plan: SparkPlan, + planInfo: SparkPlanInfo, nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = plan.metrics.toSeq.map { case (key, metric) => - SQLPlanMetric(metric.name.getOrElse(key), metric.id, - metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) } val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, planInfo.simpleString, metrics) nodes += node - val childrenNodes = plan.children.map( + val childrenNodes = planInfo.children.map( child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) for (child <- childrenNodes) { edges += SparkPlanGraphEdge(child.id, node.id) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 5e2b4154dd7ce..ebfa1eaf3e5bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -26,6 +26,7 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -82,7 +83,8 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => + val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( + df.queryExecution.executedPlan)).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => val nodeMetrics = node.metrics.map { metric => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index c15aac775096c..f93d081d0c30e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -21,10 +21,10 @@ import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.sql.test.SharedSQLContext class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { @@ -82,7 +82,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val executionId = 0 val df = createTestDataFrame val accumulatorIds = - SparkPlanGraph(df.queryExecution.executedPlan).nodes.flatMap(_.metrics.map(_.accumulatorId)) + SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)) + .nodes.flatMap(_.metrics.map(_.accumulatorId)) // Assume all accumulators are long var accumulatorValue = 0L val accumulatorUpdates = accumulatorIds.map { id => @@ -90,13 +91,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (id, accumulatorValue) }.toMap - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) val executionUIData = listener.executionIdToData(0) @@ -206,7 +207,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), JobSucceeded )) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) assert(executionUIData.runningJobs.isEmpty) assert(executionUIData.succeededJobs === Seq(0)) @@ -219,19 +221,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -248,13 +251,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -271,7 +274,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 1, time = System.currentTimeMillis(), @@ -288,19 +292,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq.empty, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 963d10eed62ed..e7b376548787c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -42,6 +42,7 @@ trait SharedSQLContext extends SQLTestUtils { * Initialize the [[TestSQLContext]]. */ protected override def beforeAll(): Unit = { + SQLContext.clearSqlListener() if (_ctx == null) { _ctx = new TestSQLContext } From dd50d760298470e273b70c5c88d57c3c80b3007f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 25 Nov 2015 21:25:20 -0800 Subject: [PATCH 0936/1168] [SPARK-12003] [SQL] remove the prefix for name after expanded star Right now, the expended start will include the name of expression as prefix for column, that's not better than without expending, we should not have the prefix. Author: Davies Liu Closes #9984 from davies/expand_star. --- .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 1b2a8dc4c7f14..4f89b462a6ce3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -204,7 +204,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu case s: StructType => s.zipWithIndex.map { case (f, i) => val extract = GetStructField(attribute.get, i) - Alias(extract, target.get + "." + f.name)() + Alias(extract, f.name)() } case _ => { From 54b1c8987a2dfe44cc15ee4d0a7f9a612bc264c0 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 25 Nov 2015 23:24:33 -0800 Subject: [PATCH 0937/1168] [SPARK-11980][SPARK-10621][SQL] Fix json_tuple and add test cases for Added Python test cases for the function `isnan`, `isnull`, `nanvl` and `json_tuple`. Fixed a bug in the function `json_tuple` rxin , could you help me review my changes? Please let me know anything is missing. Thank you! Have a good Thanksgiving day! Author: gatorsmile Closes #9977 from gatorsmile/json_tuple. --- python/pyspark/sql/functions.py | 44 +++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e3786e0fa5fb2..90625949f747a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -286,14 +286,6 @@ def countDistinct(col, *cols): return Column(jc) -@since(1.4) -def monotonicallyIncreasingId(): - """ - .. note:: Deprecated in 1.6, use monotonically_increasing_id instead. - """ - return monotonically_increasing_id() - - @since(1.6) def input_file_name(): """Creates a string column for the file name of the current Spark task. @@ -305,6 +297,10 @@ def input_file_name(): @since(1.6) def isnan(col): """An expression that returns true iff the column is NaN. + + >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df.select(isnan("a").alias("r1"), isnan(df.a).alias("r2")).collect() + [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.isnan(_to_java_column(col))) @@ -313,11 +309,23 @@ def isnan(col): @since(1.6) def isnull(col): """An expression that returns true iff the column is null. + + >>> df = sqlContext.createDataFrame([(1, None), (None, 2)], ("a", "b")) + >>> df.select(isnull("a").alias("r1"), isnull(df.a).alias("r2")).collect() + [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.isnull(_to_java_column(col))) +@since(1.4) +def monotonicallyIncreasingId(): + """ + .. note:: Deprecated in 1.6, use monotonically_increasing_id instead. + """ + return monotonically_increasing_id() + + @since(1.6) def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. @@ -344,6 +352,10 @@ def nanvl(col1, col2): """Returns col1 if it is not NaN, or col2 if col1 is NaN. Both inputs should be floating point columns (DoubleType or FloatType). + + >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).collect() + [Row(r1=1.0, r2=1.0), Row(r1=2.0, r2=2.0)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) @@ -1460,6 +1472,7 @@ def explode(col): return Column(jc) +@ignore_unicode_prefix @since(1.6) def get_json_object(col, path): """ @@ -1468,22 +1481,33 @@ def get_json_object(col, path): :param col: string column in json format :param path: path to the json object to extract + + >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] + >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) + >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \ + get_json_object(df.jstring, '$.f2').alias("c1") ).collect() + [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.get_json_object(_to_java_column(col), path) return Column(jc) +@ignore_unicode_prefix @since(1.6) -def json_tuple(col, fields): +def json_tuple(col, *fields): """Creates a new row for a json column according to the given field names. :param col: string column in json format :param fields: list of fields to extract + >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] + >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) + >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect() + [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.json_tuple(_to_java_column(col), fields) + jc = sc._jvm.functions.json_tuple(_to_java_column(col), _to_seq(sc, fields)) return Column(jc) From eb846cec095f93ad2ffdcf2cae62f927993e57bd Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 25 Nov 2015 23:31:21 -0800 Subject: [PATCH 0938/1168] [SPARK-11999][CORE] Fix the issue that ThreadUtils.newDaemonCachedThreadPool doesn't cache any task In the previous codes, `newDaemonCachedThreadPool` uses `SynchronousQueue`, which is wrong. `SynchronousQueue` is an empty queue that cannot cache any task. This patch uses `LinkedBlockingQueue` to fix it along with other fixes to make sure `newDaemonCachedThreadPool` can use at most `maxThreadNumber` threads, and after that, cache tasks to `LinkedBlockingQueue`. Author: Shixiong Zhu Closes #9978 from zsxwing/cached-threadpool. --- .../org/apache/spark/util/ThreadUtils.scala | 14 ++++-- .../apache/spark/util/ThreadUtilsSuite.scala | 45 +++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 53283448c87b1..f9fbe2ff858ce 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -56,10 +56,18 @@ private[spark] object ThreadUtils { * Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names * are formatted as prefix-ID, where ID is a unique, sequentially assigned integer. */ - def newDaemonCachedThreadPool(prefix: String, maxThreadNumber: Int): ThreadPoolExecutor = { + def newDaemonCachedThreadPool( + prefix: String, maxThreadNumber: Int, keepAliveSeconds: Int = 60): ThreadPoolExecutor = { val threadFactory = namedThreadFactory(prefix) - new ThreadPoolExecutor( - 0, maxThreadNumber, 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable], threadFactory) + val threadPool = new ThreadPoolExecutor( + maxThreadNumber, // corePoolSize: the max number of threads to create before queuing the tasks + maxThreadNumber, // maximumPoolSize: because we use LinkedBlockingDeque, this one is not used + keepAliveSeconds, + TimeUnit.SECONDS, + new LinkedBlockingQueue[Runnable], + threadFactory) + threadPool.allowCoreThreadTimeOut(true) + threadPool } /** diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 620e4debf4e08..92ae038967528 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -24,6 +24,8 @@ import scala.concurrent.duration._ import scala.concurrent.{Await, Future} import scala.util.Random +import org.scalatest.concurrent.Eventually._ + import org.apache.spark.SparkFunSuite class ThreadUtilsSuite extends SparkFunSuite { @@ -59,6 +61,49 @@ class ThreadUtilsSuite extends SparkFunSuite { } } + test("newDaemonCachedThreadPool") { + val maxThreadNumber = 10 + val startThreadsLatch = new CountDownLatch(maxThreadNumber) + val latch = new CountDownLatch(1) + val cachedThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "ThreadUtilsSuite-newDaemonCachedThreadPool", + maxThreadNumber, + keepAliveSeconds = 2) + try { + for (_ <- 1 to maxThreadNumber) { + cachedThreadPool.execute(new Runnable { + override def run(): Unit = { + startThreadsLatch.countDown() + latch.await(10, TimeUnit.SECONDS) + } + }) + } + startThreadsLatch.await(10, TimeUnit.SECONDS) + assert(cachedThreadPool.getActiveCount === maxThreadNumber) + assert(cachedThreadPool.getQueue.size === 0) + + // Submit a new task and it should be put into the queue since the thread number reaches the + // limitation + cachedThreadPool.execute(new Runnable { + override def run(): Unit = { + latch.await(10, TimeUnit.SECONDS) + } + }) + + assert(cachedThreadPool.getActiveCount === maxThreadNumber) + assert(cachedThreadPool.getQueue.size === 1) + + latch.countDown() + eventually(timeout(10.seconds)) { + // All threads should be stopped after keepAliveSeconds + assert(cachedThreadPool.getActiveCount === 0) + assert(cachedThreadPool.getPoolSize === 0) + } + } finally { + cachedThreadPool.shutdownNow() + } + } + test("sameThread") { val callerThreadName = Thread.currentThread().getName() val f = Future { From 9799433df8b1cc6833fc843cbd2deb3855afce53 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 26 Nov 2015 00:19:42 -0800 Subject: [PATCH 0939/1168] [SPARK-11973] [SQL] push filter through aggregation with alias and literals Currently, filter can't be pushed through aggregation with alias or literals, this patch fix that. After this patch, the time of TPC-DS query 4 go down to 13 seconds from 141 seconds (10x improvements). cc nongli yhuai Author: Davies Liu Closes #9959 from davies/push_filter2. --- .../sql/catalyst/expressions/predicates.scala | 9 ++++ .../sql/catalyst/optimizer/Optimizer.scala | 28 ++++++---- .../optimizer/FilterPushdownSuite.scala | 53 +++++++++++++++++++ 3 files changed, 79 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 68557479a9591..304b438c84ba4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -65,6 +65,15 @@ trait PredicateHelper { } } + // Substitute any known alias from a map. + protected def replaceAlias( + condition: Expression, + aliases: AttributeMap[Expression]): Expression = { + condition.transform { + case a: Attribute => aliases.getOrElse(a, a) + } + } + /** * Returns true if `expr` can be evaluated using only the output of `plan`. This method * can be used to determine when it is acceptable to move expression evaluation within a query diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f4dba67f13b54..52f609bc158ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -640,20 +640,14 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe filter } else { // Push down the small conditions without nondeterministic expressions. - val pushedCondition = deterministic.map(replaceAlias(_, aliasMap)).reduce(And) + val pushedCondition = + deterministic.map(replaceAlias(_, aliasMap)).reduce(And) Filter(nondeterministic.reduce(And), project.copy(child = Filter(pushedCondition, grandChild))) } } } - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - private def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = { - condition.transform { - case a: Attribute => sourceAliases.getOrElse(a, a) - } - } } /** @@ -690,12 +684,24 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel def apply(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) => - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { - conjunct => conjunct.references subsetOf AttributeSet(groupingExpressions) + + def hasAggregate(expression: Expression): Boolean = expression match { + case agg: AggregateExpression => true + case other => expression.children.exists(hasAggregate) + } + // Create a map of Alias for expressions that does not have AggregateExpression + val aliasMap = AttributeMap(aggregateExpressions.collect { + case a: Alias if !hasAggregate(a.child) => (a.toAttribute, a.child) + }) + + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { conjunct => + val replaced = replaceAlias(conjunct, aliasMap) + replaced.references.subsetOf(grandChild.outputSet) && replaced.deterministic } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) - val withPushdown = aggregate.copy(child = Filter(pushDownPredicate, grandChild)) + val replaced = replaceAlias(pushDownPredicate, aliasMap) + val withPushdown = aggregate.copy(child = Filter(replaced, grandChild)) stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) } else { filter diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0290fafe879f6..0128c220baaca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -697,4 +697,57 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("aggregate: push down filters with alias") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) + .where(('c === 2L || 'aa > 4) && 'aa < 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .where('a + 1 < 3) + .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) + .where('c === 2L || 'aa > 4) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("aggregate: push down filters with literal") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a, count('b) as 'c, "s" as 'd) + .where('c === 2L && 'd === "s") + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .where("s" === "s") + .groupBy('a)('a, count('b) as 'c, "s" as 'd) + .where('c === 2L) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("aggregate: don't push down filters which is nondeterministic") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) + .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) + .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + .analyze + + comparePlans(optimized, correctAnswer) + } } From fd196d996b121f1fd66075b86823713d9d6a4096 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 26 Nov 2015 01:15:05 -0800 Subject: [PATCH 0940/1168] [SPARK-12005][SQL] Work around VerifyError in HyperLogLogPlusPlus. Just move the code around a bit; that seems to make the JVM happy. Author: Marcelo Vanzin Closes #9985 from vanzin/SPARK-12005. --- .../expressions/aggregate/HyperLogLogPlusPlus.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 8a95c541f1e86..e1fd22e36764e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -63,11 +63,7 @@ case class HyperLogLogPlusPlus( def this(child: Expression, relativeSD: Expression) = { this( child = child, - relativeSD = relativeSD match { - case Literal(d: Double, DoubleType) => d - case _ => - throw new AnalysisException("The second argument should be a double literal.") - }, + relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD), mutableAggBufferOffset = 0, inputAggBufferOffset = 0) } @@ -448,4 +444,11 @@ object HyperLogLogPlusPlus { Array(189083, 185696.913, 182348.774, 179035.946, 175762.762, 172526.444, 169329.754, 166166.099, 163043.269, 159958.91, 156907.912, 153906.845, 150924.199, 147996.568, 145093.457, 142239.233, 139421.475, 136632.27, 133889.588, 131174.2, 128511.619, 125868.621, 123265.385, 120721.061, 118181.769, 115709.456, 113252.446, 110840.198, 108465.099, 106126.164, 103823.469, 101556.618, 99308.004, 97124.508, 94937.803, 92833.731, 90745.061, 88677.627, 86617.47, 84650.442, 82697.833, 80769.132, 78879.629, 77014.432, 75215.626, 73384.587, 71652.482, 69895.93, 68209.301, 66553.669, 64921.981, 63310.323, 61742.115, 60205.018, 58698.658, 57190.657, 55760.865, 54331.169, 52908.167, 51550.273, 50225.254, 48922.421, 47614.533, 46362.049, 45098.569, 43926.083, 42736.03, 41593.473, 40425.26, 39316.237, 38243.651, 37170.617, 36114.609, 35084.19, 34117.233, 33206.509, 32231.505, 31318.728, 30403.404, 29540.0550000001, 28679.236, 27825.862, 26965.216, 26179.148, 25462.08, 24645.952, 23922.523, 23198.144, 22529.128, 21762.4179999999, 21134.779, 20459.117, 19840.818, 19187.04, 18636.3689999999, 17982.831, 17439.7389999999, 16874.547, 16358.2169999999, 15835.684, 15352.914, 14823.681, 14329.313, 13816.897, 13342.874, 12880.882, 12491.648, 12021.254, 11625.392, 11293.7610000001, 10813.697, 10456.209, 10099.074, 9755.39000000001, 9393.18500000006, 9047.57900000003, 8657.98499999999, 8395.85900000005, 8033, 7736.95900000003, 7430.59699999995, 7258.47699999996, 6924.58200000005, 6691.29399999999, 6357.92500000005, 6202.05700000003, 5921.19700000004, 5628.28399999999, 5404.96799999999, 5226.71100000001, 4990.75600000005, 4799.77399999998, 4622.93099999998, 4472.478, 4171.78700000001, 3957.46299999999, 3868.95200000005, 3691.14300000004, 3474.63100000005, 3341.67200000002, 3109.14000000001, 3071.97400000005, 2796.40399999998, 2756.17799999996, 2611.46999999997, 2471.93000000005, 2382.26399999997, 2209.22400000005, 2142.28399999999, 2013.96100000001, 1911.18999999994, 1818.27099999995, 1668.47900000005, 1519.65800000005, 1469.67599999998, 1367.13800000004, 1248.52899999998, 1181.23600000003, 1022.71900000004, 1088.20700000005, 959.03600000008, 876.095999999903, 791.183999999892, 703.337000000058, 731.949999999953, 586.86400000006, 526.024999999907, 323.004999999888, 320.448000000091, 340.672999999952, 309.638999999966, 216.601999999955, 102.922999999952, 19.2399999999907, -0.114000000059605, -32.6240000000689, -89.3179999999702, -153.497999999905, -64.2970000000205, -143.695999999996, -259.497999999905, -253.017999999924, -213.948000000091, -397.590000000084, -434.006000000052, -403.475000000093, -297.958000000101, -404.317000000039, -528.898999999976, -506.621000000043, -513.205000000075, -479.351000000024, -596.139999999898, -527.016999999993, -664.681000000099, -680.306000000099, -704.050000000047, -850.486000000034, -757.43200000003, -713.308999999892) ) // scalastyle:on + + private def validateDoubleLiteral(exp: Expression): Double = exp match { + case Literal(d: Double, DoubleType) => d + case _ => + throw new AnalysisException("The second argument should be a double literal.") + } + } From 557167d868a039dec29573ff41406455c7a78c85 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 26 Nov 2015 11:31:28 -0800 Subject: [PATCH 0941/1168] [SPARK-11863][SQL] Unable to resolve order by if it contains mixture of aliases and real columns this is based on https://github.com/apache/spark/pull/9844, with some bug fix and clean up. The problems is that, normal operator should be resolved based on its child, but `Sort` operator can also be resolved based on its grandchild. So we have 3 rules that can resolve `Sort`: `ResolveReferences`, `ResolveSortReferences`(if grandchild is `Project`) and `ResolveAggregateFunctions`(if grandchild is `Aggregate`). For example, `select c1 as a , c2 as b from tab group by c1, c2 order by a, c2`, we need to resolve `a` and `c2` for `Sort`. Firstly `a` will be resolved in `ResolveReferences` based on its child, and when we reach `ResolveAggregateFunctions`, we will try to resolve both `a` and `c2` based on its grandchild, but failed because `a` is not a legal aggregate expression. whoever merge this PR, please give the credit to dilipbiswal Author: Dilip Biswal Author: Wenchen Fan Closes #9961 from cloud-fan/sort. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 13 ++++++++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 47962ebe6ef82..94ffbbb2e5c65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -630,7 +630,8 @@ class Analyzer( // Try resolving the ordering as though it is in the aggregate clause. try { - val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")()) + val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) + val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = @@ -663,13 +664,19 @@ class Analyzer( } } + val sortOrdersMap = unresolvedSortOrders + .map(new TreeNodeRef(_)) + .zip(evaluatedOrderings) + .toMap + val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) + // Since we don't rely on sort.resolved as the stop condition for this rule, // we need to check this and prevent applying this rule multiple times - if (sortOrder == evaluatedOrderings) { + if (sortOrder == finalSortOrders) { sort } else { Project(aggregate.output, - Sort(evaluatedOrderings, global, + Sort(finalSortOrders, global, aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) } } catch { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e051069951887..aeeca802d8bb3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -220,6 +220,24 @@ class AnalysisSuite extends AnalysisTest { // checkUDF(udf4, expected4) } + test("SPARK-11863 mixture of aliases and real columns in order by clause - tpcds 19,55,71") { + val a = testRelation2.output(0) + val c = testRelation2.output(2) + val alias1 = a.as("a1") + val alias2 = c.as("a2") + val alias3 = count(a).as("a3") + + val plan = testRelation2 + .groupBy('a, 'c)('a.as("a1"), 'c.as("a2"), count('a).as("a3")) + .orderBy('a1.asc, 'c.asc) + + val expected = testRelation2 + .groupBy(a, c)(alias1, alias2, alias3) + .orderBy(alias1.toAttribute.asc, alias2.toAttribute.asc) + .select(alias1.toAttribute, alias2.toAttribute, alias3.toAttribute) + checkAnalysis(plan, expected) + } + test("analyzer should replace current_timestamp with literals") { val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), LocalRelation()) From 3e0102cab31d88d89ea82ca022372983069df17d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 26 Nov 2015 16:20:08 -0800 Subject: [PATCH 0942/1168] [SPARK-11998][SQL][TEST-HADOOP2.0] When downloading Hadoop artifacts from maven, we need to try to download the version that is used by Spark If we need to download Hive/Hadoop artifacts, try to download a Hadoop that matches the Hadoop used by Spark. If the Hadoop artifact cannot be resolved (e.g. Hadoop version is a vendor specific version like 2.0.0-cdh4.1.1), we will use Hadoop 2.4.0 (we used to hard code this version as the hadoop that we will download from maven) and we will not share Hadoop classes. I tested this match in my laptop with the following confs (these confs are used by our builds). All tests are good. ``` build/sbt -Phadoop-1 -Dhadoop.version=1.2.1 -Pkinesis-asl -Phive-thriftserver -Phive build/sbt -Phadoop-1 -Dhadoop.version=2.0.0-mr1-cdh4.1.1 -Pkinesis-asl -Phive-thriftserver -Phive build/sbt -Pyarn -Phadoop-2.2 -Pkinesis-asl -Phive-thriftserver -Phive build/sbt -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Pkinesis-asl -Phive-thriftserver -Phive ``` Author: Yin Huai Closes #9979 from yhuai/versionsSuite. --- .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../hive/client/IsolatedClientLoader.scala | 62 +++++++++++++++---- .../spark/sql/hive/client/VersionsSuite.scala | 23 +++++-- 3 files changed, 72 insertions(+), 17 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 8a4264194ae8d..e83941c2ecf66 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} +import org.apache.hadoop.util.VersionInfo import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.SQLConf.SQLConfEntry @@ -288,7 +289,8 @@ class HiveContext private[hive]( logInfo( s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") IsolatedClientLoader.forVersion( - version = hiveMetastoreVersion, + hiveMetastoreVersion = hiveMetastoreVersion, + hadoopVersion = VersionInfo.getVersion, config = allConfig, barrierPrefixes = hiveMetastoreBarrierPrefixes, sharedPrefixes = hiveMetastoreSharedPrefixes) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index e041e0d8e5ae8..010051d255fdc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -34,23 +34,51 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.{MutableURLClassLoader, Utils} /** Factory for `IsolatedClientLoader` with specific versions of hive. */ -private[hive] object IsolatedClientLoader { +private[hive] object IsolatedClientLoader extends Logging { /** * Creates isolated Hive client loaders by downloading the requested version from maven. */ def forVersion( - version: String, + hiveMetastoreVersion: String, + hadoopVersion: String, config: Map[String, String] = Map.empty, ivyPath: Option[String] = None, sharedPrefixes: Seq[String] = Seq.empty, barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { - val resolvedVersion = hiveVersion(version) - val files = resolvedVersions.getOrElseUpdate(resolvedVersion, - downloadVersion(resolvedVersion, ivyPath)) + val resolvedVersion = hiveVersion(hiveMetastoreVersion) + // We will first try to share Hadoop classes. If we cannot resolve the Hadoop artifact + // with the given version, we will use Hadoop 2.4.0 and then will not share Hadoop classes. + var sharesHadoopClasses = true + val files = if (resolvedVersions.contains((resolvedVersion, hadoopVersion))) { + resolvedVersions((resolvedVersion, hadoopVersion)) + } else { + val (downloadedFiles, actualHadoopVersion) = + try { + (downloadVersion(resolvedVersion, hadoopVersion, ivyPath), hadoopVersion) + } catch { + case e: RuntimeException if e.getMessage.contains("hadoop") => + // If the error message contains hadoop, it is probably because the hadoop + // version cannot be resolved (e.g. it is a vendor specific version like + // 2.0.0-cdh4.1.1). If it is the case, we will try just + // "org.apache.hadoop:hadoop-client:2.4.0". "org.apache.hadoop:hadoop-client:2.4.0" + // is used just because we used to hard code it as the hadoop artifact to download. + logWarning(s"Failed to resolve Hadoop artifacts for the version ${hadoopVersion}. " + + s"We will change the hadoop version from ${hadoopVersion} to 2.4.0 and try again. " + + "Hadoop classes will not be shared between Spark and Hive metastore client. " + + "It is recommended to set jars used by Hive metastore client through " + + "spark.sql.hive.metastore.jars in the production environment.") + sharesHadoopClasses = false + (downloadVersion(resolvedVersion, "2.4.0", ivyPath), "2.4.0") + } + resolvedVersions.put((resolvedVersion, actualHadoopVersion), downloadedFiles) + resolvedVersions((resolvedVersion, actualHadoopVersion)) + } + new IsolatedClientLoader( - version = hiveVersion(version), + version = hiveVersion(hiveMetastoreVersion), execJars = files, config = config, + sharesHadoopClasses = sharesHadoopClasses, sharedPrefixes = sharedPrefixes, barrierPrefixes = barrierPrefixes) } @@ -64,12 +92,15 @@ private[hive] object IsolatedClientLoader { case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 } - private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { + private def downloadVersion( + version: HiveVersion, + hadoopVersion: String, + ivyPath: Option[String]): Seq[URL] = { val hiveArtifacts = version.extraDeps ++ Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") .map(a => s"org.apache.hive:$a:${version.fullVersion}") ++ Seq("com.google.guava:guava:14.0.1", - "org.apache.hadoop:hadoop-client:2.4.0") + s"org.apache.hadoop:hadoop-client:$hadoopVersion") val classpath = quietly { SparkSubmitUtils.resolveMavenCoordinates( @@ -86,7 +117,10 @@ private[hive] object IsolatedClientLoader { tempDir.listFiles().map(_.toURI.toURL) } - private def resolvedVersions = new scala.collection.mutable.HashMap[HiveVersion, Seq[URL]] + // A map from a given pair of HiveVersion and Hadoop version to jar files. + // It is only used by forVersion. + private val resolvedVersions = + new scala.collection.mutable.HashMap[(HiveVersion, String), Seq[URL]] } /** @@ -106,6 +140,7 @@ private[hive] object IsolatedClientLoader { * @param config A set of options that will be added to the HiveConf of the constructed client. * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be * true unless loading the version of hive that is on Sparks classloader. + * @param sharesHadoopClasses When true, we will share Hadoop classes between Spark and * @param rootClassLoader The system root classloader. Must not know about Hive classes. * @param baseClassLoader The spark classloader that is used to load shared classes. */ @@ -114,6 +149,7 @@ private[hive] class IsolatedClientLoader( val execJars: Seq[URL] = Seq.empty, val config: Map[String, String] = Map.empty, val isolationOn: Boolean = true, + val sharesHadoopClasses: Boolean = true, val rootClassLoader: ClassLoader = ClassLoader.getSystemClassLoader.getParent.getParent, val baseClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader, val sharedPrefixes: Seq[String] = Seq.empty, @@ -126,16 +162,20 @@ private[hive] class IsolatedClientLoader( /** All jars used by the hive specific classloader. */ protected def allJars = execJars.toArray - protected def isSharedClass(name: String): Boolean = + protected def isSharedClass(name: String): Boolean = { + val isHadoopClass = + name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.") + name.contains("slf4j") || name.contains("log4j") || name.startsWith("org.apache.spark.") || - (name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.")) || + (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || (name.startsWith("com.google") && !name.startsWith("com.google.cloud")) || name.startsWith("java.lang.") || name.startsWith("java.net") || sharedPrefixes.exists(name.startsWith) + } /** True if `name` refers to a spark class that must see specific version of Hive. */ protected def isBarrierClass(name: String): Boolean = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 7bc13bc60d30e..502b240f3650f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.client import java.io.File +import org.apache.hadoop.util.VersionInfo + import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} @@ -53,9 +55,11 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion(HiveContext.hiveExecutionVersion, - buildConf(), - ivyPath).createClient() + val badClient = IsolatedClientLoader.forVersion( + hiveMetastoreVersion = HiveContext.hiveExecutionVersion, + hadoopVersion = VersionInfo.getVersion, + config = buildConf(), + ivyPath = ivyPath).createClient() val db = new HiveDatabase("default", "") badClient.createDatabase(db) } @@ -85,7 +89,11 @@ class VersionsSuite extends SparkFunSuite with Logging { ignore("failure sanity check") { val e = intercept[Throwable] { val badClient = quietly { - IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).createClient() + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = "13", + hadoopVersion = VersionInfo.getVersion, + config = buildConf(), + ivyPath = ivyPath).createClient() } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") @@ -99,7 +107,12 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: create client") { client = null System.gc() // Hack to avoid SEGV on some JVM versions. - client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).createClient() + client = + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = version, + hadoopVersion = VersionInfo.getVersion, + config = buildConf(), + ivyPath = ivyPath).createClient() } test(s"$version: createDatabase") { From 343f9d7055f55ad0b38f0d3754036b1851704190 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 26 Nov 2015 18:47:54 -0800 Subject: [PATCH 0943/1168] [SPARK-11973][SQL] Improve optimizer code readability. This is a followup for https://github.com/apache/spark/pull/9959. I added more documentation and rewrote some monadic code into simpler ifs. Author: Reynold Xin Closes #9995 from rxin/SPARK-11973. --- .../sql/catalyst/optimizer/Optimizer.scala | 50 +++++++++---------- .../optimizer/FilterPushdownSuite.scala | 2 +- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 52f609bc158ca..2901d8f2efddf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -59,7 +59,7 @@ object DefaultOptimizer extends Optimizer { ConstantFolding, LikeSimplification, BooleanSimplification, - RemoveDispensable, + RemoveDispensableExpressions, SimplifyFilters, SimplifyCasts, SimplifyCaseConversionExpressions) :: @@ -660,14 +660,14 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp case filter @ Filter(condition, g: Generate) => // Predicates that reference attributes produced by the `Generate` operator cannot // be pushed below the operator. - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { - conjunct => conjunct.references subsetOf g.child.outputSet + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + cond.references subsetOf g.child.outputSet } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) - val withPushdown = Generate(g.generator, join = g.join, outer = g.outer, + val newGenerate = Generate(g.generator, join = g.join, outer = g.outer, g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) - stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) + if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate) } else { filter } @@ -675,34 +675,34 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp } /** - * Push [[Filter]] operators through [[Aggregate]] operators. Parts of the predicate that reference - * attributes which are subset of group by attribute set of [[Aggregate]] will be pushed beneath, - * and the rest should remain above. + * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only + * non-aggregate attributes (typically literals or grouping expressions). */ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, - aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) => - - def hasAggregate(expression: Expression): Boolean = expression match { - case agg: AggregateExpression => true - case other => expression.children.exists(hasAggregate) - } - // Create a map of Alias for expressions that does not have AggregateExpression - val aliasMap = AttributeMap(aggregateExpressions.collect { - case a: Alias if !hasAggregate(a.child) => (a.toAttribute, a.child) + case filter @ Filter(condition, aggregate: Aggregate) => + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + (a.toAttribute, a.child) }) - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { conjunct => - val replaced = replaceAlias(conjunct, aliasMap) - replaced.references.subsetOf(grandChild.outputSet) && replaced.deterministic + // For each filter, expand the alias and check if the filter can be evaluated using + // attributes produced by the aggregate operator's child operator. + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + val replaced = replaceAlias(cond, aliasMap) + replaced.references.subsetOf(aggregate.child.outputSet) && replaced.deterministic } + if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) val replaced = replaceAlias(pushDownPredicate, aliasMap) - val withPushdown = aggregate.copy(child = Filter(replaced, grandChild)) - stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) + val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child)) + // If there is no more filter to stay up, just eliminate the filter. + // Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)". + if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) } else { filter } @@ -714,7 +714,7 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel * evaluated using only the attributes of the left or right side of a join. Other * [[Filter]] conditions are moved into the `condition` of the [[Join]]. * - * And also Pushes down the join filter, where the `condition` can be evaluated using only the + * And also pushes down the join filter, where the `condition` can be evaluated using only the * attributes of the left or right side of sub query when applicable. * * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details @@ -821,7 +821,7 @@ object SimplifyCasts extends Rule[LogicalPlan] { /** * Removes nodes that are not necessary. */ -object RemoveDispensable extends Rule[LogicalPlan] { +object RemoveDispensableExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case UnaryPositive(child) => child case PromotePrecision(child) => child diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0128c220baaca..fba4c5ca77d64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -734,7 +734,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("aggregate: don't push down filters which is nondeterministic") { + test("aggregate: don't push down filters that are nondeterministic") { val originalQuery = testRelation .select('a, 'b) .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) From f3c3dd3b4bebbf257b5792793a9e8c39ec39a2bd Mon Sep 17 00:00:00 2001 From: muxator Date: Thu, 26 Nov 2015 18:52:20 -0800 Subject: [PATCH 0944/1168] doc typo: "classificaion" -> "classification" Author: muxator Closes #10008 from muxator/patch-1. --- docs/mllib-linear-methods.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 0c76e6e999465..132f8c354aa9c 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -122,7 +122,7 @@ Under the hood, linear methods use convex optimization methods to optimize the o [Classification](http://en.wikipedia.org/wiki/Statistical_classification) aims to divide items into categories. The most common classification type is -[binary classificaion](http://en.wikipedia.org/wiki/Binary_classification), where there are two +[binary classification](http://en.wikipedia.org/wiki/Binary_classification), where there are two categories, usually named positive and negative. If there are more than two categories, it is called [multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification). From 4b7a494e8dea39d12cba741736ba3c6ad67e533c Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 26 Nov 2015 18:56:22 -0800 Subject: [PATCH 0945/1168] [SPARK-11996][CORE] Make the executor thread dump work again In the previous implementation, the driver needs to know the executor listening address to send the thread dump request. However, in Netty RPC, the executor doesn't listen to any port, so the executor thread dump feature is broken. This patch makes the driver use the endpointRef stored in BlockManagerMasterEndpoint to send the thread dump request to fix it. Author: Shixiong Zhu Closes #9976 from zsxwing/executor-thread-dump. --- .../scala/org/apache/spark/SparkContext.scala | 10 ++--- .../org/apache/spark/executor/Executor.scala | 5 --- .../spark/executor/ExecutorEndpoint.scala | 43 ------------------- .../spark/storage/BlockManagerMaster.scala | 4 +- .../storage/BlockManagerMasterEndpoint.scala | 12 +++--- .../spark/storage/BlockManagerMessages.scala | 7 ++- .../storage/BlockManagerSlaveEndpoint.scala | 7 ++- project/MimaExcludes.scala | 8 ++++ 8 files changed, 29 insertions(+), 67 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 2c10779f2b893..b030d3c71dc20 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -48,20 +48,20 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.io.CompressionCodec import org.apache.spark.metrics.MetricsSystem import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ +import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{SparkUI, ConsoleProgressBar} import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.util._ @@ -619,11 +619,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (executorId == SparkContext.DRIVER_IDENTIFIER) { Some(Utils.getThreadDump()) } else { - val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get - val endpointRef = env.rpcEnv.setupEndpointRef( - SparkEnv.executorActorSystemName, - RpcAddress(host, port), - ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) + val endpointRef = env.blockManager.master.getExecutorEndpointRef(executorId).get Some(endpointRef.askWithRetry[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9e88d488c0379..6154f06e3ac11 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -85,10 +85,6 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) } - // Create an RpcEndpoint for receiving RPCs from the driver - private val executorEndpoint = env.rpcEnv.setupEndpoint( - ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) - // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) @@ -136,7 +132,6 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() - env.rpcEnv.stop(executorEndpoint) heartbeater.shutdown() heartbeater.awaitTermination(10, TimeUnit.SECONDS) threadPool.shutdown() diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala deleted file mode 100644 index cf362f8464735..0000000000000 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} -import org.apache.spark.util.Utils - -/** - * Driver -> Executor message to trigger a thread dump. - */ -private[spark] case object TriggerThreadDump - -/** - * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC. - */ -private[spark] -class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case TriggerThreadDump => - context.reply(Utils.getThreadDump()) - } - -} - -object ExecutorEndpoint { - val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint" -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f45bff34d4dbc..440c4c18aadd0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -87,8 +87,8 @@ class BlockManagerMaster( driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId)) } - def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { - driverEndpoint.askWithRetry[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) + def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { + driverEndpoint.askWithRetry[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId)) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 7db6035553ae6..41892b4ffce5b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,7 +19,6 @@ package org.apache.spark.storage import java.util.{HashMap => JHashMap} -import scala.collection.immutable.HashSet import scala.collection.mutable import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} @@ -75,8 +74,8 @@ class BlockManagerMasterEndpoint( case GetPeers(blockManagerId) => context.reply(getPeers(blockManagerId)) - case GetRpcHostPortForExecutor(executorId) => - context.reply(getRpcHostPortForExecutor(executorId)) + case GetExecutorEndpointRef(executorId) => + context.reply(getExecutorEndpointRef(executorId)) case GetMemoryStatus => context.reply(memoryStatus) @@ -388,15 +387,14 @@ class BlockManagerMasterEndpoint( } /** - * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its - * [[BlockManagerSlaveEndpoint]]. + * Returns an [[RpcEndpointRef]] of the [[BlockManagerSlaveEndpoint]] for sending RPC messages. */ - private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + private def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); info <- blockManagerInfo.get(blockManagerId) ) yield { - (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port) + info.slaveEndpoint } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 376e9eb48843d..f392a4a0cd9be 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -42,6 +42,11 @@ private[spark] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave + /** + * Driver -> Executor message to trigger a thread dump. + */ + case object TriggerThreadDump extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// @@ -90,7 +95,7 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class GetExecutorEndpointRef(executorId: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index e749631bf6f19..9eca902f7454e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -19,10 +19,10 @@ package org.apache.spark.storage import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext, RpcEndpoint} -import org.apache.spark.util.ThreadUtils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.{ThreadUtils, Utils} /** * An RpcEndpoint to take commands from the master to execute options. For example, @@ -70,6 +70,9 @@ class BlockManagerSlaveEndpoint( case GetMatchingBlockIds(filter, _) => context.reply(blockManager.getMatchingBlockIds(filter)) + + case TriggerThreadDump => + context.reply(Utils.getThreadDump()) } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 54a9ad956d119..566bfe8efb7a4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -147,6 +147,14 @@ object MimaExcludes { // SPARK-4557 Changed foreachRDD to use VoidFunction ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") + ) ++ Seq( + // SPARK-11996 Make the executor thread dump work again + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor$") ) case v if v.startsWith("1.5") => Seq( From 26bc2eebdb95572ea64bae1e66d2e46ee2b8f097 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 26 Nov 2015 19:00:36 -0800 Subject: [PATCH 0946/1168] [SPARK-12011][SQL] Stddev/Variance etc should support columnName as arguments Spark SQL aggregate function: ```Java stddev stddev_pop stddev_samp variance var_pop var_samp skewness kurtosis collect_list collect_set ``` should support ```columnName``` as arguments like other aggregate function(max/min/count/sum). Author: Yanbo Liang Closes #9994 from yanboliang/SPARK-12011. --- .../org/apache/spark/sql/functions.scala | 86 +++++++++++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 3 + 2 files changed, 89 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 276c5dfc8b062..e79defbbbdeea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -214,6 +214,16 @@ object functions extends LegacyFunctions { */ def collect_list(e: Column): Column = callUDF("collect_list", e) + /** + * Aggregate function: returns a list of objects with duplicates. + * + * For now this is an alias for the collect_list Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_list(columnName: String): Column = collect_list(Column(columnName)) + /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * @@ -224,6 +234,16 @@ object functions extends LegacyFunctions { */ def collect_set(e: Column): Column = callUDF("collect_set", e) + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * + * For now this is an alias for the collect_set Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_set(columnName: String): Column = collect_set(Column(columnName)) + /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. * @@ -312,6 +332,14 @@ object functions extends LegacyFunctions { */ def kurtosis(e: Column): Column = withAggregateFunction { Kurtosis(e.expr) } + /** + * Aggregate function: returns the kurtosis of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) + /** * Aggregate function: returns the last value in a group. * @@ -386,6 +414,14 @@ object functions extends LegacyFunctions { */ def skewness(e: Column): Column = withAggregateFunction { Skewness(e.expr) } + /** + * Aggregate function: returns the skewness of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def skewness(columnName: String): Column = skewness(Column(columnName)) + /** * Aggregate function: alias for [[stddev_samp]]. * @@ -394,6 +430,14 @@ object functions extends LegacyFunctions { */ def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + /** + * Aggregate function: alias for [[stddev_samp]]. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev(columnName: String): Column = stddev(Column(columnName)) + /** * Aggregate function: returns the sample standard deviation of * the expression in a group. @@ -403,6 +447,15 @@ object functions extends LegacyFunctions { */ def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + /** + * Aggregate function: returns the sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName)) + /** * Aggregate function: returns the population standard deviation of * the expression in a group. @@ -412,6 +465,15 @@ object functions extends LegacyFunctions { */ def stddev_pop(e: Column): Column = withAggregateFunction { StddevPop(e.expr) } + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName)) + /** * Aggregate function: returns the sum of all values in the expression. * @@ -452,6 +514,14 @@ object functions extends LegacyFunctions { */ def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } + /** + * Aggregate function: alias for [[var_samp]]. + * + * @group agg_funcs + * @since 1.6.0 + */ + def variance(columnName: String): Column = variance(Column(columnName)) + /** * Aggregate function: returns the unbiased variance of the values in a group. * @@ -460,6 +530,14 @@ object functions extends LegacyFunctions { */ def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_samp(columnName: String): Column = var_samp(Column(columnName)) + /** * Aggregate function: returns the population variance of the values in a group. * @@ -468,6 +546,14 @@ object functions extends LegacyFunctions { */ def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) } + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_pop(columnName: String): Column = var_pop(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 9c42f65bb6f52..b5c636d0de1d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -261,6 +261,9 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)), Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) + checkAnswer( + testData2.agg(stddev("a"), stddev_pop("a"), stddev_samp("a")), + Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) } test("zero stddev") { From 2534c7cbc160842ab63ac90dc2f042511749bc17 Mon Sep 17 00:00:00 2001 From: mariusvniekerk Date: Thu, 26 Nov 2015 19:13:16 -0800 Subject: [PATCH 0947/1168] [SPARK-11881][SQL] Fix for postgresql fetchsize > 0 Reference: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor In order for PostgreSQL to honor the fetchSize non-zero setting, its Connection.autoCommit needs to be set to false. Otherwise, it will just quietly ignore the fetchSize setting. This adds a new side-effecting dialect specific beforeFetch method that will fire before a select query is ran. Author: mariusvniekerk Closes #9861 from mariusvniekerk/SPARK-11881. --- .../execution/datasources/jdbc/JDBCRDD.scala | 12 ++++++++++++ .../apache/spark/sql/jdbc/JdbcDialects.scala | 11 +++++++++++ .../apache/spark/sql/jdbc/PostgresDialect.scala | 17 ++++++++++++++++- 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 89c850ce238d7..f9b72597dd2a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -224,6 +224,7 @@ private[sql] object JDBCRDD extends Logging { quotedColumns, filters, parts, + url, properties) } } @@ -241,6 +242,7 @@ private[sql] class JDBCRDD( columns: Array[String], filters: Array[Filter], partitions: Array[Partition], + url: String, properties: Properties) extends RDD[InternalRow](sc, Nil) { @@ -361,6 +363,9 @@ private[sql] class JDBCRDD( context.addTaskCompletionListener{ context => close() } val part = thePart.asInstanceOf[JDBCPartition] val conn = getConnection() + val dialect = JdbcDialects.get(url) + import scala.collection.JavaConverters._ + dialect.beforeFetch(conn, properties.asScala.toMap) // H2's JDBC driver does not support the setSchema() method. We pass a // fully-qualified table name in the SELECT statement. I don't know how to @@ -489,6 +494,13 @@ private[sql] class JDBCRDD( } try { if (null != conn) { + if (!conn.getAutoCommit && !conn.isClosed) { + try { + conn.commit() + } catch { + case e: Throwable => logWarning("Exception committing transaction", e) + } + } conn.close() } logInfo("closed connection") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index b3b2cb6178c52..13db141f27db6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.jdbc +import java.sql.Connection + import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi @@ -97,6 +99,15 @@ abstract class JdbcDialect extends Serializable { s"SELECT * FROM $table WHERE 1=0" } + /** + * Override connection specific properties to run before a select is made. This is in place to + * allow dialects that need special treatment to optimize behavior. + * @param connection The connection object + * @param properties The connection properties. This is passed through from the relation. + */ + def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { + } + } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index ed3faa1268635..3cf80f576e92c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.{Connection, Types} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.types._ @@ -70,4 +70,19 @@ private object PostgresDialect extends JdbcDialect { override def getTableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } + + override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { + super.beforeFetch(connection, properties) + + // According to the postgres jdbc documentation we need to be in autocommit=false if we actually + // want to have fetchsize be non 0 (all the rows). This allows us to not have to cache all the + // rows inside the driver when fetching. + // + // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor + // + if (properties.getOrElse("fetchsize", "0").toInt > 0) { + connection.setAutoCommit(false) + } + + } } From c88c7608cc449f81a340a1c765381f9a1ee9ab63 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 26 Nov 2015 19:15:22 -0800 Subject: [PATCH 0948/1168] [SPARK-11917][PYSPARK] Add SQLContext#dropTempTable to PySpark Author: Jeff Zhang Closes #9903 from zjffdu/SPARK-11917. --- python/pyspark/sql/context.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index a49c1b58d0180..b05aa2f5c4cd7 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -445,6 +445,15 @@ def registerDataFrameAsTable(self, df, tableName): else: raise ValueError("Can only register DataFrame as table") + @since(1.6) + def dropTempTable(self, tableName): + """ Remove the temp table from catalog. + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> sqlContext.dropTempTable("table1") + """ + self._ssql_ctx.dropTempTable(tableName) + def parquetFile(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. From 6ed33135a456b8adb6f1ab8761695c921bcfe39e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 26 Nov 2015 19:17:46 -0800 Subject: [PATCH 0949/1168] [SPARK-11778][SQL] add regression test Fix regression test for SPARK-11778. marmbrus Could you please take a look? Thank you very much!! Author: Huaxin Gao Closes #9890 from huaxingao/spark-11778-regression-test. --- .../hive/HiveDataFrameAnalyticsSuite.scala | 10 ------ .../spark/sql/hive/HiveDataFrameSuite.scala | 32 +++++++++++++++++++ 2 files changed, 32 insertions(+), 10 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index f19a74d4b3724..9864acf765265 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -34,14 +34,10 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with override def beforeAll() { testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") - hiveContext.sql("create schema usrdb") - hiveContext.sql("create table usrdb.test(c1 int)") } override def afterAll(): Unit = { hiveContext.dropTempTable("mytable") - hiveContext.sql("drop table usrdb.test") - hiveContext.sql("drop schema usrdb") } test("rollup") { @@ -78,10 +74,4 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with sql("select a, b, sum(b) from mytable group by a, b with cube").collect() ) } - - // There was a bug in DataFrameFrameReader.table and it has problem for table with schema name, - // Before fix, it throw Exceptionorg.apache.spark.sql.catalyst.analysis.NoSuchTableException - test("table name with schema") { - hiveContext.read.table("usrdb.test") - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala new file mode 100644 index 0000000000000..7fdc5d71937ff --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.QueryTest + +class HiveDataFrameSuite extends QueryTest with TestHiveSingleton { + test("table name with schema") { + // regression test for SPARK-11778 + hiveContext.sql("create schema usrdb") + hiveContext.sql("create table usrdb.test(c int)") + hiveContext.read.table("usrdb.test") + hiveContext.sql("drop table usrdb.test") + hiveContext.sql("drop schema usrdb") + } +} From c342ef16554fe96ae715b14232ede9c694ea0a0d Mon Sep 17 00:00:00 2001 From: Jeremy Derr Date: Thu, 26 Nov 2015 19:25:13 -0800 Subject: [PATCH 0950/1168] [SPARK-11991] fixes If `--private-ips` is required but not provided, spark_ec2.py may behave inappropriately, including attempting to ssh to localhost in attempts to verify ssh connectivity to the cluster. This fixes that behavior by raising a `UsageError` exception if `get_dns_name` is unable to determine a hostname as a result. Author: Jeremy Derr Closes #9975 from jcderr/SPARK-11991/ec_spark.py_hostname_check. --- ec2/spark_ec2.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 9fd652a3df4c4..84a950c9f6529 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -1242,6 +1242,10 @@ def get_ip_address(instance, private_ips=False): def get_dns_name(instance, private_ips=False): dns = instance.public_dns_name if not private_ips else \ instance.private_ip_address + if not dns: + raise UsageError("Failed to determine hostname of {0}.\n" + "Please check that you provided --private-ips if " + "necessary".format(instance)) return dns From a84c0ef3153cda92a6f5d72d55d51e4f2bfde15d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 26 Nov 2015 19:36:43 -0800 Subject: [PATCH 0951/1168] Fix style violation for b63938a8b04 --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index f9b72597dd2a9..57a8a044a37cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} import java.util.Properties +import scala.util.control.NonFatal + import org.apache.commons.lang3.StringUtils import org.apache.spark.rdd.RDD @@ -498,7 +500,7 @@ private[sql] class JDBCRDD( try { conn.commit() } catch { - case e: Throwable => logWarning("Exception committing transaction", e) + case NonFatal(e) => logWarning("Exception committing transaction", e) } } conn.close() From 8b9e7c0164eb5ae684853027be5204d1ad0aa9f4 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 26 Nov 2015 21:04:40 -0800 Subject: [PATCH 0952/1168] [SPARK-11997] [SQL] NPE when save a DataFrame as parquet and partitioned by long column Check for partition column null-ability while building the partition spec. Author: Dilip Biswal Closes #10001 from dilipbiswal/spark-11997. --- .../org/apache/spark/sql/sources/interfaces.scala | 2 +- .../datasources/parquet/ParquetQuerySuite.scala | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index f9465157c936d..9ace25dc7d21b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -607,7 +607,7 @@ abstract class HadoopFsRelation private[sql]( def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => Cast( - Literal.create(row.getString(i), StringType), + Literal.create(row.getUTF8String(i), StringType), userProvidedSchema.fields(i).dataType).eval() }: _*) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 70fae32b7e7a1..f777e973052d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -252,6 +252,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + test("SPARK-11997 parquet with null partition values") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(1, 3) + .selectExpr("if(id % 2 = 0, null, id) AS n", "id") + .write.partitionBy("n").parquet(path) + + checkAnswer( + sqlContext.read.parquet(path).filter("n is null"), + Row(2, null)) + } + } + // This test case is ignored because of parquet-mr bug PARQUET-370 ignore("SPARK-10301 requested schema clipping - schemas with disjoint sets of fields") { withTempPath { dir => From fe0e447ed44137fd5f15f671afac8d43a6ed3f55 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 27 Nov 2015 11:48:01 -0800 Subject: [PATCH 0953/1168] [SPARK-12025][SPARKR] Rename some window rank function names for SparkR Change ```cumeDist -> cume_dist, denseRank -> dense_rank, percentRank -> percent_rank, rowNumber -> row_number``` at SparkR side. There are two reasons that we should make this change: * We should follow the [naming convention rule of R](http://www.inside-r.org/node/230645) * Spark DataFrame has deprecated the old convention (such as ```cumeDist```) and will remove it in Spark 2.0. It's better to fix this issue before 1.6 release, otherwise we will make breaking API change. cc shivaram sun-rui Author: Yanbo Liang Closes #10016 from yanboliang/SPARK-12025. --- R/pkg/NAMESPACE | 8 ++--- R/pkg/R/functions.R | 54 ++++++++++++++++---------------- R/pkg/R/generics.R | 16 +++++----- R/pkg/inst/tests/test_sparkSQL.R | 4 +-- 4 files changed, 41 insertions(+), 41 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 260c9edce62e0..5d04dd6acaab8 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -123,14 +123,14 @@ exportMethods("%in%", "count", "countDistinct", "crc32", - "cumeDist", + "cume_dist", "date_add", "date_format", "date_sub", "datediff", "dayofmonth", "dayofyear", - "denseRank", + "dense_rank", "desc", "endsWith", "exp", @@ -188,7 +188,7 @@ exportMethods("%in%", "next_day", "ntile", "otherwise", - "percentRank", + "percent_rank", "pmod", "quarter", "rand", @@ -200,7 +200,7 @@ exportMethods("%in%", "rint", "rlike", "round", - "rowNumber", + "row_number", "rpad", "rtrim", "second", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 25a1f22101494..e98e7a0117ca0 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2146,47 +2146,47 @@ setMethod("ifelse", ###################### Window functions###################### -#' cumeDist +#' cume_dist #' #' Window function: returns the cumulative distribution of values within a window partition, #' i.e. the fraction of rows that are below the current row. #' #' N = total number of rows in the partition -#' cumeDist(x) = number of values before (and including) x / N +#' cume_dist(x) = number of values before (and including) x / N #' #' This is equivalent to the CUME_DIST function in SQL. #' -#' @rdname cumeDist -#' @name cumeDist +#' @rdname cume_dist +#' @name cume_dist #' @family window_funcs #' @export -#' @examples \dontrun{cumeDist()} -setMethod("cumeDist", +#' @examples \dontrun{cume_dist()} +setMethod("cume_dist", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "cumeDist") + jc <- callJStatic("org.apache.spark.sql.functions", "cume_dist") column(jc) }) -#' denseRank +#' dense_rank #' #' Window function: returns the rank of rows within a window partition, without any gaps. -#' The difference between rank and denseRank is that denseRank leaves no gaps in ranking -#' sequence when there are ties. That is, if you were ranking a competition using denseRank +#' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking +#' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. #' #' This is equivalent to the DENSE_RANK function in SQL. #' -#' @rdname denseRank -#' @name denseRank +#' @rdname dense_rank +#' @name dense_rank #' @family window_funcs #' @export -#' @examples \dontrun{denseRank()} -setMethod("denseRank", +#' @examples \dontrun{dense_rank()} +setMethod("dense_rank", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "denseRank") + jc <- callJStatic("org.apache.spark.sql.functions", "dense_rank") column(jc) }) @@ -2264,7 +2264,7 @@ setMethod("ntile", column(jc) }) -#' percentRank +#' percent_rank #' #' Window function: returns the relative rank (i.e. percentile) of rows within a window partition. #' @@ -2274,15 +2274,15 @@ setMethod("ntile", #' #' This is equivalent to the PERCENT_RANK function in SQL. #' -#' @rdname percentRank -#' @name percentRank +#' @rdname percent_rank +#' @name percent_rank #' @family window_funcs #' @export -#' @examples \dontrun{percentRank()} -setMethod("percentRank", +#' @examples \dontrun{percent_rank()} +setMethod("percent_rank", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "percentRank") + jc <- callJStatic("org.apache.spark.sql.functions", "percent_rank") column(jc) }) @@ -2316,21 +2316,21 @@ setMethod("rank", base::rank(x, ...) }) -#' rowNumber +#' row_number #' #' Window function: returns a sequential number starting at 1 within a window partition. #' #' This is equivalent to the ROW_NUMBER function in SQL. #' -#' @rdname rowNumber -#' @name rowNumber +#' @rdname row_number +#' @name row_number #' @family window_funcs #' @export -#' @examples \dontrun{rowNumber()} -setMethod("rowNumber", +#' @examples \dontrun{row_number()} +setMethod("row_number", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "rowNumber") + jc <- callJStatic("org.apache.spark.sql.functions", "row_number") column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 1b3f10ea04643..0c305441e043e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -700,9 +700,9 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) -#' @rdname cumeDist +#' @rdname cume_dist #' @export -setGeneric("cumeDist", function(x) { standardGeneric("cumeDist") }) +setGeneric("cume_dist", function(x) { standardGeneric("cume_dist") }) #' @rdname datediff #' @export @@ -728,9 +728,9 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @export setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) -#' @rdname denseRank +#' @rdname dense_rank #' @export -setGeneric("denseRank", function(x) { standardGeneric("denseRank") }) +setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") }) #' @rdname explode #' @export @@ -872,9 +872,9 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) -#' @rdname percentRank +#' @rdname percent_rank #' @export -setGeneric("percentRank", function(x) { standardGeneric("percentRank") }) +setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") }) #' @rdname pmod #' @export @@ -913,9 +913,9 @@ setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @export setGeneric("rint", function(x, ...) { standardGeneric("rint") }) -#' @rdname rowNumber +#' @rdname row_number #' @export -setGeneric("rowNumber", function(x) { standardGeneric("rowNumber") }) +setGeneric("row_number", function(x) { standardGeneric("row_number") }) #' @rdname rpad #' @export diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 3f4f319fe745d..0fbe0658265b2 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -861,8 +861,8 @@ test_that("column functions", { c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) c12 <- variance(c) c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) - c14 <- cumeDist() + ntile(1) - c15 <- denseRank() + percentRank() + rank() + rowNumber() + c14 <- cume_dist() + ntile(1) + c15 <- dense_rank() + percent_rank() + rank() + row_number() # Test if base::rank() is exposed expect_equal(class(rank())[[1]], "Column") From a6ee8a05d2013756ce1b88f5039b47bf8a06976e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 27 Nov 2015 11:50:18 -0800 Subject: [PATCH 0954/1168] [SPARK-12021][STREAMING][TESTS] Fix the potential dead-lock in StreamingListenerSuite In StreamingListenerSuite."don't call ssc.stop in listener", after the main thread calls `ssc.stop()`, `StreamingContextStoppingCollector` may call `ssc.stop()` in the listener bus thread, which is a dead-lock. This PR updated `StreamingContextStoppingCollector` to only call `ssc.stop()` in the first batch to avoid the dead-lock. Author: Shixiong Zhu Closes #10011 from zsxwing/fix-test-deadlock. --- .../streaming/StreamingListenerSuite.scala | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index df4575ab25aad..04cd5bdc26be2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -222,7 +222,11 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { val batchCounter = new BatchCounter(_ssc) _ssc.start() // Make sure running at least one batch - batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000) + if (!batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000)) { + fail("The first batch cannot complete in 10 seconds") + } + // When reaching here, we can make sure `StreamingContextStoppingCollector` won't call + // `ssc.stop()`, so it's safe to call `_ssc.stop()` now. _ssc.stop() assert(contextStoppingCollector.sparkExSeen) } @@ -345,12 +349,21 @@ class FailureReasonsCollector extends StreamingListener { */ class StreamingContextStoppingCollector(val ssc: StreamingContext) extends StreamingListener { @volatile var sparkExSeen = false + + private var isFirstBatch = true + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { - try { - ssc.stop() - } catch { - case se: SparkException => - sparkExSeen = true + if (isFirstBatch) { + // We should only call `ssc.stop()` in the first batch. Otherwise, it's possible that the main + // thread is calling `ssc.stop()`, while StreamingContextStoppingCollector is also calling + // `ssc.stop()` in the listener thread, which becomes a dead-lock. + isFirstBatch = false + try { + ssc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } } } } From 67b921ef8b39e79ab3acffe040fc50c88a5ec808 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 27 Nov 2015 15:11:13 -0800 Subject: [PATCH 0955/1168] [SPARK-12020][TESTS][TEST-HADOOP2.0] PR builder cannot trigger hadoop 2.0 test https://issues.apache.org/jira/browse/SPARK-12020 Author: Yin Huai Closes #10010 from yhuai/SPARK-12020. --- dev/run-tests-jenkins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 623004310e189..4f390ef1eaa32 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -164,7 +164,7 @@ def main(): # Switch the Hadoop profile based on the PR title: if "test-hadoop1.0" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop1.0" - if "test-hadoop2.2" in ghprb_pull_title: + if "test-hadoop2.0" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.0" if "test-hadoop2.2" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.2" From 131352bfee7840858f3a58103d0654593ea93dd0 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 27 Nov 2015 22:44:08 -0800 Subject: [PATCH 0956/1168] [SPARK-12028] [SQL] get_json_object returns an incorrect result when the value is null literals When calling `get_json_object` for the following two cases, both results are `"null"`: ```scala val tuple: Seq[(String, String)] = ("5", """{"f1": null}""") :: Nil val df: DataFrame = tuple.toDF("key", "jstring") val res = df.select(functions.get_json_object($"jstring", "$.f1")).collect() ``` ```scala val tuple2: Seq[(String, String)] = ("5", """{"f1": "null"}""") :: Nil val df2: DataFrame = tuple2.toDF("key", "jstring") val res3 = df2.select(functions.get_json_object($"jstring", "$.f1")).collect() ``` Fixed the problem and also added a test case. Author: gatorsmile Closes #10018 from gatorsmile/get_json_object. --- .../expressions/jsonExpressions.scala | 7 +++++-- .../apache/spark/sql/JsonFunctionsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 8cd73236a7876..4991b9cb54e5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -298,8 +298,11 @@ case class GetJsonObject(json: Expression, path: Expression) case (FIELD_NAME, Named(name) :: xs) if p.getCurrentName == name => // exact field match - p.nextToken() - evaluatePath(p, g, style, xs) + if (p.nextToken() != JsonToken.VALUE_NULL) { + evaluatePath(p, g, style, xs) + } else { + false + } case (FIELD_NAME, Wildcard :: xs) => // wildcard field match diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 14fd56fc8c222..1f384edf321b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -39,6 +39,26 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { ("6", "[invalid JSON string]") :: Nil + test("function get_json_object - null") { + val df: DataFrame = tuples.toDF("key", "jstring") + val expected = + Row("1", "value1", "value2", "3", null, "5.23") :: + Row("2", "value12", "2", "value3", "4.01", null) :: + Row("3", "value13", "2", "value33", "value44", "5.01") :: + Row("4", null, null, null, null, null) :: + Row("5", "", null, null, null, null) :: + Row("6", null, null, null, null, null) :: + Nil + + checkAnswer( + df.select($"key", functions.get_json_object($"jstring", "$.f1"), + functions.get_json_object($"jstring", "$.f2"), + functions.get_json_object($"jstring", "$.f3"), + functions.get_json_object($"jstring", "$.f4"), + functions.get_json_object($"jstring", "$.f5")), + expected) + } + test("json_tuple select") { val df: DataFrame = tuples.toDF("key", "jstring") val expected = From 5e4c4a848505214f7616fed044eac4b7b2331a47 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sat, 28 Nov 2015 21:02:05 -0800 Subject: [PATCH 0957/1168] [SPARK-12029][SPARKR] Improve column functions signature, param check, tests, fix doc and add examples shivaram sun-rui Author: felixcheung Closes #10019 from felixcheung/rfunctionsdoc. --- R/pkg/R/functions.R | 121 +++++++++++++++++++++++-------- R/pkg/inst/tests/test_sparkSQL.R | 9 ++- 2 files changed, 96 insertions(+), 34 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index e98e7a0117ca0..b30331c61c9a7 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -878,7 +878,7 @@ setMethod("rtrim", #'} setMethod("sd", signature(x = "Column"), - function(x, na.rm = FALSE) { + function(x) { # In R, sample standard deviation is calculated with the sd() function. stddev_samp(x) }) @@ -1250,7 +1250,7 @@ setMethod("upper", #'} setMethod("var", signature(x = "Column"), - function(x, y = NULL, na.rm = FALSE, use) { + function(x) { # In R, sample variance is calculated with the var() function. var_samp(x) }) @@ -1467,6 +1467,7 @@ setMethod("pmod", signature(y = "Column"), #' @name approxCountDistinct #' @return the approximate number of distinct items in a group. #' @export +#' @examples \dontrun{approxCountDistinct(df$c, 0.02)} setMethod("approxCountDistinct", signature(x = "Column"), function(x, rsd = 0.05) { @@ -1481,14 +1482,16 @@ setMethod("approxCountDistinct", #' @name countDistinct #' @return the number of distinct items in a group. #' @export +#' @examples \dontrun{countDistinct(df$c)} setMethod("countDistinct", signature(x = "Column"), function(x, ...) { - jcol <- lapply(list(...), function (x) { + jcols <- lapply(list(...), function (x) { + stopifnot(class(x) == "Column") x@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - jcol) + jcols) column(jc) }) @@ -1501,10 +1504,14 @@ setMethod("countDistinct", #' @rdname concat #' @name concat #' @export +#' @examples \dontrun{concat(df$strings, df$strings2)} setMethod("concat", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function(x) { x@jc }) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols) column(jc) }) @@ -1518,11 +1525,15 @@ setMethod("concat", #' @rdname greatest #' @name greatest #' @export +#' @examples \dontrun{greatest(df$c, df$d)} setMethod("greatest", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function(x) { x@jc }) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols) column(jc) }) @@ -1530,17 +1541,21 @@ setMethod("greatest", #' least #' #' Returns the least value of the list of column names, skipping null values. -#' This function takes at least 2 parameters. It will return null iff all parameters are null. +#' This function takes at least 2 parameters. It will return null if all parameters are null. #' #' @family normal_funcs #' @rdname least #' @name least #' @export +#' @examples \dontrun{least(df$c, df$d)} setMethod("least", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function(x) { x@jc }) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols) column(jc) }) @@ -1549,11 +1564,10 @@ setMethod("least", #' #' Computes the ceiling of the given value. #' -#' @family math_funcs #' @rdname ceil -#' @name ceil -#' @aliases ceil +#' @name ceiling #' @export +#' @examples \dontrun{ceiling(df$c)} setMethod("ceiling", signature(x = "Column"), function(x) { @@ -1564,11 +1578,10 @@ setMethod("ceiling", #' #' Computes the signum of the given value. #' -#' @family math_funcs #' @rdname signum -#' @name signum -#' @aliases signum +#' @name sign #' @export +#' @examples \dontrun{sign(df$c)} setMethod("sign", signature(x = "Column"), function(x) { signum(x) @@ -1578,11 +1591,10 @@ setMethod("sign", signature(x = "Column"), #' #' Aggregate function: returns the number of distinct items in a group. #' -#' @family agg_funcs #' @rdname countDistinct -#' @name countDistinct -#' @aliases countDistinct +#' @name n_distinct #' @export +#' @examples \dontrun{n_distinct(df$c)} setMethod("n_distinct", signature(x = "Column"), function(x, ...) { countDistinct(x, ...) @@ -1592,11 +1604,10 @@ setMethod("n_distinct", signature(x = "Column"), #' #' Aggregate function: returns the number of items in a group. #' -#' @family agg_funcs #' @rdname count -#' @name count -#' @aliases count +#' @name n #' @export +#' @examples \dontrun{n(df$c)} setMethod("n", signature(x = "Column"), function(x) { count(x) @@ -1617,6 +1628,7 @@ setMethod("n", signature(x = "Column"), #' @rdname date_format #' @name date_format #' @export +#' @examples \dontrun{date_format(df$t, 'MM/dd/yyy')} setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_format", y@jc, x) @@ -1631,6 +1643,7 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' @rdname from_utc_timestamp #' @name from_utc_timestamp #' @export +#' @examples \dontrun{from_utc_timestamp(df$t, 'PST')} setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "from_utc_timestamp", y@jc, x) @@ -1649,6 +1662,7 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' @rdname instr #' @name instr #' @export +#' @examples \dontrun{instr(df$c, 'b')} setMethod("instr", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "instr", y@jc, x) @@ -1663,13 +1677,18 @@ setMethod("instr", signature(y = "Column", x = "character"), #' For example, \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first #' Sunday after 2015-07-27. #' -#' Day of the week parameter is case insensitive, and accepts: +#' Day of the week parameter is case insensitive, and accepts first three or two characters: #' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". #' #' @family datetime_funcs #' @rdname next_day #' @name next_day #' @export +#' @examples +#'\dontrun{ +#'next_day(df$d, 'Sun') +#'next_day(df$d, 'Sunday') +#'} setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "next_day", y@jc, x) @@ -1684,6 +1703,7 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' @rdname to_utc_timestamp #' @name to_utc_timestamp #' @export +#' @examples \dontrun{to_utc_timestamp(df$t, 'PST')} setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "to_utc_timestamp", y@jc, x) @@ -1697,8 +1717,8 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' @name add_months #' @family datetime_funcs #' @rdname add_months -#' @name add_months #' @export +#' @examples \dontrun{add_months(df$d, 1)} setMethod("add_months", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "add_months", y@jc, as.integer(x)) @@ -1713,6 +1733,7 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' @rdname date_add #' @name date_add #' @export +#' @examples \dontrun{date_add(df$d, 1)} setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_add", y@jc, as.integer(x)) @@ -1727,6 +1748,7 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' @rdname date_sub #' @name date_sub #' @export +#' @examples \dontrun{date_sub(df$d, 1)} setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_sub", y@jc, as.integer(x)) @@ -1735,16 +1757,19 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' format_number #' -#' Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, +#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places, #' and returns the result as a string column. #' -#' If d is 0, the result has no decimal point or fractional part. -#' If d < 0, the result will be null.' +#' If x is 0, the result has no decimal point or fractional part. +#' If x < 0, the result will be null. #' +#' @param y column to format +#' @param x number of decimal place to format to #' @family string_funcs #' @rdname format_number #' @name format_number #' @export +#' @examples \dontrun{format_number(df$n, 4)} setMethod("format_number", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1764,6 +1789,7 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), #' @rdname sha2 #' @name sha2 #' @export +#' @examples \dontrun{sha2(df$c, 256)} setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "sha2", y@jc, as.integer(x)) @@ -1779,6 +1805,7 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' @rdname shiftLeft #' @name shiftLeft #' @export +#' @examples \dontrun{shiftLeft(df$c, 1)} setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1796,6 +1823,7 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' @rdname shiftRight #' @name shiftRight #' @export +#' @examples \dontrun{shiftRight(df$c, 1)} setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1813,6 +1841,7 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' @rdname shiftRightUnsigned #' @name shiftRightUnsigned #' @export +#' @examples \dontrun{shiftRightUnsigned(df$c, 1)} setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1830,6 +1859,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @rdname concat_ws #' @name concat_ws #' @export +#' @examples \dontrun{concat_ws('-', df$s, df$d)} setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { jcols <- lapply(list(x, ...), function(x) { x@jc }) @@ -1845,6 +1875,7 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), #' @rdname conv #' @name conv #' @export +#' @examples \dontrun{conv(df$n, 2, 16)} setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { fromBase <- as.integer(fromBase) @@ -1864,6 +1895,7 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' @rdname expr #' @name expr #' @export +#' @examples \dontrun{expr('length(name)')} setMethod("expr", signature(x = "character"), function(x) { jc <- callJStatic("org.apache.spark.sql.functions", "expr", x) @@ -1878,6 +1910,7 @@ setMethod("expr", signature(x = "character"), #' @rdname format_string #' @name format_string #' @export +#' @examples \dontrun{format_string('%d %s', df$a, df$b)} setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { jcols <- lapply(list(x, ...), function(arg) { arg@jc }) @@ -1897,6 +1930,11 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' @rdname from_unixtime #' @name from_unixtime #' @export +#' @examples +#'\dontrun{ +#'from_unixtime(df$t) +#'from_unixtime(df$t, 'yyyy/MM/dd HH') +#'} setMethod("from_unixtime", signature(x = "Column"), function(x, format = "yyyy-MM-dd HH:mm:ss") { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1915,6 +1953,7 @@ setMethod("from_unixtime", signature(x = "Column"), #' @rdname locate #' @name locate #' @export +#' @examples \dontrun{locate('b', df$c, 1)} setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 0) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1931,6 +1970,7 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' @rdname lpad #' @name lpad #' @export +#' @examples \dontrun{lpad(df$c, 6, '#')} setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1947,12 +1987,13 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @rdname rand #' @name rand #' @export +#' @examples \dontrun{rand()} setMethod("rand", signature(seed = "missing"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "rand") column(jc) }) -#' @family normal_funcs + #' @rdname rand #' @name rand #' @export @@ -1970,12 +2011,13 @@ setMethod("rand", signature(seed = "numeric"), #' @rdname randn #' @name randn #' @export +#' @examples \dontrun{randn()} setMethod("randn", signature(seed = "missing"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "randn") column(jc) }) -#' @family normal_funcs + #' @rdname randn #' @name randn #' @export @@ -1993,6 +2035,7 @@ setMethod("randn", signature(seed = "numeric"), #' @rdname regexp_extract #' @name regexp_extract #' @export +#' @examples \dontrun{regexp_extract(df$c, '(\d+)-(\d+)', 1)} setMethod("regexp_extract", signature(x = "Column", pattern = "character", idx = "numeric"), function(x, pattern, idx) { @@ -2010,6 +2053,7 @@ setMethod("regexp_extract", #' @rdname regexp_replace #' @name regexp_replace #' @export +#' @examples \dontrun{regexp_replace(df$c, '(\\d+)', '--')} setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), function(x, pattern, replacement) { @@ -2027,6 +2071,7 @@ setMethod("regexp_replace", #' @rdname rpad #' @name rpad #' @export +#' @examples \dontrun{rpad(df$c, 6, '#')} setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -2040,12 +2085,17 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' Returns the substring from string str before count occurrences of the delimiter delim. #' If count is positive, everything the left of the final delimiter (counting from left) is #' returned. If count is negative, every to the right of the final delimiter (counting from the -#' right) is returned. substring <- index performs a case-sensitive match when searching for delim. +#' right) is returned. substring_index performs a case-sensitive match when searching for delim. #' #' @family string_funcs #' @rdname substring_index #' @name substring_index #' @export +#' @examples +#'\dontrun{ +#'substring_index(df$c, '.', 2) +#'substring_index(df$c, '.', -1) +#'} setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), function(x, delim, count) { @@ -2066,6 +2116,7 @@ setMethod("substring_index", #' @rdname translate #' @name translate #' @export +#' @examples \dontrun{translate(df$c, 'rnlt', '123')} setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), function(x, matchingString, replaceString) { @@ -2082,12 +2133,18 @@ setMethod("translate", #' @rdname unix_timestamp #' @name unix_timestamp #' @export +#' @examples +#'\dontrun{ +#'unix_timestamp() +#'unix_timestamp(df$t) +#'unix_timestamp(df$t, 'yyyy-MM-dd HH') +#'} setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp") column(jc) }) -#' @family datetime_funcs + #' @rdname unix_timestamp #' @name unix_timestamp #' @export @@ -2096,7 +2153,7 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"), jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc) column(jc) }) -#' @family datetime_funcs + #' @rdname unix_timestamp #' @name unix_timestamp #' @export @@ -2113,7 +2170,9 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' @family normal_funcs #' @rdname when #' @name when +#' @seealso \link{ifelse} #' @export +#' @examples \dontrun{when(df$age == 2, df$age + 1)} setMethod("when", signature(condition = "Column", value = "ANY"), function(condition, value) { condition <- condition@jc @@ -2130,7 +2189,9 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @family normal_funcs #' @rdname ifelse #' @name ifelse +#' @seealso \link{when} #' @export +#' @examples \dontrun{ifelse(df$a > 1 & df$b > 2, 0, 1)} setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), function(test, yes, no) { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 0fbe0658265b2..899fc3b977385 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -880,14 +880,15 @@ test_that("column functions", { expect_equal(collect(df3)[[2, 1]], FALSE) expect_equal(collect(df3)[[3, 1]], TRUE) - expect_equal(collect(select(df, sum(df$age)))[1, 1], 49) + df4 <- select(df, countDistinct(df$age, df$name)) + expect_equal(collect(df4)[[1, 1]], 2) + expect_equal(collect(select(df, sum(df$age)))[1, 1], 49) expect_true(abs(collect(select(df, stddev(df$age)))[1, 1] - 7.778175) < 1e-6) - expect_equal(collect(select(df, var_pop(df$age)))[1, 1], 30.25) - df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) - expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") + df5 <- createDataFrame(sqlContext, list(list(a = "010101"))) + expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") # Test array_contains() and sort_array() df <- createDataFrame(sqlContext, list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) From 3392386076ae05231f373e0d58772a9b8a79810e Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sat, 28 Nov 2015 21:16:21 -0800 Subject: [PATCH 0958/1168] [SPARK-9319][SPARKR] Add support for setting column names, types Add support for for colnames, colnames<-, coltypes<- Also added tests for names, names<- which have no test previously. I merged with PR 8984 (coltypes). Clicked the wrong thing, crewed up the PR. Recreated it here. Was #9218 shivaram sun-rui Author: felixcheung Closes #9654 from felixcheung/colnamescoltypes. --- R/pkg/NAMESPACE | 6 +- R/pkg/R/DataFrame.R | 166 ++++++++++++++++++++++--------- R/pkg/R/generics.R | 20 +++- R/pkg/R/types.R | 8 ++ R/pkg/inst/tests/test_sparkSQL.R | 40 +++++++- 5 files changed, 185 insertions(+), 55 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5d04dd6acaab8..43e5e0119e7fe 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -27,7 +27,10 @@ exportMethods("arrange", "attach", "cache", "collect", + "colnames", + "colnames<-", "coltypes", + "coltypes<-", "columns", "count", "cov", @@ -56,6 +59,7 @@ exportMethods("arrange", "mutate", "na.omit", "names", + "names<-", "ncol", "nrow", "orderBy", @@ -276,4 +280,4 @@ export("structField", "structType", "structType.jobj", "structType.structField", - "print.structType") \ No newline at end of file + "print.structType") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 8a13e7a36766d..f89e2682d9e29 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -254,6 +254,7 @@ setMethod("dtypes", #' @family DataFrame functions #' @rdname columns #' @name columns + #' @export #' @examples #'\dontrun{ @@ -262,6 +263,7 @@ setMethod("dtypes", #' path <- "path/to/file.json" #' df <- jsonFile(sqlContext, path) #' columns(df) +#' colnames(df) #'} setMethod("columns", signature(x = "DataFrame"), @@ -290,6 +292,121 @@ setMethod("names<-", } }) +#' @rdname columns +#' @name colnames +setMethod("colnames", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' @rdname columns +#' @name colnames<- +setMethod("colnames<-", + signature(x = "DataFrame", value = "character"), + function(x, value) { + sdf <- callJMethod(x@sdf, "toDF", as.list(value)) + dataFrame(sdf) + }) + +#' coltypes +#' +#' Get column types of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' @return value A character vector with the column types of the given DataFrame +#' @rdname coltypes +#' @name coltypes +#' @family DataFrame functions +#' @export +#' @examples +#'\dontrun{ +#' irisDF <- createDataFrame(sqlContext, iris) +#' coltypes(irisDF) +#'} +setMethod("coltypes", + signature(x = "DataFrame"), + function(x) { + # Get the data types of the DataFrame by invoking dtypes() function + types <- sapply(dtypes(x), function(x) {x[[2]]}) + + # Map Spark data types into R's data types using DATA_TYPES environment + rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { + # Check for primitive types + type <- PRIMITIVE_TYPES[[x]] + + if (is.null(type)) { + # Check for complex types + for (t in names(COMPLEX_TYPES)) { + if (substring(x, 1, nchar(t)) == t) { + type <- COMPLEX_TYPES[[t]] + break + } + } + + if (is.null(type)) { + stop(paste("Unsupported data type: ", x)) + } + } + type + }) + + # Find which types don't have mapping to R + naIndices <- which(is.na(rTypes)) + + # Assign the original scala data types to the unmatched ones + rTypes[naIndices] <- types[naIndices] + + rTypes + }) + +#' coltypes +#' +#' Set the column types of a DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @param value A character vector with the target column types for the given +#' DataFrame. Column types can be one of integer, numeric/double, character, logical, or NA +#' to keep that column as-is. +#' @rdname coltypes +#' @name coltypes<- +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' coltypes(df) <- c("character", "integer") +#' coltypes(df) <- c(NA, "numeric") +#'} +setMethod("coltypes<-", + signature(x = "DataFrame", value = "character"), + function(x, value) { + cols <- columns(x) + ncols <- length(cols) + if (length(value) == 0) { + stop("Cannot set types of an empty DataFrame with no Column") + } + if (length(value) != ncols) { + stop("Length of type vector should match the number of columns for DataFrame") + } + newCols <- lapply(seq_len(ncols), function(i) { + col <- getColumn(x, cols[i]) + if (!is.na(value[i])) { + stype <- rToSQLTypes[[value[i]]] + if (is.null(stype)) { + stop("Only atomic type is supported for column types") + } + cast(col, stype) + } else { + col + } + }) + nx <- select(x, newCols) + dataFrame(nx@sdf) + }) + #' Register Temporary Table #' #' Registers a DataFrame as a Temporary Table in the SQLContext @@ -2102,52 +2219,3 @@ setMethod("with", newEnv <- assignNewEnv(data) eval(substitute(expr), envir = newEnv, enclos = newEnv) }) - -#' Returns the column types of a DataFrame. -#' -#' @name coltypes -#' @title Get column types of a DataFrame -#' @family dataframe_funcs -#' @param x (DataFrame) -#' @return value (character) A character vector with the column types of the given DataFrame -#' @rdname coltypes -#' @examples \dontrun{ -#' irisDF <- createDataFrame(sqlContext, iris) -#' coltypes(irisDF) -#' } -setMethod("coltypes", - signature(x = "DataFrame"), - function(x) { - # Get the data types of the DataFrame by invoking dtypes() function - types <- sapply(dtypes(x), function(x) {x[[2]]}) - - # Map Spark data types into R's data types using DATA_TYPES environment - rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { - - # Check for primitive types - type <- PRIMITIVE_TYPES[[x]] - - if (is.null(type)) { - # Check for complex types - for (t in names(COMPLEX_TYPES)) { - if (substring(x, 1, nchar(t)) == t) { - type <- COMPLEX_TYPES[[t]] - break - } - } - - if (is.null(type)) { - stop(paste("Unsupported data type: ", x)) - } - } - type - }) - - # Find which types don't have mapping to R - naIndices <- which(is.na(rTypes)) - - # Assign the original scala data types to the unmatched ones - rTypes[naIndices] <- types[naIndices] - - rTypes - }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0c305441e043e..711ce38f9e104 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -385,6 +385,22 @@ setGeneric("agg", function (x, ...) { standardGeneric("agg") }) #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) +#' @rdname columns +#' @export +setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) + +#' @rdname columns +#' @export +setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) + +#' @rdname coltypes +#' @export +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) + +#' @rdname coltypes +#' @export +setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) + #' @rdname schema #' @export setGeneric("columns", function(x) {standardGeneric("columns") }) @@ -1081,7 +1097,3 @@ setGeneric("attach") #' @rdname with #' @export setGeneric("with") - -#' @rdname coltypes -#' @export -setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index 1828c23ab0f6d..dae4fe858bdbc 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -41,3 +41,11 @@ COMPLEX_TYPES <- list( # The full list of data types. DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) + +# An environment for mapping R to Scala, names are R types and values are Scala types. +rToSQLTypes <- as.environment(list( + "integer" = "integer", # in R, integer is 32bit + "numeric" = "double", # in R, numeric == double which is 64bit + "double" = "double", + "character" = "string", + "logical" = "boolean")) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 899fc3b977385..d3b2f20bf81c5 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -622,6 +622,26 @@ test_that("schema(), dtypes(), columns(), names() return the correct values/form expect_equal(testNames[2], "name") }) +test_that("names() colnames() set the column names", { + df <- jsonFile(sqlContext, jsonPath) + names(df) <- c("col1", "col2") + expect_equal(colnames(df)[2], "col2") + + colnames(df) <- c("col3", "col4") + expect_equal(names(df)[1], "col3") + + # Test base::colnames base::names + m2 <- cbind(1, 1:4) + expect_equal(colnames(m2, do.NULL = FALSE), c("col1", "col2")) + colnames(m2) <- c("x","Y") + expect_equal(colnames(m2), c("x", "Y")) + + z <- list(a = 1, b = "c", c = 1:3) + expect_equal(names(z)[3], "c") + names(z)[3] <- "c2" + expect_equal(names(z)[3], "c2") +}) + test_that("head() and first() return the correct data", { df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) @@ -1617,7 +1637,7 @@ test_that("with() on a DataFrame", { expect_equal(nrow(sum2), 35) }) -test_that("Method coltypes() to get R's data types of a DataFrame", { +test_that("Method coltypes() to get and set R's data types of a DataFrame", { expect_equal(coltypes(irisDF), c(rep("numeric", 4), "character")) data <- data.frame(c1=c(1,2,3), @@ -1636,6 +1656,24 @@ test_that("Method coltypes() to get R's data types of a DataFrame", { x <- createDataFrame(sqlContext, list(list(as.environment( list("a"="b", "c"="d", "e"="f"))))) expect_equal(coltypes(x), "map") + + df <- selectExpr(jsonFile(sqlContext, jsonPath), "name", "(age * 1.21) as age") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) + + df1 <- select(df, cast(df$age, "integer")) + coltypes(df) <- c("character", "integer") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"))) + value <- collect(df[, 2])[[3, 1]] + expect_equal(value, collect(df1)[[3, 1]]) + expect_equal(value, 22) + + coltypes(df) <- c(NA, "numeric") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"))) + + expect_error(coltypes(df) <- c("character"), + "Length of type vector should match the number of columns for DataFrame") + expect_error(coltypes(df) <- c("environment", "list"), + "Only atomic type is supported for column types") }) unlink(parquetPath) From 187e403478ca70fee341bf3ca9fc830d8f469edb Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Sun, 29 Nov 2015 11:08:26 -0800 Subject: [PATCH 0959/1168] [SPARK-11781][SPARKR] SparkR has problem in inferring type of raw type. Author: Sun Rui Closes #9769 from sun-rui/SPARK-11781. --- R/pkg/R/DataFrame.R | 34 ++++++++++++++++------------- R/pkg/R/SQLContext.R | 2 +- R/pkg/R/types.R | 37 ++++++++++++++++++-------------- R/pkg/inst/tests/test_sparkSQL.R | 6 ++++++ 4 files changed, 47 insertions(+), 32 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index f89e2682d9e29..a82ded9c51fac 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -793,8 +793,8 @@ setMethod("dim", setMethod("collect", signature(x = "DataFrame"), function(x, stringsAsFactors = FALSE) { - names <- columns(x) - ncol <- length(names) + dtypes <- dtypes(x) + ncol <- length(dtypes) if (ncol <= 0) { # empty data.frame with 0 columns and 0 rows data.frame() @@ -817,25 +817,29 @@ setMethod("collect", # data of complex type can be held. But getting a cell from a column # of list type returns a list instead of a vector. So for columns of # non-complex type, append them as vector. + # + # For columns of complex type, be careful to access them. + # Get a column of complex type returns a list. + # Get a cell from a column of complex type returns a list instead of a vector. col <- listCols[[colIndex]] + colName <- dtypes[[colIndex]][[1]] if (length(col) <= 0) { - df[[names[colIndex]]] <- col + df[[colName]] <- col } else { - # TODO: more robust check on column of primitive types - vec <- do.call(c, col) - if (class(vec) != "list") { - df[[names[colIndex]]] <- vec + colType <- dtypes[[colIndex]][[2]] + # Note that "binary" columns behave like complex types. + if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { + vec <- do.call(c, col) + stopifnot(class(vec) != "list") + df[[colName]] <- vec } else { - # For columns of complex type, be careful to access them. - # Get a column of complex type returns a list. - # Get a cell from a column of complex type returns a list instead of a vector. - df[[names[colIndex]]] <- col - } + df[[colName]] <- col + } + } } + df } - df - } - }) + }) #' Limit #' diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index a62b25fde926d..85541c8e22447 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -63,7 +63,7 @@ infer_type <- function(x) { }) type <- Reduce(paste0, type) type <- paste0("struct<", substr(type, 1, nchar(type) - 1), ">") - } else if (length(x) > 1) { + } else if (length(x) > 1 && type != "binary") { paste0("array<", infer_type(x[[1]]), ">") } else { type diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index dae4fe858bdbc..1f06af7e904fe 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -19,25 +19,30 @@ # values are equivalent R types. This is stored in an environment to allow for # more efficient look up (environments use hashmaps). PRIMITIVE_TYPES <- as.environment(list( - "byte"="integer", - "tinyint"="integer", - "smallint"="integer", - "integer"="integer", - "bigint"="numeric", - "float"="numeric", - "double"="numeric", - "decimal"="numeric", - "string"="character", - "binary"="raw", - "boolean"="logical", - "timestamp"="POSIXct", - "date"="Date")) + "tinyint" = "integer", + "smallint" = "integer", + "int" = "integer", + "bigint" = "numeric", + "float" = "numeric", + "double" = "numeric", + "decimal" = "numeric", + "string" = "character", + "binary" = "raw", + "boolean" = "logical", + "timestamp" = "POSIXct", + "date" = "Date", + # following types are not SQL types returned by dtypes(). They are listed here for usage + # by checkType() in schema.R. + # TODO: refactor checkType() in schema.R. + "byte" = "integer", + "integer" = "integer" + )) # The complex data types. These do not have any direct mapping to R's types. COMPLEX_TYPES <- list( - "map"=NA, - "array"=NA, - "struct"=NA) + "map" = NA, + "array" = NA, + "struct" = NA) # The full list of data types. DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d3b2f20bf81c5..92ec82096c6df 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -72,6 +72,8 @@ test_that("infer types and check types", { expect_equal(infer_type(e), "map") expect_error(checkType("map"), "Key type in a map must be string or character") + + expect_equal(infer_type(as.raw(c(1, 2, 3))), "binary") }) test_that("structType and structField", { @@ -250,6 +252,10 @@ test_that("create DataFrame from list or data.frame", { mtcarsdf <- createDataFrame(sqlContext, mtcars) expect_equivalent(collect(mtcarsdf), mtcars) + + bytes <- as.raw(c(1, 2, 3)) + df <- createDataFrame(sqlContext, list(list(bytes))) + expect_equal(collect(df)[[1]][[1]], bytes) }) test_that("create DataFrame with different data types", { From 6ac2a75a2dc1d323b5f938b51251ae4cc0c2ed9b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 29 Nov 2015 14:13:11 -0800 Subject: [PATCH 0960/1168] [SPARK-12024][SQL] More efficient multi-column counting. In https://github.com/apache/spark/pull/9409 we enabled multi-column counting. The approach taken in that PR introduces a bit of overhead by first creating a row only to check if all of the columns are non-null. This PR fixes that technical debt. Count now takes multiple columns as its input. In order to make this work I have also added support for multiple columns in the single distinct code path. cc yhuai Author: Herman van Hovell Closes #10015 from hvanhovell/SPARK-12024. --- .../expressions/aggregate/Count.scala | 21 ++-------- .../expressions/conditionalExpressions.scala | 27 ------------- .../sql/catalyst/optimizer/Optimizer.scala | 14 ++++--- .../ConditionalExpressionSuite.scala | 14 ------- .../spark/sql/execution/aggregate/utils.scala | 39 +++++++++---------- .../spark/sql/expressions/WindowSpec.scala | 4 +- 6 files changed, 33 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 09a1da9200df0..441f52ab5ca58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -case class Count(child: Expression) extends DeclarativeAggregate { - override def children: Seq[Expression] = child :: Nil +case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override def nullable: Boolean = false @@ -30,7 +29,7 @@ case class Count(child: Expression) extends DeclarativeAggregate { override def dataType: DataType = LongType // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType) private lazy val count = AttributeReference("count", LongType)() @@ -41,7 +40,7 @@ case class Count(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions = Seq( - /* count = */ If(IsNull(child), count, count + 1L) + /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L) ) override lazy val mergeExpressions = Seq( @@ -54,17 +53,5 @@ case class Count(child: Expression) extends DeclarativeAggregate { } object Count { - def apply(children: Seq[Expression]): Count = { - // This is used to deal with COUNT DISTINCT. When we have multiple - // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row). - // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any - // null in the arguments, we will not count that row. So, we use DropAnyNull at here - // to return a null when any field of the created STRUCT is null. - val child = if (children.size > 1) { - DropAnyNull(CreateStruct(children)) - } else { - children.head - } - Count(child) - } + def apply(child: Expression): Count = Count(child :: Nil) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 694a2a7c54a90..40b1eec63e551 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -426,30 +426,3 @@ case class Greatest(children: Seq[Expression]) extends Expression { } } -/** Operator that drops a row when it contains any nulls. */ -case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StructType) - - protected override def nullSafeEval(input: Any): InternalRow = { - val row = input.asInstanceOf[InternalRow] - if (row.anyNull) { - null - } else { - row - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.anyNull()) { - ${ev.isNull} = true; - } else { - ${ev.value} = $eval; - } - """ - }) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2901d8f2efddf..06d14fcf8b9c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -362,9 +362,14 @@ object LikeSimplification extends Rule[LogicalPlan] { * Null value propagation from bottom to top of the expression tree. */ object NullPropagation extends Rule[LogicalPlan] { + def nonNullLiteral(e: Expression): Boolean = e match { + case Literal(null, _) => false + case _ => true + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ AggregateExpression(Count(Literal(null, _)), _, _) => + case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) @@ -377,16 +382,13 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable => + case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. AggregateExpression(Count(Literal(1)), mode, isDistinct = false) // For Coalesce, remove null literals. case e @ Coalesce(children) => - val newChildren = children.filter { - case Literal(null, _) => false - case _ => true - } + val newChildren = children.filter(nonNullLiteral) if (newChildren.length == 0) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index c1e3c17b87102..0df673bb9fa02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -231,18 +231,4 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } } - - test("function dropAnyNull") { - val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1)))) - val a = create_row("a", "q") - val nullStr: String = null - checkEvaluation(drop, a, a) - checkEvaluation(drop, null, create_row("b", nullStr)) - checkEvaluation(drop, null, create_row(nullStr, nullStr)) - - val row = 'r.struct( - StructField("a", StringType, false), - StructField("b", StringType, true)).at(0) - checkEvaluation(DropAnyNull(row), null, create_row(null)) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index a70e41436c7aa..76b938cdb694e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -146,20 +146,16 @@ object Utils { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one - // DISTINCT aggregate function, all of those functions will have the same column expression. + // DISTINCT aggregate function, all of those functions will have the same column expressions. // For example, it would be valid for functionsWithDistinct to be // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is // disallowed because those two distinct aggregates have different column expressions. - val distinctColumnExpression: Expression = { - val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children - assert(allDistinctColumnExpressions.length == 1) - allDistinctColumnExpressions.head - } - val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match { + val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children + val namedDistinctColumnExpressions = distinctColumnExpressions.map { case ne: NamedExpression => ne case other => Alias(other, other.toString)() } - val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute + val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) // 1. Create an Aggregate Operator for partial aggregations. @@ -170,10 +166,11 @@ object Utils { // We will group by the original grouping expression, plus an additional expression for the // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. - val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression + val partialAggregateGroupingExpressions = + groupingExpressions ++ namedDistinctColumnExpressions val partialAggregateResult = groupingAttributes ++ - Seq(distinctColumnAttribute) ++ + distinctColumnAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) if (usesTungstenAggregate) { TungstenAggregate( @@ -208,28 +205,28 @@ object Utils { partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialMergeAggregateResult = groupingAttributes ++ - Seq(distinctColumnAttribute) ++ + distinctColumnAttributes ++ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) if (usesTungstenAggregate) { TungstenAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes :+ distinctColumnAttribute, + groupingExpressions = groupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, nonCompleteAggregateAttributes = partialMergeAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) } else { SortBasedAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes :+ distinctColumnAttribute, + groupingExpressions = groupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, nonCompleteAggregateAttributes = partialMergeAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) } @@ -244,14 +241,16 @@ object Utils { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } + val distinctColumnAttributeLookup = + distinctColumnExpressions.zip(distinctColumnAttributes).toMap val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. case agg @ AggregateExpression(aggregateFunction, mode, true) => - val rewrittenAggregateFunction = aggregateFunction.transformDown { - case expr if expr == distinctColumnExpression => distinctColumnAttribute - }.asInstanceOf[AggregateFunction] + val rewrittenAggregateFunction = aggregateFunction + .transformDown(distinctColumnAttributeLookup) + .asInstanceOf[AggregateFunction] // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, @@ -270,7 +269,7 @@ object Utils { nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = completeAggregateExpressions, completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = resultExpressions, child = partialMergeAggregate) } else { @@ -281,7 +280,7 @@ object Utils { nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = completeAggregateExpressions, completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = resultExpressions, child = partialMergeAggregate) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index fc873c04f88f0..893e800a61438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -152,8 +152,8 @@ class WindowSpec private[sql]( case Sum(child) => WindowExpression( UnresolvedWindowFunction("sum", child :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(child) => WindowExpression( - UnresolvedWindowFunction("count", child :: Nil), + case Count(children) => WindowExpression( + UnresolvedWindowFunction("count", children), WindowSpecDefinition(partitionSpec, orderSpec, frame)) case First(child, ignoreNulls) => WindowExpression( // TODO this is a hack for Hive UDAF first_value From ce0edf9889e3770bca30d90f80e938b25499442a Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 29 Nov 2015 19:02:15 -0800 Subject: [PATCH 0961/1168] [SPARK-12039] [SQL] Ignore HiveSparkSubmitSuite's "SPARK-9757 Persist Parquet relation with decimal column". https://issues.apache.org/jira/browse/SPARK-12039 Since it is pretty flaky in hadoop 1 tests, we can disable it while we are investigating the cause. Author: Yin Huai Closes #10035 from yhuai/SPARK-12039-ignore. --- .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 24a3afee148c5..92962193311d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -101,7 +101,7 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } - test("SPARK-9757 Persist Parquet relation with decimal column") { + ignore("SPARK-9757 Persist Parquet relation with decimal column") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( "--class", SPARK_9757.getClass.getName.stripSuffix("$"), From ec7e453df2561b2590753aeb663ab652b10f1df3 Mon Sep 17 00:00:00 2001 From: toddwan Date: Mon, 30 Nov 2015 09:26:29 +0000 Subject: [PATCH 0962/1168] [SPARK-11859][MESOS] SparkContext accepts invalid Master URLs in the form zk://host:port for a multi-master Mesos cluster using ZooKeeper * According to below doc and validation logic in [SparkSubmit.scala](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L231), master URL for a mesos cluster should always start with `mesos://` http://spark.apache.org/docs/latest/running-on-mesos.html `The Master URLs for Mesos are in the form mesos://host:5050 for a single-master Mesos cluster, or mesos://zk://host:2181 for a multi-master Mesos cluster using ZooKeeper.` * However, [SparkContext.scala](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/SparkContext.scala#L2749) fails the validation and can receive master URL in the form `zk://host:port` * For the master URLs in the form `zk:host:port`, the valid form should be `mesos://zk://host:port` * This PR restrict the validation in `SparkContext.scala`, and now only mesos master URLs prefixed with `mesos://` can be accepted. * This PR also updated corresponding unit test. Author: toddwan Closes #9886 from toddwan/S11859. --- .../scala/org/apache/spark/SparkContext.scala | 16 ++++++++++------ .../SparkContextSchedulerCreationSuite.scala | 5 +++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b030d3c71dc20..8a62b71c3fa68 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2708,15 +2708,14 @@ object SparkContext extends Logging { scheduler.initialize(backend) (backend, scheduler) - case mesosUrl @ MESOS_REGEX(_) => + case MESOS_REGEX(mesosUrl) => MesosNativeLibrary.load() val scheduler = new TaskSchedulerImpl(sc) val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", defaultValue = true) - val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager) + new CoarseMesosSchedulerBackend(scheduler, sc, mesosUrl, sc.env.securityManager) } else { - new MesosSchedulerBackend(scheduler, sc, url) + new MesosSchedulerBackend(scheduler, sc, mesosUrl) } scheduler.initialize(backend) (backend, scheduler) @@ -2727,6 +2726,11 @@ object SparkContext extends Logging { scheduler.initialize(backend) (backend, scheduler) + case zkUrl if zkUrl.startsWith("zk://") => + logWarning("Master URL for a multi-master Mesos cluster managed by ZooKeeper should be " + + "in the form mesos://zk://host:port. Current Master URL will stop working in Spark 2.0.") + createTaskScheduler(sc, "mesos://" + zkUrl) + case _ => throw new SparkException("Could not parse Master URL: '" + master + "'") } @@ -2745,8 +2749,8 @@ private object SparkMasterRegex { val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters val SPARK_REGEX = """spark://(.*)""".r - // Regular expression for connection to Mesos cluster by mesos:// or zk:// url - val MESOS_REGEX = """(mesos|zk)://.*""".r + // Regular expression for connection to Mesos cluster by mesos:// or mesos://zk:// url + val MESOS_REGEX = """mesos://(.*)""".r // Regular expression for connection to Simr cluster val SIMR_REGEX = """simr://(.*)""".r } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index e5a14a69ef05f..d18e0782c0392 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -175,6 +175,11 @@ class SparkContextSchedulerCreationSuite } test("mesos with zookeeper") { + testMesos("mesos://zk://localhost:1234,localhost:2345", + classOf[MesosSchedulerBackend], coarse = false) + } + + test("mesos with zookeeper and Master URL starting with zk://") { testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend], coarse = false) } } From 5b38164dc65b67b32aa044d9535f9ecb21239afd Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Mon, 30 Nov 2015 09:30:58 +0000 Subject: [PATCH 0963/1168] [MINOR][BUILD] Changed the comment to reflect the plugin project is there to support SBT pom reader only. Author: Prashant Sharma Closes #10012 from ScrapCodes/minor-build-comment. --- project/project/SparkPluginBuild.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 471d00bd8223f..cbb88dc7dd1dd 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -19,9 +19,8 @@ import sbt._ import sbt.Keys._ /** - * This plugin project is there to define new scala style rules for spark. This is - * a plugin project so that this gets compiled first and is put on the classpath and - * becomes available for scalastyle sbt plugin. + * This plugin project is there because we use our custom fork of sbt-pom-reader plugin. This is + * a plugin project so that this gets compiled first and is available on the classpath for SBT build. */ object SparkPluginDef extends Build { lazy val root = Project("plugins", file(".")) dependsOn(sbtPomReader) From 68ff33c17a21b2e652aa0309ea54ad118fb4f885 Mon Sep 17 00:00:00 2001 From: Wieland Hoffmann Date: Mon, 30 Nov 2015 09:32:48 +0000 Subject: [PATCH 0964/1168] [DOC] Explicitly state that top maintains the order of elements Top is implemented in terms of takeOrdered, which already maintains the order, so top should, too. Author: Wieland Hoffmann Closes #10013 from mineo/top-order. --- .../main/scala/org/apache/spark/api/java/JavaRDDLike.scala | 4 ++-- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 871be0b1f39ea..1e9d4f1803a81 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -556,7 +556,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the top k (largest) elements from this RDD as defined by - * the specified Comparator[T]. + * the specified Comparator[T] and maintains the order. * @param num k, the number of top elements to return * @param comp the comparator that defines the order * @return an array of top elements @@ -567,7 +567,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the top k (largest) elements from this RDD using the - * natural ordering for T. + * natural ordering for T and maintains the order. * @param num k, the number of top elements to return * @return an array of top elements */ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2aeb5eeaad32c..8b3731d935788 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1327,7 +1327,8 @@ abstract class RDD[T: ClassTag]( /** * Returns the top k (largest) elements from this RDD as defined by the specified - * implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example: + * implicit Ordering[T] and maintains the ordering. This does the opposite of + * [[takeOrdered]]. For example: * {{{ * sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1) * // returns Array(12) From abe08fb468ed7a0fff5de0a0d4cfff2bb9cf0072 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Mon, 30 Nov 2015 10:11:27 +0000 Subject: [PATCH 0965/1168] [SPARK-12023][BUILD] Fix warnings while packaging spark with maven. this is a trivial fix, discussed [here](http://stackoverflow.com/questions/28500401/maven-assembly-plugin-warning-the-assembly-descriptor-contains-a-filesystem-roo/). Author: Prashant Sharma Closes #10014 from ScrapCodes/assembly-warning. --- assembly/src/main/assembly/assembly.xml | 8 ++++---- external/mqtt/src/main/assembly/assembly.xml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/assembly/src/main/assembly/assembly.xml b/assembly/src/main/assembly/assembly.xml index 711156337b7c3..009d4b92f406c 100644 --- a/assembly/src/main/assembly/assembly.xml +++ b/assembly/src/main/assembly/assembly.xml @@ -32,7 +32,7 @@ ${project.parent.basedir}/core/src/main/resources/org/apache/spark/ui/static/ - /ui-resources/org/apache/spark/ui/static + ui-resources/org/apache/spark/ui/static **/* @@ -41,7 +41,7 @@ ${project.parent.basedir}/sbin/ - /sbin + sbin **/* @@ -50,7 +50,7 @@ ${project.parent.basedir}/bin/ - /bin + bin **/* @@ -59,7 +59,7 @@ ${project.parent.basedir}/assembly/target/${spark.jar.dir} - / + ${spark.jar.basename} diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml index ecab5b360eb3e..c110b01b34e10 100644 --- a/external/mqtt/src/main/assembly/assembly.xml +++ b/external/mqtt/src/main/assembly/assembly.xml @@ -24,7 +24,7 @@ ${project.build.directory}/scala-${scala.binary.version}/test-classes - / + From 919153c3e30348ffdad2e37686c11b5b745f4249 Mon Sep 17 00:00:00 2001 From: CK50 Date: Mon, 30 Nov 2015 20:08:49 +0800 Subject: [PATCH 0966/1168] [SPARK-11989][SQL] Only use commit in JDBC data source if the underlying database supports transactions Fixes [SPARK-11989](https://issues.apache.org/jira/browse/SPARK-11989) Author: CK50 Author: Christian Kurz Closes #9973 from CK50/branch-1.6_non-transactional. (cherry picked from commit a589736a1b237ef2f3bd59fbaeefe143ddcc8f4e) Signed-off-by: Reynold Xin --- .../datasources/jdbc/JdbcUtils.scala | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 7375a5c09123f..252f1cfd5d9c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -21,6 +21,7 @@ import java.sql.{Connection, PreparedStatement} import java.util.Properties import scala.util.Try +import scala.util.control.NonFatal import org.apache.spark.Logging import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects} @@ -125,8 +126,19 @@ object JdbcUtils extends Logging { dialect: JdbcDialect): Iterator[Byte] = { val conn = getConnection() var committed = false + val supportsTransactions = try { + conn.getMetaData().supportsDataManipulationTransactionsOnly() || + conn.getMetaData().supportsDataDefinitionAndDataManipulationTransactions() + } catch { + case NonFatal(e) => + logWarning("Exception while detecting transaction support", e) + true + } + try { - conn.setAutoCommit(false) // Everything in the same db transaction. + if (supportsTransactions) { + conn.setAutoCommit(false) // Everything in the same db transaction. + } val stmt = insertStatement(conn, table, rddSchema) try { var rowCount = 0 @@ -175,14 +187,18 @@ object JdbcUtils extends Logging { } finally { stmt.close() } - conn.commit() + if (supportsTransactions) { + conn.commit() + } committed = true } finally { if (!committed) { // The stage must fail. We got here through an exception path, so // let the exception through unless rollback() or close() want to // tell the user about another problem. - conn.rollback() + if (supportsTransactions) { + conn.rollback() + } conn.close() } else { // The stage must succeed. We cannot propagate any exception close() might throw. From e54151bb6ae56dad6bca13546efe473d33659642 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 30 Nov 2015 10:32:13 -0800 Subject: [PATCH 0967/1168] [SPARK-11700] [SQL] Remove thread local SQLContext in SparkPlan In 1.6, we introduce a public API to have a SQLContext for current thread, SparkPlan should use that. Author: Davies Liu Closes #9990 from davies/leak_context. --- .../scala/org/apache/spark/sql/SQLContext.scala | 10 +++++----- .../spark/sql/execution/QueryExecution.scala | 3 +-- .../org/apache/spark/sql/execution/SparkPlan.scala | 14 ++++---------- .../apache/spark/sql/MultiSQLContextsSuite.scala | 2 +- .../sql/execution/ExchangeCoordinatorSuite.scala | 2 +- .../sql/execution/RowFormatConvertersSuite.scala | 4 ++-- 6 files changed, 14 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1c2ac5f6f11bf..8d2783952532a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -26,7 +26,6 @@ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SparkException, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD @@ -45,9 +44,10 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.sql.util.ExecutionListenerManager +import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.util.Utils +import org.apache.spark.{SparkContext, SparkException} /** * The entry point for working with structured data (rows and columns) in Spark. Allows the @@ -401,7 +401,7 @@ class SQLContext private[sql]( */ @Experimental def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { - SparkPlan.currentContext.set(self) + SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) @@ -417,7 +417,7 @@ class SQLContext private[sql]( */ @Experimental def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { - SparkPlan.currentContext.set(self) + SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes DataFrame(self, LocalRelation.fromProduct(attributeSeq, data)) @@ -1334,7 +1334,7 @@ object SQLContext { activeContext.remove() } - private[sql] def getActiveContextOption(): Option[SQLContext] = { + private[sql] def getActive(): Option[SQLContext] = { Option(activeContext.get()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5da5aea17e25b..107570f9dbcc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -42,9 +42,8 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { lazy val optimizedPlan: LogicalPlan = sqlContext.optimizer.execute(withCachedData) - // TODO: Don't just pick the first one... lazy val sparkPlan: SparkPlan = { - SparkPlan.currentContext.set(sqlContext) + SQLContext.setActive(sqlContext) sqlContext.planner.plan(optimizedPlan).next() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 534a3bcb8364d..507641ff8263e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,21 +23,15 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging import org.apache.spark.rdd.{RDD, RDDOperationScope} -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric} import org.apache.spark.sql.types.DataType -object SparkPlan { - protected[sql] val currentContext = new ThreadLocal[SQLContext]() -} - /** * The base class for physical operators. */ @@ -49,7 +43,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * populated by the query planning infrastructure. */ @transient - protected[spark] final val sqlContext = SparkPlan.currentContext.get() + protected[spark] final val sqlContext = SQLContext.getActive().get protected def sparkContext = sqlContext.sparkContext @@ -69,7 +63,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Overridden make copy also propogates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { - SparkPlan.currentContext.set(sqlContext) + SQLContext.setActive(sqlContext) super.makeCopy(newArgs) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala index 34c5c68fd1c18..162c0b56c6e11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -27,7 +27,7 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { private var sparkConf: SparkConf = _ override protected def beforeAll(): Unit = { - originalActiveSQLContext = SQLContext.getActiveContextOption() + originalActiveSQLContext = SQLContext.getActive() originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index b96d50a70b85c..180050bdac00f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -30,7 +30,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { private var originalInstantiatedSQLContext: Option[SQLContext] = _ override protected def beforeAll(): Unit = { - originalActiveSQLContext = SQLContext.getActiveContextOption() + originalActiveSQLContext = SQLContext.getActive() originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 6876ab0f02b10..13d68a103a225 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} import org.apache.spark.sql.catalyst.util.GenericArrayData @@ -94,7 +94,7 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { } test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SparkPlan.currentContext.set(sqlContext) + SQLContext.setActive(sqlContext) val schema = ArrayType(StringType) val rows = (1 to 100).map { i => InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) From b6860b9bfc36fe8ebe550397be04c904f0d9c7d1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 30 Nov 2015 11:54:18 -0800 Subject: [PATCH 0968/1168] [SPARK-11982] [SQL] improve performance of cartesian product This PR improve the performance of CartesianProduct by caching the result of right plan. After this patch, the query time of TPC-DS Q65 go down to 4 seconds from 28 minutes (420X faster). cc nongli Author: Davies Liu Closes #9969 from davies/improve_cartesian. --- .../unsafe/sort/UnsafeExternalSorter.java | 63 +++++++++++++++ .../unsafe/sort/UnsafeInMemorySorter.java | 7 ++ .../execution/joins/CartesianProduct.scala | 76 +++++++++++++++++-- .../execution/metric/SQLMetricsSuite.scala | 2 +- 4 files changed, 139 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 2e40312674737..5a97f4f11340c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -21,6 +21,7 @@ import java.io.File; import java.io.IOException; import java.util.LinkedList; +import java.util.Queue; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; @@ -521,4 +522,66 @@ public long getKeyPrefix() { return upstream.getKeyPrefix(); } } + + /** + * Returns a iterator, which will return the rows in the order as inserted. + * + * It is the caller's responsibility to call `cleanupResources()` + * after consuming this iterator. + */ + public UnsafeSorterIterator getIterator() throws IOException { + if (spillWriters.isEmpty()) { + assert(inMemSorter != null); + return inMemSorter.getIterator(); + } else { + LinkedList queue = new LinkedList<>(); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + queue.add(spillWriter.getReader(blockManager)); + } + if (inMemSorter != null) { + queue.add(inMemSorter.getIterator()); + } + return new ChainedIterator(queue); + } + } + + /** + * Chain multiple UnsafeSorterIterator together as single one. + */ + class ChainedIterator extends UnsafeSorterIterator { + + private final Queue iterators; + private UnsafeSorterIterator current; + + public ChainedIterator(Queue iterators) { + assert iterators.size() > 0; + this.iterators = iterators; + this.current = iterators.remove(); + } + + @Override + public boolean hasNext() { + while (!current.hasNext() && !iterators.isEmpty()) { + current = iterators.remove(); + } + return current.hasNext(); + } + + @Override + public void loadNext() throws IOException { + current.loadNext(); + } + + @Override + public Object getBaseObject() { return current.getBaseObject(); } + + @Override + public long getBaseOffset() { return current.getBaseOffset(); } + + @Override + public int getRecordLength() { return current.getRecordLength(); } + + @Override + public long getKeyPrefix() { return current.getKeyPrefix(); } + } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index dce1f15a2963c..c91e88f31bf9b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -226,4 +226,11 @@ public SortedIterator getSortedIterator() { sorter.sort(array, 0, pos / 2, sortComparator); return new SortedIterator(pos / 2); } + + /** + * Returns an iterator over record pointers in original order (inserted). + */ + public SortedIterator getIterator() { + return new SortedIterator(pos / 2); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index f467519b802a7..fa2bc7672131c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -17,16 +17,75 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.rdd.RDD +import org.apache.spark._ +import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + + +/** + * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, + * will be much faster than building the right partition for every row in left RDD, it also + * materialize the right RDD (in case of the right RDD is nondeterministic). + */ +private[spark] +class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) + extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { + + override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { + // We will not sort the rows, so prefixComparator and recordComparator are null. + val sorter = UnsafeExternalSorter.create( + context.taskMemoryManager(), + SparkEnv.get.blockManager, + context, + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes) + + val partition = split.asInstanceOf[CartesianPartition] + for (y <- rdd2.iterator(partition.s2, context)) { + sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0) + } + + // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] + def createIter(): Iterator[UnsafeRow] = { + val iter = sorter.getIterator + val unsafeRow = new UnsafeRow + new Iterator[UnsafeRow] { + override def hasNext: Boolean = { + iter.hasNext + } + override def next(): UnsafeRow = { + iter.loadNext() + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFieldsOfRight, + iter.getRecordLength) + unsafeRow + } + } + } + + val resultIter = + for (x <- rdd1.iterator(partition.s1, context); + y <- createIter()) yield (x, y) + CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( + resultIter, sorter.cleanupResources) + } +} case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override def canProcessSafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true + override private[sql] lazy val metrics = Map( "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), @@ -39,18 +98,19 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod val leftResults = left.execute().map { row => numLeftRows += 1 - row.copy() + row.asInstanceOf[UnsafeRow] } val rightResults = right.execute().map { row => numRightRows += 1 - row.copy() + row.asInstanceOf[UnsafeRow] } - leftResults.cartesian(rightResults).mapPartitionsInternal { iter => - val joinedRow = new JoinedRow + val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) + pair.mapPartitionsInternal { iter => + val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) iter.map { r => numOutputRows += 1 - joinedRow(r._1, r._2) + joiner.join(r._1, r._2) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index ebfa1eaf3e5bc..4f2cad19bfb6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -317,7 +317,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { testSparkPlanMetrics(df, 1, Map( 1L -> ("CartesianProduct", Map( "number of left rows" -> 12L, // left needs to be scanned twice - "number of right rows" -> 12L, // right is read 6 times + "number of right rows" -> 4L, // right is read twice "number of output rows" -> 12L))) ) } From 710e445e906491d9be441d2d83fcbf9ec51e8900 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Mon, 30 Nov 2015 13:02:08 -0800 Subject: [PATCH 0969/1168] [MINOR][DOCS] fixed list display in ml-ensembles The list in ml-ensembles.md wasn't properly formatted and, as a result, was looking like this: ![old](http://i.imgur.com/2ZhELLR.png) This PR aims to make it look like this: ![new](http://i.imgur.com/0Xriwd2.png) Author: BenFradet Closes #10025 from BenFradet/ml-ensembles-doc. --- docs/ml-ensembles.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index f6c3c30d5334f..14fef76f260ff 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -20,6 +20,7 @@ Both use [MLlib decision trees](ml-decision-tree.html) as their base models. Users can find more information about ensemble algorithms in the [MLlib Ensemble guide](mllib-ensembles.html). In this section, we demonstrate the Pipelines API for ensembles. The main differences between this API and the [original MLlib ensembles API](mllib-ensembles.html) are: + * support for ML Pipelines * separation of classification vs. regression * use of DataFrame metadata to distinguish continuous and categorical features From 3de5d82d99d3753271768760fea30c298effe767 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 30 Nov 2015 13:41:52 -0800 Subject: [PATCH 0970/1168] Revert "[SPARK-11206] Support SQL UI on the history server" This reverts commit cc243a079b1c039d6e7f0b410d1654d94a090e14 / PR #9297 I'm reverting this because it broke SQLListenerMemoryLeakSuite in the master Maven builds. See #9991 for a discussion of why this broke the tests. --- .rat-excludes | 1 - .../org/apache/spark/JavaSparkListener.java | 3 - .../apache/spark/SparkFirehoseListener.java | 4 - .../scheduler/EventLoggingListener.scala | 4 - .../spark/scheduler/SparkListener.scala | 24 +-- .../spark/scheduler/SparkListenerBus.scala | 1 - .../scala/org/apache/spark/ui/SparkUI.scala | 16 +- .../org/apache/spark/util/JsonProtocol.scala | 11 +- ...park.scheduler.SparkHistoryListenerFactory | 1 - .../org/apache/spark/sql/SQLContext.scala | 18 +-- .../spark/sql/execution/SQLExecution.scala | 24 ++- .../spark/sql/execution/SparkPlanInfo.scala | 46 ------ .../sql/execution/metric/SQLMetricInfo.scala | 30 ---- .../sql/execution/metric/SQLMetrics.scala | 56 +++---- .../sql/execution/ui/ExecutionPage.scala | 4 +- .../spark/sql/execution/ui/SQLListener.scala | 139 ++++++------------ .../spark/sql/execution/ui/SQLTab.scala | 12 +- .../sql/execution/ui/SparkPlanGraph.scala | 20 +-- .../execution/metric/SQLMetricsSuite.scala | 4 +- .../sql/execution/ui/SQLListenerSuite.scala | 43 +++--- .../spark/sql/test/SharedSQLContext.scala | 1 - 21 files changed, 135 insertions(+), 327 deletions(-) delete mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala diff --git a/.rat-excludes b/.rat-excludes index 7262c960ed6bb..08fba6d351d6a 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -82,5 +82,4 @@ INDEX gen-java.* .*avpr org.apache.spark.sql.sources.DataSourceRegister -org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index 23bc9a2e81727..fa9acf0a15b88 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -82,7 +82,4 @@ public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } @Override public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } - @Override - public void onOtherEvent(SparkListenerEvent event) { } - } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index e6b24afd88ad4..1214d05ba6063 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,8 +118,4 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { onEvent(blockUpdated); } - @Override - public void onOtherEvent(SparkListenerEvent event) { - onEvent(event); - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index eaa07acc5132e..000a021a528cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -207,10 +207,6 @@ private[spark] class EventLoggingListener( // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } - override def onOtherEvent(event: SparkListenerEvent): Unit = { - logEvent(event, flushLogger = true) - } - /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 075a7f13172de..896f1743332f1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -22,19 +22,15 @@ import java.util.Properties import scala.collection.Map import scala.collection.mutable -import com.fasterxml.jackson.annotation.JsonTypeInfo - -import org.apache.spark.{Logging, SparkConf, TaskEndReason} +import org.apache.spark.{Logging, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} -import org.apache.spark.ui.SparkUI @DeveloperApi -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") -trait SparkListenerEvent +sealed trait SparkListenerEvent @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) @@ -134,17 +130,6 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent */ private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent -/** - * Interface for creating history listeners defined in other modules like SQL, which are used to - * rebuild the history UI. - */ -private[spark] trait SparkHistoryListenerFactory { - /** - * Create listeners used to rebuild the history UI. - */ - def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] -} - /** * :: DeveloperApi :: * Interface for listening to events from the Spark scheduler. Note that this is an internal @@ -238,11 +223,6 @@ trait SparkListener { * Called when the driver receives a block update info. */ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } - - /** - * Called when other events like SQL-specific events are posted. - */ - def onOtherEvent(event: SparkListenerEvent) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 95722a07144ec..04afde33f5aad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,7 +61,6 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata - case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 8da6884a38535..4608bce202ec8 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,13 +17,10 @@ package org.apache.spark.ui -import java.util.{Date, ServiceLoader} - -import scala.collection.JavaConverters._ +import java.util.Date import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -157,16 +154,7 @@ private[spark] object SparkUI { appName: String, basePath: String, startTime: Long): SparkUI = { - val sparkUI = create( - None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) - - val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], - Utils.getContextOrSparkClassLoader).asScala - listenerFactories.foreach { listenerFactory => - val listeners = listenerFactory.createListeners(conf, sparkUI) - listeners.foreach(listenerBus.addListener) - } - sparkUI + create(None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) } /** diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 7f5d713ec6505..c9beeb25e05af 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -19,21 +19,19 @@ package org.apache.spark.util import java.util.{Properties, UUID} +import org.apache.spark.scheduler.cluster.ExecutorInfo + import scala.collection.JavaConverters._ import scala.collection.Map -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.JsonAST._ -import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ /** @@ -56,8 +54,6 @@ private[spark] object JsonProtocol { private implicit val format = DefaultFormats - private val mapper = new ObjectMapper().registerModule(DefaultScalaModule) - /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ @@ -100,7 +96,6 @@ private[spark] object JsonProtocol { executorMetricsUpdateToJson(metricsUpdate) case blockUpdated: SparkListenerBlockUpdated => throw new MatchError(blockUpdated) // TODO(ekl) implement this - case _ => parse(mapper.writeValueAsString(event)) } } @@ -516,8 +511,6 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) - case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) - .asInstanceOf[SparkListenerEvent] } } diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory deleted file mode 100644 index 507100be90967..0000000000000 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory +++ /dev/null @@ -1 +0,0 @@ -org.apache.spark.sql.execution.ui.SQLHistoryListenerFactory diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 8d2783952532a..9cc65de19180a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1263,8 +1263,6 @@ object SQLContext { */ @transient private val instantiatedContext = new AtomicReference[SQLContext]() - @transient private val sqlListener = new AtomicReference[SQLListener]() - /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. * @@ -1309,10 +1307,6 @@ object SQLContext { Option(instantiatedContext.get()) } - private[sql] def clearSqlListener(): Unit = { - sqlListener.set(null) - } - /** * Changes the SQLContext that will be returned in this thread and its children when * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives @@ -1361,13 +1355,9 @@ object SQLContext { * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. */ private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { - if (sqlListener.get() == null) { - val listener = new SQLListener(sc.conf) - if (sqlListener.compareAndSet(null, listener)) { - sc.addSparkListener(listener) - sc.ui.foreach(new SQLTab(listener, _)) - } - } - sqlListener.get() + val listener = new SQLListener(sc.conf) + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + listener } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 34971986261c2..1422e15549c94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,8 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionStart, - SparkListenerSQLExecutionEnd} +import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.util.Utils private[sql] object SQLExecution { @@ -46,14 +45,25 @@ private[sql] object SQLExecution { sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) val r = try { val callSite = Utils.getCallSite() - sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + sqlContext.listener.onExecutionStart( + executionId, + callSite.shortForm, + callSite.longForm, + queryExecution.toString, + SparkPlanGraph(queryExecution.executedPlan), + System.currentTimeMillis()) try { body } finally { - sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + // Ideally, we need to make sure onExecutionEnd happens after onJobStart and onJobEnd. + // However, onJobStart and onJobEnd run in the listener thread. Because we cannot add new + // SQL event types to SparkListener since it's a public API, we cannot guarantee that. + // + // SQLListener should handle the case that onExecutionEnd happens before onJobEnd. + // + // The worst case is onExecutionEnd may happen before onJobStart when the listener thread + // is very busy. If so, we cannot track the jobs for the execution. It seems acceptable. + sqlContext.listener.onExecutionEnd(executionId, System.currentTimeMillis()) } } finally { sc.setLocalProperty(EXECUTION_ID_KEY, null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala deleted file mode 100644 index 486ce34064e43..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.execution.metric.SQLMetricInfo -import org.apache.spark.util.Utils - -/** - * :: DeveloperApi :: - * Stores information about a SQL SparkPlan. - */ -@DeveloperApi -class SparkPlanInfo( - val nodeName: String, - val simpleString: String, - val children: Seq[SparkPlanInfo], - val metrics: Seq[SQLMetricInfo]) - -private[sql] object SparkPlanInfo { - - def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { - val metrics = plan.metrics.toSeq.map { case (key, metric) => - new SQLMetricInfo(metric.name.getOrElse(key), metric.id, - Utils.getFormattedClassName(metric.param)) - } - val children = plan.children.map(fromSparkPlan) - - new SparkPlanInfo(plan.nodeName, plan.simpleString, children, metrics) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala deleted file mode 100644 index 2708219ad3485..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.metric - -import org.apache.spark.annotation.DeveloperApi - -/** - * :: DeveloperApi :: - * Stores information about a SQL Metric. - */ -@DeveloperApi -class SQLMetricInfo( - val name: String, - val accumulatorId: Long, - val metricParam: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 6c0f6f8a52dc5..1c253e3942e95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -104,39 +104,21 @@ private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialVa override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) } -private object LongSQLMetricParam extends LongSQLMetricParam(_.sum.toString, 0L) - -private object StaticsLongSQLMetricParam extends LongSQLMetricParam( - (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.bytesToString) - } - s"\n$sum ($min, $med, $max)" - }, -1L) - private[sql] object SQLMetrics { private def createLongMetric( sc: SparkContext, name: String, - param: LongSQLMetricParam): LongSQLMetric = { + stringValue: Seq[Long] => String, + initialValue: Long): LongSQLMetric = { + val param = new LongSQLMetricParam(stringValue, initialValue) val acc = new LongSQLMetric(name, param) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - createLongMetric(sc, name, LongSQLMetricParam) + createLongMetric(sc, name, _.sum.toString, 0L) } /** @@ -144,25 +126,31 @@ private[sql] object SQLMetrics { * spill size, etc. */ def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = { + val stringValue = (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.bytesToString) + } + s"\n$sum ($min, $med, $max)" + } // The final result of this metric in physical operator UI may looks like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) - createLongMetric(sc, s"$name total (min, med, max)", StaticsLongSQLMetricParam) - } - - def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = { - val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam) - val staticsSQLMetricParam = Utils.getFormattedClassName(StaticsLongSQLMetricParam) - val metricParam = metricParamName match { - case `longSQLMetricParam` => LongSQLMetricParam - case `staticsSQLMetricParam` => StaticsLongSQLMetricParam - } - metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]] + createLongMetric(sc, s"$name total (min, med, max)", stringValue, -1L) } /** * A metric that its value will be ignored. Use this one when we need a metric parameter but don't * care about the value. */ - val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam) + val nullLongMetric = new LongSQLMetric("null", new LongSQLMetricParam(_.sum.toString, 0L)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index c74ad40406992..e74d6fb396e1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest -import scala.xml.Node +import scala.xml.{Node, Unparsed} + +import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index e19a1e3e5851f..5a072de400b6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,34 +19,11 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricValue, SQLMetricParam} +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} -import org.apache.spark.ui.SparkUI - -@DeveloperApi -case class SparkListenerSQLExecutionStart( - executionId: Long, - description: String, - details: String, - physicalPlanDescription: String, - sparkPlanInfo: SparkPlanInfo, - time: Long) - extends SparkListenerEvent - -@DeveloperApi -case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) - extends SparkListenerEvent - -private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { - - override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { - List(new SQLHistoryListener(conf, sparkUI)) - } -} private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { @@ -141,8 +118,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(), - finishTask = false) + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics, finishTask = false) } } @@ -164,7 +140,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskMetrics.accumulatorUpdates(), + taskEnd.taskMetrics, finishTask = true) } @@ -172,12 +148,15 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi * Update the accumulator values of a task with the latest metrics for this task. This is called * every time we receive an executor heartbeat or when a task finishes. */ - protected def updateTaskAccumulatorValues( + private def updateTaskAccumulatorValues( taskId: Long, stageId: Int, stageAttemptID: Int, - accumulatorUpdates: Map[Long, Any], + metrics: TaskMetrics, finishTask: Boolean): Unit = { + if (metrics == null) { + return + } _stageIdToStageMetrics.get(stageId) match { case Some(stageMetrics) => @@ -195,9 +174,9 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi case Some(taskMetrics) => if (finishTask) { taskMetrics.finished = true - taskMetrics.accumulatorUpdates = accumulatorUpdates + taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() } else if (!taskMetrics.finished) { - taskMetrics.accumulatorUpdates = accumulatorUpdates + taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() } else { // If a task is finished, we should not override with accumulator updates from // heartbeat reports @@ -206,7 +185,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi // TODO Now just set attemptId to 0. Should fix here when we can get the attempt // id from SparkListenerExecutorMetricsUpdate stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( - attemptId = 0, finished = finishTask, accumulatorUpdates) + attemptId = 0, finished = finishTask, metrics.accumulatorUpdates()) } } case None => @@ -214,40 +193,38 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } } - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case SparkListenerSQLExecutionStart(executionId, description, details, - physicalPlanDescription, sparkPlanInfo, time) => - val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) - val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => - node.metrics.map(metric => metric.accumulatorId -> metric) - } - val executionUIData = new SQLExecutionUIData( - executionId, - description, - details, - physicalPlanDescription, - physicalPlanGraph, - sqlPlanMetrics.toMap, - time) - synchronized { - activeExecutions(executionId) = executionUIData - _executionIdToData(executionId) = executionUIData - } - case SparkListenerSQLExecutionEnd(executionId, time) => synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - executionUIData.completionTime = Some(time) - if (!executionUIData.hasRunningJobs) { - // onExecutionEnd happens after all "onJobEnd"s - // So we should update the execution lists. - markExecutionFinished(executionId) - } else { - // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. - // Then we don't if the execution is successful, so let the last onJobEnd updates the - // execution lists. - } + def onExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + physicalPlanGraph: SparkPlanGraph, + time: Long): Unit = { + val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + + val executionUIData = new SQLExecutionUIData(executionId, description, details, + physicalPlanDescription, physicalPlanGraph, sqlPlanMetrics.toMap, time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + } + + def onExecutionEnd(executionId: Long, time: Long): Unit = synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. } } - case _ => // Ignore } private def markExecutionFinished(executionId: Long): Unit = { @@ -312,38 +289,6 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } -private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) - extends SQLListener(conf) { - - private var sqlTabAttached = false - - override def onExecutorMetricsUpdate( - executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { - // Do nothing - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - updateTaskAccumulatorValues( - taskEnd.taskInfo.taskId, - taskEnd.stageId, - taskEnd.stageAttemptId, - taskEnd.taskInfo.accumulables.map { acc => - (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong)) - }.toMap, - finishTask = true) - } - - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case _: SparkListenerSQLExecutionStart => - if (!sqlTabAttached) { - new SQLTab(this, sparkUI) - sqlTabAttached = true - } - super.onOtherEvent(event) - case _ => super.onOtherEvent(event) - } -} - /** * Represent all necessary data for an execution that will be used in Web UI. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 4f50b2ecdc8f8..9c27944d42fc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.ui +import java.util.concurrent.atomic.AtomicInteger + import org.apache.spark.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) - extends SparkUITab(sparkUI, "SQL") with Logging { + extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { val parent = sparkUI @@ -33,5 +35,13 @@ private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) } private[sql] object SQLTab { + private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" + + private val nextTabId = new AtomicInteger(0) + + private def nextTabName: String = { + val nextId = nextTabId.getAndIncrement() + if (nextId == 0) "SQL" else s"SQL$nextId" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 7af0ff09c5c6d..f1fce5478a3fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} /** * A graph used for storing information of an executionPlan of DataFrame. @@ -48,27 +48,27 @@ private[sql] object SparkPlanGraph { /** * Build a SparkPlanGraph from the root of a SparkPlan tree. */ - def apply(planInfo: SparkPlanInfo): SparkPlanGraph = { + def apply(plan: SparkPlan): SparkPlanGraph = { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges) + buildSparkPlanGraphNode(plan, nodeIdGenerator, nodes, edges) new SparkPlanGraph(nodes, edges) } private def buildSparkPlanGraphNode( - planInfo: SparkPlanInfo, + plan: SparkPlan, nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, - SQLMetrics.getMetricParam(metric.metricParam)) + val metrics = plan.metrics.toSeq.map { case (key, metric) => + SQLPlanMetric(metric.name.getOrElse(key), metric.id, + metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) } val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), planInfo.nodeName, planInfo.simpleString, metrics) + nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) nodes += node - val childrenNodes = planInfo.children.map( + val childrenNodes = plan.children.map( child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) for (child <- childrenNodes) { edges += SparkPlanGraphEdge(child.id, node.id) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 4f2cad19bfb6b..82867ab4967bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -26,7 +26,6 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -83,8 +82,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( - df.queryExecution.executedPlan)).nodes.filter { node => + val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => val nodeMetrics = node.metrics.map { metric => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index f93d081d0c30e..c15aac775096c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -21,10 +21,10 @@ import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} -import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.test.SharedSQLContext class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { @@ -82,8 +82,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val executionId = 0 val df = createTestDataFrame val accumulatorIds = - SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)) - .nodes.flatMap(_.metrics.map(_.accumulatorId)) + SparkPlanGraph(df.queryExecution.executedPlan).nodes.flatMap(_.metrics.map(_.accumulatorId)) // Assume all accumulators are long var accumulatorValue = 0L val accumulatorUpdates = accumulatorIds.map { id => @@ -91,13 +90,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (id, accumulatorValue) }.toMap - listener.onOtherEvent(SparkListenerSQLExecutionStart( + listener.onExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) val executionUIData = listener.executionIdToData(0) @@ -207,8 +206,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), JobSucceeded )) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) assert(executionUIData.runningJobs.isEmpty) assert(executionUIData.succeededJobs === Seq(0)) @@ -221,20 +219,19 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + listener.onExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -251,13 +248,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + listener.onExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -274,8 +271,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) listener.onJobEnd(SparkListenerJobEnd( jobId = 1, time = System.currentTimeMillis(), @@ -292,20 +288,19 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + listener.onExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq.empty, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e7b376548787c..963d10eed62ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -42,7 +42,6 @@ trait SharedSQLContext extends SQLTestUtils { * Initialize the [[TestSQLContext]]. */ protected override def beforeAll(): Unit = { - SQLContext.clearSqlListener() if (_ctx == null) { _ctx = new TestSQLContext } From fdfd9a183e81287130032fa93f70e9c7c53091c5 Mon Sep 17 00:00:00 2001 From: Teng Qiu Date: Tue, 1 Dec 2015 07:27:32 +0900 Subject: [PATCH 0971/1168] [SPARK-12053][CORE] EventLoggingListener.getLogPath needs 4 parameters ```EventLoggingListener.getLogPath``` needs 4 input arguments: https://github.com/apache/spark/blob/v1.6.0-preview2/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala#L276-L280 the 3rd parameter should be appAttemptId, 4th parameter is codec... Author: Teng Qiu Closes #10044 from chutium/SPARK-12053. --- core/src/main/scala/org/apache/spark/deploy/master/Master.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 9952c97dbdffc..1355e1ad1b523 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -934,7 +934,7 @@ private[deploy] class Master( } val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, app.desc.eventLogCodec) + eventLogDir, app.id, appAttemptId = None, compressionCodecName = app.desc.eventLogCodec) val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) val inProgressExists = fs.exists(new Path(eventLogFilePrefix + EventLoggingListener.IN_PROGRESS)) From db5f114b1558b3b82208bf34c8b9e7ab063c75ee Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 30 Nov 2015 14:56:51 -0800 Subject: [PATCH 0972/1168] [SPARK-11689][ML] Add user guide and example code for LDA under spark.ml jira: https://issues.apache.org/jira/browse/SPARK-11689 Add simple user guide for LDA under spark.ml and example code under examples/. Use include_example to include example code in the user guide markdown. Check SPARK-11606 for instructions. Original PR is reverted due to document build error. https://github.com/apache/spark/pull/9722 mengxr feynmanliang yinxusen Sorry for the troubling. Author: Yuhao Yang Closes #9974 from hhbyyh/ldaMLExample. --- docs/ml-clustering.md | 31 ++++++ docs/ml-guide.md | 3 +- docs/mllib-guide.md | 1 + .../spark/examples/ml/JavaLDAExample.java | 97 +++++++++++++++++++ .../apache/spark/examples/ml/LDAExample.scala | 77 +++++++++++++++ 5 files changed, 208 insertions(+), 1 deletion(-) create mode 100644 docs/ml-clustering.md create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md new file mode 100644 index 0000000000000..cfefb5dfbde9e --- /dev/null +++ b/docs/ml-clustering.md @@ -0,0 +1,31 @@ +--- +layout: global +title: Clustering - ML +displayTitle: ML - Clustering +--- + +In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). + +## Latent Dirichlet allocation (LDA) + +`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, +and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by +`EMLDAOptimizer` to a `DistributedLDAModel` if needed. + +
    + +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. + +{% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} +
    + +
    \ No newline at end of file diff --git a/docs/ml-guide.md b/docs/ml-guide.md index be18a05361a17..6f35b30c3d4df 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -40,6 +40,7 @@ Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., provide class probabilities, and linear models provide model summaries. * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision Trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) @@ -950,4 +951,4 @@ model.transform(test) {% endhighlight %}
    - + \ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 91e50ccfecec4..54e35fcbb15af 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -69,6 +69,7 @@ We list major functionality from both below, with links to detailed guides. concepts. It also contains sections on using algorithms within the Pipelines API, for example: * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java new file mode 100644 index 0000000000000..3a5d3237c85f6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; +// $example on$ +import java.util.regex.Pattern; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.clustering.LDA; +import org.apache.spark.ml.clustering.LDAModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +/** + * An example demonstrating LDA + * Run with + *
    + * bin/run-example ml.JavaLDAExample
    + * 
    + */ +public class JavaLDAExample { + + // $example on$ + private static class ParseVector implements Function { + private static final Pattern separator = Pattern.compile(" "); + + @Override + public Row call(String line) { + String[] tok = separator.split(line); + double[] point = new double[tok.length]; + for (int i = 0; i < tok.length; ++i) { + point[i] = Double.parseDouble(tok[i]); + } + Vector[] points = {Vectors.dense(point)}; + return new GenericRow(points); + } + } + + public static void main(String[] args) { + + String inputFile = "data/mllib/sample_lda_data.txt"; + + // Parses the arguments + SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Loads data + JavaRDD points = jsc.textFile(inputFile).map(new ParseVector()); + StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; + StructType schema = new StructType(fields); + DataFrame dataset = sqlContext.createDataFrame(points, schema); + + // Trains a LDA model + LDA lda = new LDA() + .setK(10) + .setMaxIter(10); + LDAModel model = lda.fit(dataset); + + System.out.println(model.logLikelihood(dataset)); + System.out.println(model.logPerplexity(dataset)); + + // Shows the result + DataFrame topics = model.describeTopics(3); + topics.show(false); + model.transform(dataset).show(false); + + jsc.stop(); + } + // $example off$ +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala new file mode 100644 index 0000000000000..419ce3d87a6ac --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +// $example on$ +import org.apache.spark.ml.clustering.LDA +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.types.{StructField, StructType} +// $example off$ + +/** + * An example demonstrating a LDA of ML pipeline. + * Run with + * {{{ + * bin/run-example ml.LDAExample + * }}} + */ +object LDAExample { + + final val FEATURES_COL = "features" + + def main(args: Array[String]): Unit = { + + val input = "data/mllib/sample_lda_data.txt" + // Creates a Spark context and a SQL context + val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Loads data + val rowRDD = sc.textFile(input).filter(_.nonEmpty) + .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) + val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) + val dataset = sqlContext.createDataFrame(rowRDD, schema) + + // Trains a LDA model + val lda = new LDA() + .setK(10) + .setMaxIter(10) + .setFeaturesCol(FEATURES_COL) + val model = lda.fit(dataset) + val transformed = model.transform(dataset) + + val ll = model.logLikelihood(dataset) + val lp = model.logPerplexity(dataset) + + // describeTopics + val topics = model.describeTopics(3) + + // Shows the result + topics.show(false) + transformed.show(false) + + // $example off$ + sc.stop() + } +} +// scalastyle:on println From 46d4bdb8d7a0935ebf76b1d383ddc5a98dd10c01 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 30 Nov 2015 15:01:08 -0800 Subject: [PATCH 0973/1168] [SPARK-11975][ML] Remove duplicate mllib example (DT/RF/GBT in Java/Python) Remove duplicate mllib example (DT/RF/GBT in Java/Python). Since we have tutorial code for DT/RF/GBT classification/regression in Scala/Java/Python and example applications for DT/RF/GBT in Scala, so we mark these as duplicated and remove them. mengxr Author: Yanbo Liang Closes #9954 from yanboliang/SPARK-11975. --- .../examples/mllib/JavaDecisionTree.java | 116 -------------- .../mllib/JavaGradientBoostedTreesRunner.java | 126 --------------- .../mllib/JavaRandomForestExample.java | 139 ----------------- .../main/python/mllib/decision_tree_runner.py | 144 ------------------ .../python/mllib/gradient_boosted_trees.py | 77 ---------- .../python/mllib/random_forest_example.py | 90 ----------- 6 files changed, 692 deletions(-) delete mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java delete mode 100755 examples/src/main/python/mllib/decision_tree_runner.py delete mode 100644 examples/src/main/python/mllib/gradient_boosted_trees.py delete mode 100755 examples/src/main/python/mllib/random_forest_example.py diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java deleted file mode 100644 index 1f82e3f4cb18e..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.HashMap; - -import scala.Tuple2; - -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -/** - * Classification and regression using decision trees. - */ -public final class JavaDecisionTree { - - public static void main(String[] args) { - String datapath = "data/mllib/sample_libsvm_data.txt"; - if (args.length == 1) { - datapath = args[0]; - } else if (args.length > 1) { - System.err.println("Usage: JavaDecisionTree "); - System.exit(1); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); - - // Compute the number of classes from the data. - Integer numClasses = data.map(new Function() { - @Override public Double call(LabeledPoint p) { - return p.label(); - } - }).countByValue().size(); - - // Set parameters. - // Empty categoricalFeaturesInfo indicates all features are continuous. - HashMap categoricalFeaturesInfo = new HashMap(); - String impurity = "gini"; - Integer maxDepth = 5; - Integer maxBins = 32; - - // Train a DecisionTree model for classification. - final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / data.count(); - System.out.println("Training error: " + trainErr); - System.out.println("Learned classification tree model:\n" + model); - - // Train a DecisionTree model for regression. - impurity = "variance"; - final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - - // Evaluate model on training instances and compute training error - JavaPairRDD regressorPredictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(regressionModel.predict(p.features()), p.label()); - } - }); - Double trainMSE = - regressorPredictionAndLabel.map(new Function, Double>() { - @Override public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); - System.out.println("Training Mean Squared Error: " + trainMSE); - System.out.println("Learned regression tree model:\n" + regressionModel); - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java deleted file mode 100644 index a1844d5d07ad4..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoostedTrees; -import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; -import org.apache.spark.mllib.util.MLUtils; - -/** - * Classification and regression using gradient-boosted decision trees. - */ -public final class JavaGradientBoostedTreesRunner { - - private static void usage() { - System.err.println("Usage: JavaGradientBoostedTreesRunner " + - " "); - System.exit(-1); - } - - public static void main(String[] args) { - String datapath = "data/mllib/sample_libsvm_data.txt"; - String algo = "Classification"; - if (args.length >= 1) { - datapath = args[0]; - } - if (args.length >= 2) { - algo = args[1]; - } - if (args.length > 2) { - usage(); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); - - // Set parameters. - // Note: All features are treated as continuous. - BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); - boostingStrategy.setNumIterations(10); - boostingStrategy.treeStrategy().setMaxDepth(5); - - if (algo.equals("Classification")) { - // Compute the number of classes from the data. - Integer numClasses = data.map(new Function() { - @Override public Double call(LabeledPoint p) { - return p.label(); - } - }).countByValue().size(); - boostingStrategy.treeStrategy().setNumClasses(numClasses); - - // Train a GradientBoosting model for classification. - final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / data.count(); - System.out.println("Training error: " + trainErr); - System.out.println("Learned classification tree model:\n" + model); - } else if (algo.equals("Regression")) { - // Train a GradientBoosting model for classification. - final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainMSE = - predictionAndLabel.map(new Function, Double>() { - @Override public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); - System.out.println("Training Mean Squared Error: " + trainMSE); - System.out.println("Learned regression tree model:\n" + model); - } else { - usage(); - } - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java deleted file mode 100644 index 89a4e092a5af7..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import java.util.HashMap; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.RandomForest; -import org.apache.spark.mllib.tree.model.RandomForestModel; -import org.apache.spark.mllib.util.MLUtils; - -public final class JavaRandomForestExample { - - /** - * Note: This example illustrates binary classification. - * For information on multiclass classification, please refer to the JavaDecisionTree.java - * example. - */ - private static void testClassification(JavaRDD trainingData, - JavaRDD testData) { - // Train a RandomForest model. - // Empty categoricalFeaturesInfo indicates all features are continuous. - Integer numClasses = 2; - HashMap categoricalFeaturesInfo = new HashMap(); - Integer numTrees = 3; // Use more in practice. - String featureSubsetStrategy = "auto"; // Let the algorithm choose. - String impurity = "gini"; - Integer maxDepth = 4; - Integer maxBins = 32; - Integer seed = 12345; - - final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, - categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, - seed); - - // Evaluate model on test instances and compute test error - JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); - System.out.println("Test Error: " + testErr); - System.out.println("Learned classification forest model:\n" + model.toDebugString()); - } - - private static void testRegression(JavaRDD trainingData, - JavaRDD testData) { - // Train a RandomForest model. - // Empty categoricalFeaturesInfo indicates all features are continuous. - HashMap categoricalFeaturesInfo = new HashMap(); - Integer numTrees = 3; // Use more in practice. - String featureSubsetStrategy = "auto"; // Let the algorithm choose. - String impurity = "variance"; - Integer maxDepth = 4; - Integer maxBins = 32; - Integer seed = 12345; - - final RandomForestModel model = RandomForest.trainRegressor(trainingData, - categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, - seed); - - // Evaluate model on test instances and compute test error - JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / testData.count(); - System.out.println("Test Mean Squared Error: " + testMSE); - System.out.println("Learned regression forest model:\n" + model.toDebugString()); - } - - public static void main(String[] args) { - SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestExample"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - // Load and parse the data file. - String datapath = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); - // Split the data into training and test sets (30% held out for testing) - JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); - JavaRDD trainingData = splits[0]; - JavaRDD testData = splits[1]; - - System.out.println("\nRunning example of classification using RandomForest\n"); - testClassification(trainingData, testData); - - System.out.println("\nRunning example of regression using RandomForest\n"); - testRegression(trainingData, testData); - sc.stop(); - } -} diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py deleted file mode 100755 index 513ed8fd51450..0000000000000 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ /dev/null @@ -1,144 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Decision tree classification and regression using MLlib. - -This example requires NumPy (http://www.numpy.org/). -""" -from __future__ import print_function - -import numpy -import os -import sys - -from operator import add - -from pyspark import SparkContext -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree -from pyspark.mllib.util import MLUtils - - -def getAccuracy(dtModel, data): - """ - Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint]. - """ - seqOp = (lambda acc, x: acc + (x[0] == x[1])) - predictions = dtModel.predict(data.map(lambda x: x.features)) - truth = data.map(lambda p: p.label) - trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add) - if data.count() == 0: - return 0 - return trainCorrect / (0.0 + data.count()) - - -def getMSE(dtModel, data): - """ - Return mean squared error (MSE) of DecisionTreeModel on the given - RDD[LabeledPoint]. - """ - seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1])) - predictions = dtModel.predict(data.map(lambda x: x.features)) - truth = data.map(lambda p: p.label) - trainMSE = predictions.zip(truth).aggregate(0, seqOp, add) - if data.count() == 0: - return 0 - return trainMSE / (0.0 + data.count()) - - -def reindexClassLabels(data): - """ - Re-index class labels in a dataset to the range {0,...,numClasses-1}. - If all labels in that range already appear at least once, - then the returned RDD is the same one (without a mapping). - Note: If a label simply does not appear in the data, - the index will not include it. - Be aware of this when reindexing subsampled data. - :param data: RDD of LabeledPoint where labels are integer values - denoting labels for a classification problem. - :return: Pair (reindexedData, origToNewLabels) where - reindexedData is an RDD of LabeledPoint with labels in - the range {0,...,numClasses-1}, and - origToNewLabels is a dictionary mapping original labels - to new labels. - """ - # classCounts: class --> # examples in class - classCounts = data.map(lambda x: x.label).countByValue() - numExamples = sum(classCounts.values()) - sortedClasses = sorted(classCounts.keys()) - numClasses = len(classCounts) - # origToNewLabels: class --> index in 0,...,numClasses-1 - if (numClasses < 2): - print("Dataset for classification should have at least 2 classes." - " The given dataset had only %d classes." % numClasses, file=sys.stderr) - exit(1) - origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) - - print("numClasses = %d" % numClasses) - print("Per-class example fractions, counts:") - print("Class\tFrac\tCount") - for c in sortedClasses: - frac = classCounts[c] / (numExamples + 0.0) - print("%g\t%g\t%d" % (c, frac, classCounts[c])) - - if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): - return (data, origToNewLabels) - else: - reindexedData = \ - data.map(lambda x: LabeledPoint(origToNewLabels[x.label], x.features)) - return (reindexedData, origToNewLabels) - - -def usage(): - print("Usage: decision_tree_runner [libsvm format data filepath]", file=sys.stderr) - exit(1) - - -if __name__ == "__main__": - if len(sys.argv) > 2: - usage() - sc = SparkContext(appName="PythonDT") - - # Load data. - dataPath = 'data/mllib/sample_libsvm_data.txt' - if len(sys.argv) == 2: - dataPath = sys.argv[1] - if not os.path.isfile(dataPath): - sc.stop() - usage() - points = MLUtils.loadLibSVMFile(sc, dataPath) - - # Re-index class labels if needed. - (reindexedData, origToNewLabels) = reindexClassLabels(points) - numClasses = len(origToNewLabels) - - # Train a classifier. - categoricalFeaturesInfo = {} # no categorical features - model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses, - categoricalFeaturesInfo=categoricalFeaturesInfo) - # Print learned tree and stats. - print("Trained DecisionTree for classification:") - print(" Model numNodes: %d" % model.numNodes()) - print(" Model depth: %d" % model.depth()) - print(" Training accuracy: %g" % getAccuracy(model, reindexedData)) - if model.numNodes() < 20: - print(model.toDebugString()) - else: - print(model) - - sc.stop() diff --git a/examples/src/main/python/mllib/gradient_boosted_trees.py b/examples/src/main/python/mllib/gradient_boosted_trees.py deleted file mode 100644 index 781bd61c9d2b5..0000000000000 --- a/examples/src/main/python/mllib/gradient_boosted_trees.py +++ /dev/null @@ -1,77 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Gradient boosted Trees classification and regression using MLlib. -""" -from __future__ import print_function - -import sys - -from pyspark.context import SparkContext -from pyspark.mllib.tree import GradientBoostedTrees -from pyspark.mllib.util import MLUtils - - -def testClassification(trainingData, testData): - # Train a GradientBoostedTrees model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - model = GradientBoostedTrees.trainClassifier(trainingData, categoricalFeaturesInfo={}, - numIterations=30, maxDepth=4) - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count() \ - / float(testData.count()) - print('Test Error = ' + str(testErr)) - print('Learned classification ensemble model:') - print(model.toDebugString()) - - -def testRegression(trainingData, testData): - # Train a GradientBoostedTrees model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - model = GradientBoostedTrees.trainRegressor(trainingData, categoricalFeaturesInfo={}, - numIterations=30, maxDepth=4) - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda vp: (vp[0] - vp[1]) * (vp[0] - vp[1])).sum() \ - / float(testData.count()) - print('Test Mean Squared Error = ' + str(testMSE)) - print('Learned regression ensemble model:') - print(model.toDebugString()) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: gradient_boosted_trees", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonGradientBoostedTrees") - - # Load and parse the data file into an RDD of LabeledPoint. - data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') - # Split the data into training and test sets (30% held out for testing) - (trainingData, testData) = data.randomSplit([0.7, 0.3]) - - print('\nRunning example of classification using GradientBoostedTrees\n') - testClassification(trainingData, testData) - - print('\nRunning example of regression using GradientBoostedTrees\n') - testRegression(trainingData, testData) - - sc.stop() diff --git a/examples/src/main/python/mllib/random_forest_example.py b/examples/src/main/python/mllib/random_forest_example.py deleted file mode 100755 index 4cfdad868c66e..0000000000000 --- a/examples/src/main/python/mllib/random_forest_example.py +++ /dev/null @@ -1,90 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Random Forest classification and regression using MLlib. - -Note: This example illustrates binary classification. - For information on multiclass classification, please refer to the decision_tree_runner.py - example. -""" -from __future__ import print_function - -import sys - -from pyspark.context import SparkContext -from pyspark.mllib.tree import RandomForest -from pyspark.mllib.util import MLUtils - - -def testClassification(trainingData, testData): - # Train a RandomForest model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - # Note: Use larger numTrees in practice. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - model = RandomForest.trainClassifier(trainingData, numClasses=2, - categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='gini', maxDepth=4, maxBins=32) - - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count()\ - / float(testData.count()) - print('Test Error = ' + str(testErr)) - print('Learned classification forest model:') - print(model.toDebugString()) - - -def testRegression(trainingData, testData): - # Train a RandomForest model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - # Note: Use larger numTrees in practice. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='variance', maxDepth=4, maxBins=32) - - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda v_p1: (v_p1[0] - v_p1[1]) * (v_p1[0] - v_p1[1]))\ - .sum() / float(testData.count()) - print('Test Mean Squared Error = ' + str(testMSE)) - print('Learned regression forest model:') - print(model.toDebugString()) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: random_forest_example", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonRandomForestExample") - - # Load and parse the data file into an RDD of LabeledPoint. - data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') - # Split the data into training and test sets (30% held out for testing) - (trainingData, testData) = data.randomSplit([0.7, 0.3]) - - print('\nRunning example of classification using RandomForest\n') - testClassification(trainingData, testData) - - print('\nRunning example of regression using RandomForest\n') - testRegression(trainingData, testData) - - sc.stop() From ca56fe88b5373f6d03808e40b67378ded5c15923 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 30 Nov 2015 15:38:44 -0800 Subject: [PATCH 0974/1168] [SPARK-11960][MLLIB][DOC] User guide for streaming tests CC jkbradley mengxr josepablocam Author: Feynman Liang Closes #10005 from feynmanliang/streaming-test-user-guide. --- docs/mllib-guide.md | 1 + docs/mllib-statistics.md | 25 +++++++++++++++++++ .../examples/mllib/StreamingTestExample.scala | 2 ++ 3 files changed, 28 insertions(+) diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 54e35fcbb15af..43772adcf26e1 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -34,6 +34,7 @@ We list major functionality from both below, with links to detailed guides. * [correlations](mllib-statistics.html#correlations) * [stratified sampling](mllib-statistics.html#stratified-sampling) * [hypothesis testing](mllib-statistics.html#hypothesis-testing) + * [streaming significance testing](mllib-statistics.html#streaming-significance-testing) * [random data generation](mllib-statistics.html#random-data-generation) * [Classification and regression](mllib-classification-regression.html) * [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html) diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index ade5b0768aefe..de209f68e19ca 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -521,6 +521,31 @@ print(testResult) # summary of the test including the p-value, test statistic, +### Streaming Significance Testing +MLlib provides online implementations of some tests to support use cases +like A/B testing. These tests may be performed on a Spark Streaming +`DStream[(Boolean,Double)]` where the first element of each tuple +indicates control group (`false`) or treatment group (`true`) and the +second element is the value of an observation. + +Streaming significance testing supports the following parameters: + +* `peacePeriod` - The number of initial data points from the stream to +ignore, used to mitigate novelty effects. +* `windowSize` - The number of past batches to perform hypothesis +testing over. Setting to `0` will perform cumulative processing using +all prior batches. + + +
    +
    +[`StreamingTest`](api/scala/index.html#org.apache.spark.mllib.stat.test.StreamingTest) +provides streaming hypothesis testing. + +{% include_example scala/org/apache/spark/examples/mllib/StreamingTestExample.scala %} +
    +
    + ## Random data generation diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala index ab29f90254d34..b6677c6476639 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -64,6 +64,7 @@ object StreamingTestExample { dir.toString }) + // $example on$ val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { case Array(label, value) => (label.toBoolean, value.toDouble) }) @@ -75,6 +76,7 @@ object StreamingTestExample { val out = streamingTest.registerStream(data) out.print() + // $example off$ // Stop processing if test becomes significant or we time out var timeoutCounter = numBatchesTimeout From 7d9a876069cd48fff789b7d53346d7ad540c4e77 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 30 Nov 2015 15:42:10 -0800 Subject: [PATCH 0975/1168] fix Maven build --- .../main/scala/org/apache/spark/sql/execution/SparkPlan.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 507641ff8263e..a78177751c9dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -43,7 +43,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * populated by the query planning infrastructure. */ @transient - protected[spark] final val sqlContext = SQLContext.getActive().get + protected[spark] final val sqlContext = SQLContext.getActive().getOrElse(null) protected def sparkContext = sqlContext.sparkContext From 95aa0ee03a94926665cfd0d9e8c4e63c45e15b8c Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 30 Nov 2015 16:31:59 -0800 Subject: [PATCH 0976/1168] [SPARK-12058][HOTFIX] Disable KinesisStreamTests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit KinesisStreamTests in test.py is broken because of #9403. See https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/46896/testReport/(root)/KinesisStreamTests/test_kinesis_stream/ Because Streaming Python didn’t work when merging https://github.com/apache/spark/pull/9403, the PR build didn’t report the Python test failure actually. This PR just disabled the test to unblock #10039 Author: Shixiong Zhu Closes #10047 from zsxwing/disable-python-kinesis-test. --- python/pyspark/streaming/tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index d380d697bc51c..a647e6bf39581 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1409,6 +1409,7 @@ def test_kinesis_stream_api(self): InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, "awsAccessKey", "awsSecretKey") + @unittest.skip("Enable it when we fix SPAKR-12058") def test_kinesis_stream(self): if not are_kinesis_tests_enabled: sys.stderr.write( From 4445bf0b051d3e63f0fc3342a7236956d3140569 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 30 Nov 2015 16:37:27 -0800 Subject: [PATCH 0977/1168] [SPARK-12000] Fix API doc generation issues This pull request fixes multiple issues with API doc generation. - Modify the Jekyll plugin so that the entire doc build fails if API docs cannot be generated. This will make it easy to detect when the doc build breaks, since this will now trigger Jenkins failures. - Change how we handle the `-target` compiler option flag in order to fix `javadoc` generation. - Incorporate doc changes from thunterdb (in #10048). Closes #10048. Author: Josh Rosen Author: Timothy Hunter Closes #10049 from JoshRosen/fix-doc-build. --- docs/_plugins/copy_api_dirs.rb | 6 +++--- .../apache/spark/network/client/StreamCallback.java | 4 ++-- .../org/apache/spark/network/server/RpcHandler.java | 2 +- project/SparkBuild.scala | 11 ++++++++--- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 01718d98dffe0..f2f3e2e653149 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -27,7 +27,7 @@ cd("..") puts "Running 'build/sbt -Pkinesis-asl clean compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `build/sbt -Pkinesis-asl clean compile unidoc` + system("build/sbt -Pkinesis-asl clean compile unidoc") || raise("Unidoc generation failed") puts "Moving back into docs dir." cd("docs") @@ -117,7 +117,7 @@ puts "Moving to python/docs directory and building sphinx." cd("../python/docs") - puts `make html` + system(make html) || raise("Python doc generation failed") puts "Moving back into home dir." cd("../../") @@ -131,7 +131,7 @@ # Build SparkR API docs puts "Moving to R directory and building roxygen docs." cd("R") - puts `./create-docs.sh` + system("./create-docs.sh") || raise("R doc generation failed") puts "Moving back into home dir." cd("../") diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java index 093fada320cc3..51d34cac6e636 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -21,8 +21,8 @@ import java.nio.ByteBuffer; /** - * Callback for streaming data. Stream data will be offered to the {@link onData(ByteBuffer)} - * method as it arrives. Once all the stream data is received, {@link onComplete()} will be + * Callback for streaming data. Stream data will be offered to the {@link onData(String, ByteBuffer)} + * method as it arrives. Once all the stream data is received, {@link onComplete(String)} will be * called. *

    * The network library guarantees that a single thread will call these methods at a time, but diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 65109ddfe13b9..1a11f7b3820c6 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -55,7 +55,7 @@ public abstract void receive( /** * Receives an RPC message that does not expect a reply. The default implementation will - * call "{@link receive(TransportClient, byte[], RpcResponseCallback}" and log a warning if + * call "{@link receive(TransportClient, byte[], RpcResponseCallback)}" and log a warning if * any of the callback methods are called. * * @param client A channel client which enables the handler to make requests back to the sender diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f575f0012d59e..63290d8a666e6 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -160,7 +160,12 @@ object SparkBuild extends PomBuild { javacOptions in Compile ++= Seq( "-encoding", "UTF-8", - "-source", javacJVMVersion.value, + "-source", javacJVMVersion.value + ), + // This -target option cannot be set in the Compile configuration scope since `javadoc` doesn't + // play nicely with it; see https://github.com/sbt/sbt/issues/355#issuecomment-3817629 for + // additional discussion and explanation. + javacOptions in (Compile, compile) ++= Seq( "-target", javacJVMVersion.value ), @@ -547,9 +552,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn, testTags), // Skip actual catalyst, but include the subproject. // Catalyst is not public API and contains quasiquotes which break scaladoc. From 776d6cd2544a916964f602e7e09adcaaac8aa54f Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 30 Nov 2015 17:18:44 -0800 Subject: [PATCH 0978/1168] [SPARK-12035] Add more debug information in include_example tag of Jekyll https://issues.apache.org/jira/browse/SPARK-12035 When we debuging lots of example code files, like in https://github.com/apache/spark/pull/10002, it's hard to know which file causes errors due to limited information in `include_example.rb`. With their filenames, we can locate bugs easily. Author: Xusen Yin Closes #10026 from yinxusen/SPARK-12035. --- docs/_plugins/include_example.rb | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 564c86680f68e..f7485826a762d 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -75,10 +75,10 @@ def select_lines(code) .select { |l, i| l.include? "$example off$" } .map { |l, i| i } - raise "Start indices amount is not equal to end indices amount, please check the code." \ + raise "Start indices amount is not equal to end indices amount, see #{@file}." \ unless startIndices.size == endIndices.size - raise "No code is selected by include_example, please check the code." \ + raise "No code is selected by include_example, see #{@file}." \ if startIndices.size == 0 # Select and join code blocks together, with a space line between each of two continuous @@ -86,8 +86,10 @@ def select_lines(code) lastIndex = -1 result = "" startIndices.zip(endIndices).each do |start, endline| - raise "Overlapping between two example code blocks are not allowed." if start <= lastIndex - raise "$example on$ should not be in the same line with $example off$." if start == endline + raise "Overlapping between two example code blocks are not allowed, see #{@file}." \ + if start <= lastIndex + raise "$example on$ should not be in the same line with $example off$, see #{@file}." \ + if start == endline lastIndex = endline range = Range.new(start + 1, endline - 1) result += trim_codeblock(lines[range]).join From 378f13781e04f15076d6d2ffac5c1352f87a8040 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Mon, 30 Nov 2015 17:19:26 -0800 Subject: [PATCH 0979/1168] [SPARK-12037][CORE] initialize heartbeatReceiverRef before calling startDriverHeartbeat https://issues.apache.org/jira/browse/SPARK-12037 a simple fix by changing the order of the statements Author: CodingCat Closes #10032 from CodingCat/SPARK-12037. --- .../main/scala/org/apache/spark/executor/Executor.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 6154f06e3ac11..7b68dfe5ad06e 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -109,6 +109,10 @@ private[spark] class Executor( // Executor for the heartbeat task. private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater") + // must be initialized before running startDriverHeartbeat() + private val heartbeatReceiverRef = + RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) + startDriverHeartbeater() def launchTask( @@ -411,9 +415,6 @@ private[spark] class Executor( } } - private val heartbeatReceiverRef = - RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) - /** Reports heartbeat and metrics for active tasks to the driver. */ private def reportHeartBeat(): Unit = { // list of (task id, metrics) to send back to the driver From 24bfd58b03d61edc5c20b05817f7ee6c32991dbb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 30 Nov 2015 17:22:05 -0800 Subject: [PATCH 0980/1168] [SPARK-12007][NETWORK] Avoid copies in the network lib's RPC layer. This change seems large, but most of it is just replacing `byte[]` with `ByteBuffer` and `new byte[]` with `ByteBuffer.allocate()`, since it changes the network library's API. The following are parts of the code that actually have meaningful changes: - The Message implementations were changed to inherit from a new AbstractMessage that can optionally hold a reference to a body (in the form of a ManagedBuffer); this is similar to how ResponseWithBody worked before, except now it's not restricted to just responses. - The TransportFrameDecoder was pretty much rewritten to avoid copies as much as possible; it doesn't rely on CompositeByteBuf to accumulate incoming data anymore, since CompositeByteBuf has issues when slices are retained. The code now is able to create frames without having to resort to copying bytes except for a few bytes (containing the frame length) in very rare cases. - Some minor changes in the SASL layer to convert things back to `byte[]` since the JDK SASL API operates on those. Author: Marcelo Vanzin Closes #9987 from vanzin/SPARK-12007. --- .../mesos/MesosExternalShuffleService.scala | 3 +- .../network/netty/NettyBlockRpcServer.scala | 8 +- .../netty/NettyBlockTransferService.scala | 6 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 16 +- .../org/apache/spark/rpc/netty/Outbox.scala | 9 +- .../rpc/netty/NettyRpcHandlerSuite.scala | 3 +- .../network/client/RpcResponseCallback.java | 4 +- .../spark/network/client/TransportClient.java | 16 +- .../client/TransportResponseHandler.java | 16 +- .../network/protocol/AbstractMessage.java | 54 +++++++ ...Body.java => AbstractResponseMessage.java} | 16 +- .../network/protocol/ChunkFetchFailure.java | 2 +- .../network/protocol/ChunkFetchRequest.java | 2 +- .../network/protocol/ChunkFetchSuccess.java | 8 +- .../spark/network/protocol/Message.java | 11 +- .../network/protocol/MessageEncoder.java | 29 ++-- .../spark/network/protocol/OneWayMessage.java | 33 ++-- .../spark/network/protocol/RpcFailure.java | 2 +- .../spark/network/protocol/RpcRequest.java | 34 +++-- .../spark/network/protocol/RpcResponse.java | 39 +++-- .../spark/network/protocol/StreamFailure.java | 2 +- .../spark/network/protocol/StreamRequest.java | 2 +- .../network/protocol/StreamResponse.java | 19 +-- .../network/sasl/SaslClientBootstrap.java | 12 +- .../spark/network/sasl/SaslMessage.java | 31 ++-- .../spark/network/sasl/SaslRpcHandler.java | 26 +++- .../spark/network/server/MessageHandler.java | 2 +- .../spark/network/server/NoOpRpcHandler.java | 8 +- .../spark/network/server/RpcHandler.java | 8 +- .../server/TransportChannelHandler.java | 2 +- .../server/TransportRequestHandler.java | 15 +- .../apache/spark/network/util/JavaUtils.java | 48 ++++-- .../network/util/TransportFrameDecoder.java | 142 ++++++++++++------ .../network/ChunkFetchIntegrationSuite.java | 5 +- .../apache/spark/network/ProtocolSuite.java | 10 +- .../RequestTimeoutIntegrationSuite.java | 51 ++++--- .../spark/network/RpcIntegrationSuite.java | 26 ++-- .../org/apache/spark/network/StreamSuite.java | 5 +- .../TransportResponseHandlerSuite.java | 24 +-- .../spark/network/sasl/SparkSaslSuite.java | 43 ++++-- .../util/TransportFrameDecoderSuite.java | 23 ++- .../shuffle/ExternalShuffleBlockHandler.java | 9 +- .../shuffle/ExternalShuffleClient.java | 3 +- .../shuffle/OneForOneBlockFetcher.java | 7 +- .../mesos/MesosExternalShuffleClient.java | 5 +- .../protocol/BlockTransferMessage.java | 8 +- .../network/sasl/SaslIntegrationSuite.java | 18 ++- .../shuffle/BlockTransferMessagesSuite.java | 2 +- .../ExternalShuffleBlockHandlerSuite.java | 21 +-- .../shuffle/OneForOneBlockFetcherSuite.java | 8 +- 50 files changed, 589 insertions(+), 307 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java rename network/common/src/main/java/org/apache/spark/network/protocol/{ResponseWithBody.java => AbstractResponseMessage.java} (63%) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 12337a940a414..8ffcfc0878a42 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.mesos import java.net.SocketAddress +import java.nio.ByteBuffer import scala.collection.mutable @@ -56,7 +57,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo } } connectedApps(address) = appId - callback.onSuccess(new Array[Byte](0)) + callback.onSuccess(ByteBuffer.allocate(0)) case _ => super.handleMessage(message, client, callback) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 76968249fb625..df8c21fb837ed 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -47,9 +47,9 @@ class NettyBlockRpcServer( override def receive( client: TransportClient, - messageBytes: Array[Byte], + rpcMessage: ByteBuffer, responseContext: RpcResponseCallback): Unit = { - val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes) + val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage) logTrace(s"Received request: $message") message match { @@ -58,7 +58,7 @@ class NettyBlockRpcServer( openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) + responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) case uploadBlock: UploadBlock => // StorageLevel is serialized as bytes using our JavaSerializer. @@ -66,7 +66,7 @@ class NettyBlockRpcServer( serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level) - responseContext.onSuccess(new Array[Byte](0)) + responseContext.onSuccess(ByteBuffer.allocate(0)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b0694e3c6c8af..82c16e855b0c0 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,6 +17,8 @@ package org.apache.spark.network.netty +import java.nio.ByteBuffer + import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} @@ -133,9 +135,9 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage data } - client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray, + client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer, new RpcResponseCallback { - override def onSuccess(response: Array[Byte]): Unit = { + override def onSuccess(response: ByteBuffer): Unit = { logTrace(s"Successfully uploaded block $blockId") result.success((): Unit) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index c7d74fa1d9195..68c5f44145b0d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -241,16 +241,14 @@ private[netty] class NettyRpcEnv( promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } - private[netty] def serialize(content: Any): Array[Byte] = { - val buffer = javaSerializerInstance.serialize(content) - java.util.Arrays.copyOfRange( - buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) + private[netty] def serialize(content: Any): ByteBuffer = { + javaSerializerInstance.serialize(content) } - private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = { + private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = { NettyRpcEnv.currentClient.withValue(client) { deserialize { () => - javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes)) + javaSerializerInstance.deserialize[T](bytes) } } } @@ -557,7 +555,7 @@ private[netty] class NettyRpcHandler( override def receive( client: TransportClient, - message: Array[Byte], + message: ByteBuffer, callback: RpcResponseCallback): Unit = { val messageToDispatch = internalReceive(client, message) dispatcher.postRemoteMessage(messageToDispatch, callback) @@ -565,12 +563,12 @@ private[netty] class NettyRpcHandler( override def receive( client: TransportClient, - message: Array[Byte]): Unit = { + message: ByteBuffer): Unit = { val messageToDispatch = internalReceive(client, message) dispatcher.postOneWayMessage(messageToDispatch) } - private def internalReceive(client: TransportClient, message: Array[Byte]): RequestMessage = { + private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 36fdd00bbc4c2..2316ebe347bb7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -17,6 +17,7 @@ package org.apache.spark.rpc.netty +import java.nio.ByteBuffer import java.util.concurrent.Callable import javax.annotation.concurrent.GuardedBy @@ -34,7 +35,7 @@ private[netty] sealed trait OutboxMessage { } -private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends OutboxMessage +private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends OutboxMessage with Logging { override def sendWith(client: TransportClient): Unit = { @@ -48,9 +49,9 @@ private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends Outb } private[netty] case class RpcOutboxMessage( - content: Array[Byte], + content: ByteBuffer, _onFailure: (Throwable) => Unit, - _onSuccess: (TransportClient, Array[Byte]) => Unit) + _onSuccess: (TransportClient, ByteBuffer) => Unit) extends OutboxMessage with RpcResponseCallback { private var client: TransportClient = _ @@ -70,7 +71,7 @@ private[netty] case class RpcOutboxMessage( _onFailure(e) } - override def onSuccess(response: Array[Byte]): Unit = { + override def onSuccess(response: ByteBuffer): Unit = { _onSuccess(client, response) } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index 323184cdd9b6e..ebd6f700710bd 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.rpc.netty import java.net.InetSocketAddress +import java.nio.ByteBuffer import io.netty.channel.Channel import org.mockito.Mockito._ @@ -32,7 +33,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) val sm = mock(classOf[StreamManager]) - when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any())) .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null)) test("receive") { diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java index 6ec960d795420..47e93f9846fa6 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java +++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -17,13 +17,15 @@ package org.apache.spark.network.client; +import java.nio.ByteBuffer; + /** * Callback for the result of a single RPC. This will be invoked once with either success or * failure. */ public interface RpcResponseCallback { /** Successful serialized result from server. */ - void onSuccess(byte[] response); + void onSuccess(ByteBuffer response); /** Exception either propagated from server or raised on client side. */ void onFailure(Throwable e); diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8a58e7b24585b..c49ca4d5ee925 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -20,6 +20,7 @@ import java.io.Closeable; import java.io.IOException; import java.net.SocketAddress; +import java.nio.ByteBuffer; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -36,6 +37,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcRequest; @@ -212,7 +214,7 @@ public void operationComplete(ChannelFuture future) throws Exception { * @param callback Callback to handle the RPC's reply. * @return The RPC's id. */ - public long sendRpc(byte[] message, final RpcResponseCallback callback) { + public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) { final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); @@ -220,7 +222,7 @@ public long sendRpc(byte[] message, final RpcResponseCallback callback) { final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); handler.addRpcRequest(requestId, callback); - channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( + channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { @@ -249,12 +251,12 @@ public void operationComplete(ChannelFuture future) throws Exception { * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to * a specified timeout for a response. */ - public byte[] sendRpcSync(byte[] message, long timeoutMs) { - final SettableFuture result = SettableFuture.create(); + public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) { + final SettableFuture result = SettableFuture.create(); sendRpc(message, new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { result.set(response); } @@ -279,8 +281,8 @@ public void onFailure(Throwable e) { * * @param message The message to send. */ - public void send(byte[] message) { - channel.writeAndFlush(new OneWayMessage(message)); + public void send(ByteBuffer message) { + channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message))); } /** diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 4c15045363b84..23a8dba593442 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -136,7 +136,7 @@ public void exceptionCaught(Throwable cause) { } @Override - public void handle(ResponseMessage message) { + public void handle(ResponseMessage message) throws Exception { String remoteAddress = NettyUtils.getRemoteAddress(channel); if (message instanceof ChunkFetchSuccess) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; @@ -144,11 +144,11 @@ public void handle(ResponseMessage message) { if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", resp.streamChunkId, remoteAddress); - resp.body.release(); + resp.body().release(); } else { outstandingFetches.remove(resp.streamChunkId); - listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body); - resp.body.release(); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); + resp.body().release(); } } else if (message instanceof ChunkFetchFailure) { ChunkFetchFailure resp = (ChunkFetchFailure) message; @@ -166,10 +166,14 @@ public void handle(ResponseMessage message) { RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", - resp.requestId, remoteAddress, resp.response.length); + resp.requestId, remoteAddress, resp.body().size()); } else { outstandingRpcs.remove(resp.requestId); - listener.onSuccess(resp.response); + try { + listener.onSuccess(resp.body().nioByteBuffer()); + } finally { + resp.body().release(); + } } } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java new file mode 100644 index 0000000000000..2924218c2f08b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Abstract class for messages which optionally contain a body kept in a separate buffer. + */ +public abstract class AbstractMessage implements Message { + private final ManagedBuffer body; + private final boolean isBodyInFrame; + + protected AbstractMessage() { + this(null, false); + } + + protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) { + this.body = body; + this.isBodyInFrame = isBodyInFrame; + } + + @Override + public ManagedBuffer body() { + return body; + } + + @Override + public boolean isBodyInFrame() { + return isBodyInFrame; + } + + protected boolean equals(AbstractMessage other) { + return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java similarity index 63% rename from network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java rename to network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java index 67be77e39f711..c362c92fc4f52 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java @@ -17,23 +17,15 @@ package org.apache.spark.network.protocol; -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; /** - * Abstract class for response messages that contain a large data portion kept in a separate - * buffer. These messages are treated especially by MessageEncoder. + * Abstract class for response messages. */ -public abstract class ResponseWithBody implements ResponseMessage { - public final ManagedBuffer body; - public final boolean isBodyInFrame; +public abstract class AbstractResponseMessage extends AbstractMessage implements ResponseMessage { - protected ResponseWithBody(ManagedBuffer body, boolean isBodyInFrame) { - this.body = body; - this.isBodyInFrame = isBodyInFrame; + protected AbstractResponseMessage(ManagedBuffer body, boolean isBodyInFrame) { + super(body, isBodyInFrame); } public abstract ResponseMessage createFailureResponse(String error); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index f0363830b61ac..7b28a9a969486 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -23,7 +23,7 @@ /** * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. */ -public final class ChunkFetchFailure implements ResponseMessage { +public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage { public final StreamChunkId streamChunkId; public final String errorString; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index 5a173af54f618..26d063feb5fe3 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -24,7 +24,7 @@ * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ -public final class ChunkFetchRequest implements RequestMessage { +public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage { public final StreamChunkId streamChunkId; public ChunkFetchRequest(StreamChunkId streamChunkId) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index e6a7e9a8b4145..94c2ac9b20e43 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -30,7 +30,7 @@ * may be written by Netty in a more efficient manner (i.e., zero-copy write). * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. */ -public final class ChunkFetchSuccess extends ResponseWithBody { +public final class ChunkFetchSuccess extends AbstractResponseMessage { public final StreamChunkId streamChunkId; public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { @@ -67,14 +67,14 @@ public static ChunkFetchSuccess decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(streamChunkId, body); + return Objects.hashCode(streamChunkId, body()); } @Override public boolean equals(Object other) { if (other instanceof ChunkFetchSuccess) { ChunkFetchSuccess o = (ChunkFetchSuccess) other; - return streamChunkId.equals(o.streamChunkId) && body.equals(o.body); + return streamChunkId.equals(o.streamChunkId) && super.equals(o); } return false; } @@ -83,7 +83,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("streamChunkId", streamChunkId) - .add("buffer", body) + .add("buffer", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java index 39afd03db60ee..66f5b8b3a59c8 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -19,17 +19,25 @@ import io.netty.buffer.ByteBuf; +import org.apache.spark.network.buffer.ManagedBuffer; + /** An on-the-wire transmittable message. */ public interface Message extends Encodable { /** Used to identify this request type. */ Type type(); + /** An optional body for the message. */ + ManagedBuffer body(); + + /** Whether to include the body of the message in the same frame as the message. */ + boolean isBodyInFrame(); + /** Preceding every serialized Message is its type, which allows us to deserialize it. */ public static enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9); + OneWayMessage(9), User(-1); private final byte id; @@ -57,6 +65,7 @@ public static Type decode(ByteBuf buf) { case 7: return StreamResponse; case 8: return StreamFailure; case 9: return OneWayMessage; + case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 6cce97c807dc0..abca22347b783 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -42,25 +42,28 @@ public final class MessageEncoder extends MessageToMessageEncoder { * data to 'out', in order to enable zero-copy transfer. */ @Override - public void encode(ChannelHandlerContext ctx, Message in, List out) { + public void encode(ChannelHandlerContext ctx, Message in, List out) throws Exception { Object body = null; long bodyLength = 0; boolean isBodyInFrame = false; - // Detect ResponseWithBody messages and get the data buffer out of them. - // The body is used in order to enable zero-copy transfer for the payload. - if (in instanceof ResponseWithBody) { - ResponseWithBody resp = (ResponseWithBody) in; + // If the message has a body, take it out to enable zero-copy transfer for the payload. + if (in.body() != null) { try { - bodyLength = resp.body.size(); - body = resp.body.convertToNetty(); - isBodyInFrame = resp.isBodyInFrame; + bodyLength = in.body().size(); + body = in.body().convertToNetty(); + isBodyInFrame = in.isBodyInFrame(); } catch (Exception e) { - // Re-encode this message as a failure response. - String error = e.getMessage() != null ? e.getMessage() : "null"; - logger.error(String.format("Error processing %s for client %s", - resp, ctx.channel().remoteAddress()), e); - encode(ctx, resp.createFailureResponse(error), out); + if (in instanceof AbstractResponseMessage) { + AbstractResponseMessage resp = (AbstractResponseMessage) in; + // Re-encode this message as a failure response. + String error = e.getMessage() != null ? e.getMessage() : "null"; + logger.error(String.format("Error processing %s for client %s", + in, ctx.channel().remoteAddress()), e); + encode(ctx, resp.createFailureResponse(error), out); + } else { + throw e; + } return; } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java index 95a0270be3da9..efe0470f35875 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -17,21 +17,21 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** * A RPC that does not expect a reply, which is handled by a remote * {@link org.apache.spark.network.server.RpcHandler}. */ -public final class OneWayMessage implements RequestMessage { - /** Serialized message to send to remote RpcHandler. */ - public final byte[] message; +public final class OneWayMessage extends AbstractMessage implements RequestMessage { - public OneWayMessage(byte[] message) { - this.message = message; + public OneWayMessage(ManagedBuffer body) { + super(body, true); } @Override @@ -39,29 +39,34 @@ public OneWayMessage(byte[] message) { @Override public int encodedLength() { - return Encoders.ByteArrays.encodedLength(message); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 4; } @Override public void encode(ByteBuf buf) { - Encoders.ByteArrays.encode(buf, message); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static OneWayMessage decode(ByteBuf buf) { - byte[] message = Encoders.ByteArrays.decode(buf); - return new OneWayMessage(message); + // See comment in encodedLength(). + buf.readInt(); + return new OneWayMessage(new NettyManagedBuffer(buf.retain())); } @Override public int hashCode() { - return Arrays.hashCode(message); + return Objects.hashCode(body()); } @Override public boolean equals(Object other) { if (other instanceof OneWayMessage) { OneWayMessage o = (OneWayMessage) other; - return Arrays.equals(message, o.message); + return super.equals(o); } return false; } @@ -69,7 +74,7 @@ public boolean equals(Object other) { @Override public String toString() { return Objects.toStringHelper(this) - .add("message", message) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index 2dfc7876ba328..a76624ef5dc96 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -21,7 +21,7 @@ import io.netty.buffer.ByteBuf; /** Response to {@link RpcRequest} for a failed RPC. */ -public final class RpcFailure implements ResponseMessage { +public final class RpcFailure extends AbstractMessage implements ResponseMessage { public final long requestId; public final String errorString; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index 745039db742fa..96213794a8015 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -17,26 +17,25 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. * This will correspond to a single * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ -public final class RpcRequest implements RequestMessage { +public final class RpcRequest extends AbstractMessage implements RequestMessage { /** Used to link an RPC request with its response. */ public final long requestId; - /** Serialized message to send to remote RpcHandler. */ - public final byte[] message; - - public RpcRequest(long requestId, byte[] message) { + public RpcRequest(long requestId, ManagedBuffer message) { + super(message, true); this.requestId = requestId; - this.message = message; } @Override @@ -44,31 +43,36 @@ public RpcRequest(long requestId, byte[] message) { @Override public int encodedLength() { - return 8 + Encoders.ByteArrays.encodedLength(message); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 8 + 4; } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - Encoders.ByteArrays.encode(buf, message); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static RpcRequest decode(ByteBuf buf) { long requestId = buf.readLong(); - byte[] message = Encoders.ByteArrays.decode(buf); - return new RpcRequest(requestId, message); + // See comment in encodedLength(). + buf.readInt(); + return new RpcRequest(requestId, new NettyManagedBuffer(buf.retain())); } @Override public int hashCode() { - return Objects.hashCode(requestId, Arrays.hashCode(message)); + return Objects.hashCode(requestId, body()); } @Override public boolean equals(Object other) { if (other instanceof RpcRequest) { RpcRequest o = (RpcRequest) other; - return requestId == o.requestId && Arrays.equals(message, o.message); + return requestId == o.requestId && super.equals(o); } return false; } @@ -77,7 +81,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("requestId", requestId) - .add("message", message) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index 1671cd444f039..bae866e14a1e1 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -17,49 +17,62 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** Response to {@link RpcRequest} for a successful RPC. */ -public final class RpcResponse implements ResponseMessage { +public final class RpcResponse extends AbstractResponseMessage { public final long requestId; - public final byte[] response; - public RpcResponse(long requestId, byte[] response) { + public RpcResponse(long requestId, ManagedBuffer message) { + super(message, true); this.requestId = requestId; - this.response = response; } @Override public Type type() { return Type.RpcResponse; } @Override - public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); } + public int encodedLength() { + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 8 + 4; + } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - Encoders.ByteArrays.encode(buf, response); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new RpcFailure(requestId, error); } public static RpcResponse decode(ByteBuf buf) { long requestId = buf.readLong(); - byte[] response = Encoders.ByteArrays.decode(buf); - return new RpcResponse(requestId, response); + // See comment in encodedLength(). + buf.readInt(); + return new RpcResponse(requestId, new NettyManagedBuffer(buf.retain())); } @Override public int hashCode() { - return Objects.hashCode(requestId, Arrays.hashCode(response)); + return Objects.hashCode(requestId, body()); } @Override public boolean equals(Object other) { if (other instanceof RpcResponse) { RpcResponse o = (RpcResponse) other; - return requestId == o.requestId && Arrays.equals(response, o.response); + return requestId == o.requestId && super.equals(o); } return false; } @@ -68,7 +81,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("requestId", requestId) - .add("response", response) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java index e3dade2ebf905..26747ee55b4de 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java @@ -26,7 +26,7 @@ /** * Message indicating an error when transferring a stream. */ -public final class StreamFailure implements ResponseMessage { +public final class StreamFailure extends AbstractMessage implements ResponseMessage { public final String streamId; public final String error; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java index 821e8f53884d7..35af5a84ba6bd 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java @@ -29,7 +29,7 @@ * The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before * the data can be streamed. */ -public final class StreamRequest implements RequestMessage { +public final class StreamRequest extends AbstractMessage implements RequestMessage { public final String streamId; public StreamRequest(String streamId) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index ac5ab9a323a11..51b899930f721 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -30,15 +30,15 @@ * sender. The receiver is expected to set a temporary channel handler that will consume the * number of bytes this message says the stream has. */ -public final class StreamResponse extends ResponseWithBody { - public final String streamId; - public final long byteCount; +public final class StreamResponse extends AbstractResponseMessage { + public final String streamId; + public final long byteCount; - public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { - super(buffer, false); - this.streamId = streamId; - this.byteCount = byteCount; - } + public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { + super(buffer, false); + this.streamId = streamId; + this.byteCount = byteCount; + } @Override public Type type() { return Type.StreamResponse; } @@ -68,7 +68,7 @@ public static StreamResponse decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(byteCount, streamId); + return Objects.hashCode(byteCount, streamId, body()); } @Override @@ -85,6 +85,7 @@ public String toString() { return Objects.toStringHelper(this) .add("streamId", streamId) .add("byteCount", byteCount) + .add("body", body()) .toString(); } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 69923769d44b4..68381037d6891 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -17,6 +17,8 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.nio.ByteBuffer; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; @@ -28,6 +30,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; /** @@ -70,11 +73,12 @@ public void doBootstrap(TransportClient client, Channel channel) { while (!saslClient.isComplete()) { SaslMessage msg = new SaslMessage(appId, payload); - ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); + buf.writeBytes(msg.body().nioByteBuffer()); - byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs()); - payload = saslClient.response(response); + ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs()); + payload = saslClient.response(JavaUtils.bufferToArray(response)); } client.setClientId(appId); @@ -88,6 +92,8 @@ public void doBootstrap(TransportClient client, Channel channel) { saslClient = null; logger.debug("Channel {} configured for SASL encryption.", client); } + } catch (IOException ioe) { + throw new RuntimeException(ioe); } finally { if (saslClient != null) { try { diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index cad76ab7aa54e..e52b526f09c77 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -18,38 +18,50 @@ package org.apache.spark.network.sasl; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; -import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.protocol.AbstractMessage; /** * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged * with the given appId. This appId allows a single SaslRpcHandler to multiplex different * applications which may be using different sets of credentials. */ -class SaslMessage implements Encodable { +class SaslMessage extends AbstractMessage { /** Serialization tag used to catch incorrect payloads. */ private static final byte TAG_BYTE = (byte) 0xEA; public final String appId; - public final byte[] payload; - public SaslMessage(String appId, byte[] payload) { + public SaslMessage(String appId, byte[] message) { + this(appId, Unpooled.wrappedBuffer(message)); + } + + public SaslMessage(String appId, ByteBuf message) { + super(new NettyManagedBuffer(message), true); this.appId = appId; - this.payload = payload; } + @Override + public Type type() { return Type.User; } + @Override public int encodedLength() { - return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 1 + Encoders.Strings.encodedLength(appId) + 4; } @Override public void encode(ByteBuf buf) { buf.writeByte(TAG_BYTE); Encoders.Strings.encode(buf, appId); - Encoders.ByteArrays.encode(buf, payload); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static SaslMessage decode(ByteBuf buf) { @@ -59,7 +71,8 @@ public static SaslMessage decode(ByteBuf buf) { } String appId = Encoders.Strings.decode(buf); - byte[] payload = Encoders.ByteArrays.decode(buf); - return new SaslMessage(appId, payload); + // See comment in encodedLength(). + buf.readInt(); + return new SaslMessage(appId, buf.retain()); } } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 830db94b890c5..c215bd9d15045 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -17,8 +17,11 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.nio.ByteBuffer; import javax.security.sasl.Sasl; +import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import org.slf4j.Logger; @@ -28,6 +31,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; /** @@ -70,14 +74,20 @@ class SaslRpcHandler extends RpcHandler { } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { if (isComplete) { // Authentication complete, delegate to base handler. delegate.receive(client, message, callback); return; } - SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + SaslMessage saslMessage; + try { + saslMessage = SaslMessage.decode(nettyBuf); + } finally { + nettyBuf.release(); + } if (saslServer == null) { // First message in the handshake, setup the necessary state. @@ -86,8 +96,14 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback conf.saslServerAlwaysEncrypt()); } - byte[] response = saslServer.response(saslMessage.payload); - callback.onSuccess(response); + byte[] response; + try { + response = saslServer.response(JavaUtils.bufferToArray( + saslMessage.body().nioByteBuffer())); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + callback.onSuccess(ByteBuffer.wrap(response)); // Setup encryption after the SASL response is sent, otherwise the client can't parse the // response. It's ok to change the channel pipeline here since we are processing an incoming @@ -109,7 +125,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } @Override - public void receive(TransportClient client, byte[] message) { + public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java index b80c15106ecbd..3843406b27403 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -26,7 +26,7 @@ */ public abstract class MessageHandler { /** Handles the receipt of a single message. */ - public abstract void handle(T message); + public abstract void handle(T message) throws Exception; /** Invoked when an exception was caught on the Channel. */ public abstract void exceptionCaught(Throwable cause); diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 1502b7489e864..6ed61da5c7eff 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -1,5 +1,3 @@ -package org.apache.spark.network.server; - /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -17,6 +15,10 @@ * limitations under the License. */ +package org.apache.spark.network.server; + +import java.nio.ByteBuffer; + import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; @@ -29,7 +31,7 @@ public NoOpRpcHandler() { } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { throw new UnsupportedOperationException("Cannot handle messages"); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 1a11f7b3820c6..ee1c683699478 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import java.nio.ByteBuffer; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,7 +46,7 @@ public abstract class RpcHandler { */ public abstract void receive( TransportClient client, - byte[] message, + ByteBuffer message, RpcResponseCallback callback); /** @@ -62,7 +64,7 @@ public abstract void receive( * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. */ - public void receive(TransportClient client, byte[] message) { + public void receive(TransportClient client, ByteBuffer message) { receive(client, message, ONE_WAY_CALLBACK); } @@ -79,7 +81,7 @@ private static class OneWayRpcCallback implements RpcResponseCallback { private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class); @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { logger.warn("Response provided for one-way RPC."); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 3164e00679035..09435bcbab35e 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -99,7 +99,7 @@ public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { } @Override - public void channelRead0(ChannelHandlerContext ctx, Message request) { + public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception { if (request instanceof RequestMessage) { requestHandler.handle((RequestMessage) request); } else { diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index db18ea77d1073..c864d7ce16bd3 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import java.nio.ByteBuffer; + import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import io.netty.channel.Channel; @@ -26,6 +28,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.protocol.ChunkFetchRequest; @@ -143,10 +146,10 @@ private void processStreamRequest(final StreamRequest req) { private void processRpcRequest(final RpcRequest req) { try { - rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { + rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { - respond(new RpcResponse(req.requestId, response)); + public void onSuccess(ByteBuffer response) { + respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); } @Override @@ -157,14 +160,18 @@ public void onFailure(Throwable e) { } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } finally { + req.body().release(); } } private void processOneWayMessage(OneWayMessage req) { try { - rpcHandler.receive(reverseClient, req.message); + rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() for one-way message.", e); + } finally { + req.body().release(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 7d27439cfde7a..b3d8e0cd7cdcd 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -132,7 +132,7 @@ private static boolean isSymlink(File file) throws IOException { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static final ImmutableMap timeSuffixes = + private static final ImmutableMap timeSuffixes = ImmutableMap.builder() .put("us", TimeUnit.MICROSECONDS) .put("ms", TimeUnit.MILLISECONDS) @@ -164,32 +164,32 @@ private static boolean isSymlink(File file) throws IOException { */ private static long parseTimeString(String str, TimeUnit unit) { String lower = str.toLowerCase().trim(); - + try { Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); if (!m.matches()) { throw new NumberFormatException("Failed to parse time string: " + str); } - + long val = Long.parseLong(m.group(1)); String suffix = m.group(2); - + // Check for invalid suffixes if (suffix != null && !timeSuffixes.containsKey(suffix)) { throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); } - + // If suffix is valid use that, otherwise none was provided and use the default passed return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); } catch (NumberFormatException e) { String timeError = "Time must be specified as seconds (s), " + "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + "E.g. 50s, 100ms, or 250us."; - + throw new NumberFormatException(timeError + "\n" + e.getMessage()); } } - + /** * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If * no suffix is provided, the passed number is assumed to be in ms. @@ -205,10 +205,10 @@ public static long timeStringAsMs(String str) { public static long timeStringAsSec(String str) { return parseTimeString(str, TimeUnit.SECONDS); } - + /** * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for - * internal use. If no suffix is provided a direct conversion of the provided default is + * internal use. If no suffix is provided a direct conversion of the provided default is * attempted. */ private static long parseByteString(String str, ByteUnit unit) { @@ -217,7 +217,7 @@ private static long parseByteString(String str, ByteUnit unit) { try { Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); - + if (m.matches()) { long val = Long.parseLong(m.group(1)); String suffix = m.group(2); @@ -228,14 +228,14 @@ private static long parseByteString(String str, ByteUnit unit) { } // If suffix is valid use that, otherwise none was provided and use the default passed - return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); + return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); } else if (fractionMatcher.matches()) { - throw new NumberFormatException("Fractional values are not supported. Input was: " + throw new NumberFormatException("Fractional values are not supported. Input was: " + fractionMatcher.group(1)); } else { - throw new NumberFormatException("Failed to parse byte string: " + str); + throw new NumberFormatException("Failed to parse byte string: " + str); } - + } catch (NumberFormatException e) { String timeError = "Size must be specified as bytes (b), " + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + @@ -248,7 +248,7 @@ private static long parseByteString(String str, ByteUnit unit) { /** * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for * internal use. - * + * * If no suffix is provided, the passed number is assumed to be in bytes. */ public static long byteStringAsBytes(String str) { @@ -264,7 +264,7 @@ public static long byteStringAsBytes(String str) { public static long byteStringAsKb(String str) { return parseByteString(str, ByteUnit.KiB); } - + /** * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for * internal use. @@ -284,4 +284,20 @@ public static long byteStringAsMb(String str) { public static long byteStringAsGb(String str) { return parseByteString(str, ByteUnit.GiB); } + + /** + * Returns a byte array with the buffer's contents, trying to avoid copying the data if + * possible. + */ + public static byte[] bufferToArray(ByteBuffer buffer) { + if (buffer.hasArray() && buffer.arrayOffset() == 0 && + buffer.array().length == buffer.remaining()) { + return buffer.array(); + } else { + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + return bytes; + } + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 5889562dd9705..a466c729154aa 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -17,9 +17,13 @@ package org.apache.spark.network.util; +import java.util.Iterator; +import java.util.LinkedList; + import com.google.common.base.Preconditions; import io.netty.buffer.ByteBuf; import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -44,84 +48,138 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { public static final String HANDLER_NAME = "frameDecoder"; private static final int LENGTH_SIZE = 8; private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE; + private static final int UNKNOWN_FRAME_SIZE = -1; + + private final LinkedList buffers = new LinkedList<>(); + private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE); - private CompositeByteBuf buffer; + private long totalSize = 0; + private long nextFrameSize = UNKNOWN_FRAME_SIZE; private volatile Interceptor interceptor; @Override public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { ByteBuf in = (ByteBuf) data; + buffers.add(in); + totalSize += in.readableBytes(); + + while (!buffers.isEmpty()) { + // First, feed the interceptor, and if it's still, active, try again. + if (interceptor != null) { + ByteBuf first = buffers.getFirst(); + int available = first.readableBytes(); + if (feedInterceptor(first)) { + assert !first.isReadable() : "Interceptor still active but buffer has data."; + } - if (buffer == null) { - buffer = in.alloc().compositeBuffer(); - } - - buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes()); - - while (buffer.isReadable()) { - discardReadBytes(); - if (!feedInterceptor()) { + int read = available - first.readableBytes(); + if (read == available) { + buffers.removeFirst().release(); + } + totalSize -= read; + } else { + // Interceptor is not active, so try to decode one frame. ByteBuf frame = decodeNext(); if (frame == null) { break; } - ctx.fireChannelRead(frame); } } - - discardReadBytes(); } - private void discardReadBytes() { - // If the buffer's been retained by downstream code, then make a copy of the remaining - // bytes into a new buffer. Otherwise, just discard stale components. - if (buffer.refCnt() > 1) { - CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer(); + private long decodeFrameSize() { + if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) { + return nextFrameSize; + } - if (buffer.readableBytes() > 0) { - ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes()); - spillBuf.writeBytes(buffer); - newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes()); + // We know there's enough data. If the first buffer contains all the data, great. Otherwise, + // hold the bytes for the frame length in a composite buffer until we have enough data to read + // the frame size. Normally, it should be rare to need more than one buffer to read the frame + // size. + ByteBuf first = buffers.getFirst(); + if (first.readableBytes() >= LENGTH_SIZE) { + nextFrameSize = first.readLong() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + if (!first.isReadable()) { + buffers.removeFirst().release(); } + return nextFrameSize; + } - buffer.release(); - buffer = newBuffer; - } else { - buffer.discardReadComponents(); + while (frameLenBuf.readableBytes() < LENGTH_SIZE) { + ByteBuf next = buffers.getFirst(); + int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes()); + frameLenBuf.writeBytes(next, toRead); + if (!next.isReadable()) { + buffers.removeFirst().release(); + } } + + nextFrameSize = frameLenBuf.readLong() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + frameLenBuf.clear(); + return nextFrameSize; } private ByteBuf decodeNext() throws Exception { - if (buffer.readableBytes() < LENGTH_SIZE) { + long frameSize = decodeFrameSize(); + if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) { return null; } - int frameLen = (int) buffer.readLong() - LENGTH_SIZE; - if (buffer.readableBytes() < frameLen) { - buffer.readerIndex(buffer.readerIndex() - LENGTH_SIZE); - return null; + // Reset size for next frame. + nextFrameSize = UNKNOWN_FRAME_SIZE; + + Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); + Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); + + // If the first buffer holds the entire frame, return it. + int remaining = (int) frameSize; + if (buffers.getFirst().readableBytes() >= remaining) { + return nextBufferForFrame(remaining); } - Preconditions.checkArgument(frameLen < MAX_FRAME_SIZE, "Too large frame: %s", frameLen); - Preconditions.checkArgument(frameLen > 0, "Frame length should be positive: %s", frameLen); + // Otherwise, create a composite buffer. + CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(); + while (remaining > 0) { + ByteBuf next = nextBufferForFrame(remaining); + remaining -= next.readableBytes(); + frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes()); + } + assert remaining == 0; + return frame; + } + + /** + * Takes the first buffer in the internal list, and either adjust it to fit in the frame + * (by taking a slice out of it) or remove it from the internal list. + */ + private ByteBuf nextBufferForFrame(int bytesToRead) { + ByteBuf buf = buffers.getFirst(); + ByteBuf frame; + + if (buf.readableBytes() > bytesToRead) { + frame = buf.retain().readSlice(bytesToRead); + totalSize -= bytesToRead; + } else { + frame = buf; + buffers.removeFirst(); + totalSize -= frame.readableBytes(); + } - ByteBuf frame = buffer.readSlice(frameLen); - frame.retain(); return frame; } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { - if (buffer != null) { - if (buffer.isReadable()) { - feedInterceptor(); - } - buffer.release(); + for (ByteBuf b : buffers) { + b.release(); } if (interceptor != null) { interceptor.channelInactive(); } + frameLenBuf.release(); super.channelInactive(ctx); } @@ -141,8 +199,8 @@ public void setInterceptor(Interceptor interceptor) { /** * @return Whether the interceptor is still active after processing the data. */ - private boolean feedInterceptor() throws Exception { - if (interceptor != null && !interceptor.handle(buffer)) { + private boolean feedInterceptor(ByteBuf buf) throws Exception { + if (interceptor != null && !interceptor.handle(buf)) { interceptor = null; } return interceptor != null; diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 50a324e293386..70c849d60e0a6 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -107,7 +107,10 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 1aa20900ffe74..6c8dd742f4b64 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -82,10 +82,10 @@ private void testClientToServer(Message msg) { @Test public void requests() { testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); - testClientToServer(new RpcRequest(12345, new byte[0])); - testClientToServer(new RpcRequest(12345, new byte[100])); + testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0))); + testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10))); testClientToServer(new StreamRequest("abcde")); - testClientToServer(new OneWayMessage(new byte[100])); + testClientToServer(new OneWayMessage(new TestManagedBuffer(10))); } @Test @@ -94,8 +94,8 @@ public void responses() { testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); - testServerToClient(new RpcResponse(12345, new byte[0])); - testServerToClient(new RpcResponse(12345, new byte[1000])); + testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0))); + testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100))); testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 42955ef69235a..f9b5bf96d6215 100644 --- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -31,6 +31,7 @@ import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; import org.junit.*; +import static org.junit.Assert.*; import java.io.IOException; import java.nio.ByteBuffer; @@ -84,13 +85,16 @@ public void tearDown() { @Test public void timeoutInactiveRequests() throws Exception { final Semaphore semaphore = new Semaphore(1); - final byte[] response = new byte[16]; + final int responseSize = 16; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { try { semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(response); + callback.onSuccess(ByteBuffer.allocate(responseSize)); } catch (InterruptedException e) { // do nothing } @@ -110,15 +114,15 @@ public StreamManager getStreamManager() { // First completes quickly (semaphore starts at 1). TestCallback callback0 = new TestCallback(); synchronized (callback0) { - client.sendRpc(new byte[0], callback0); + client.sendRpc(ByteBuffer.allocate(0), callback0); callback0.wait(FOREVER); - assert (callback0.success.length == response.length); + assertEquals(responseSize, callback0.successLength); } // Second times out after 2 seconds, with slack. Must be IOException. TestCallback callback1 = new TestCallback(); synchronized (callback1) { - client.sendRpc(new byte[0], callback1); + client.sendRpc(ByteBuffer.allocate(0), callback1); callback1.wait(4 * 1000); assert (callback1.failure != null); assert (callback1.failure instanceof IOException); @@ -131,13 +135,16 @@ public StreamManager getStreamManager() { @Test public void timeoutCleanlyClosesClient() throws Exception { final Semaphore semaphore = new Semaphore(0); - final byte[] response = new byte[16]; + final int responseSize = 16; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { try { semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(response); + callback.onSuccess(ByteBuffer.allocate(responseSize)); } catch (InterruptedException e) { // do nothing } @@ -158,7 +165,7 @@ public StreamManager getStreamManager() { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); TestCallback callback0 = new TestCallback(); synchronized (callback0) { - client0.sendRpc(new byte[0], callback0); + client0.sendRpc(ByteBuffer.allocate(0), callback0); callback0.wait(FOREVER); assert (callback0.failure instanceof IOException); assert (!client0.isActive()); @@ -170,10 +177,10 @@ public StreamManager getStreamManager() { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); TestCallback callback1 = new TestCallback(); synchronized (callback1) { - client1.sendRpc(new byte[0], callback1); + client1.sendRpc(ByteBuffer.allocate(0), callback1); callback1.wait(FOREVER); - assert (callback1.success.length == response.length); - assert (callback1.failure == null); + assertEquals(responseSize, callback1.successLength); + assertNull(callback1.failure); } } @@ -191,7 +198,10 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } @@ -218,9 +228,10 @@ public StreamManager getStreamManager() { synchronized (callback0) { // not complete yet, but should complete soon - assert (callback0.success == null && callback0.failure == null); + assertEquals(-1, callback0.successLength); + assertNull(callback0.failure); callback0.wait(2 * 1000); - assert (callback0.failure instanceof IOException); + assertTrue(callback0.failure instanceof IOException); } synchronized (callback1) { @@ -235,13 +246,13 @@ public StreamManager getStreamManager() { */ class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { - byte[] success; + int successLength = -1; Throwable failure; @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { synchronized(this) { - success = response; + successLength = response.remaining(); this.notifyAll(); } } @@ -258,7 +269,7 @@ public void onFailure(Throwable e) { public void onSuccess(int chunkIndex, ManagedBuffer buffer) { synchronized(this) { try { - success = buffer.nioByteBuffer().array(); + successLength = buffer.nioByteBuffer().remaining(); this.notifyAll(); } catch (IOException e) { // weird diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 88fa2258bb794..9e9be98c140b7 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; @@ -26,7 +27,6 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import com.google.common.base.Charsets; import com.google.common.collect.Sets; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -41,6 +41,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -55,11 +56,14 @@ public static void setUp() throws Exception { TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); rpcHandler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - String msg = new String(message, Charsets.UTF_8); + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + String msg = JavaUtils.bytesToString(message); String[] parts = msg.split("/"); if (parts[0].equals("hello")) { - callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8)); + callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!")); } else if (parts[0].equals("return error")) { callback.onFailure(new RuntimeException("Returned: " + parts[1])); } else if (parts[0].equals("throw error")) { @@ -68,9 +72,8 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } @Override - public void receive(TransportClient client, byte[] message) { - String msg = new String(message, Charsets.UTF_8); - oneWayMsgs.add(msg); + public void receive(TransportClient client, ByteBuffer message) { + oneWayMsgs.add(JavaUtils.bytesToString(message)); } @Override @@ -103,8 +106,9 @@ private RpcResult sendRPC(String ... commands) throws Exception { RpcResponseCallback callback = new RpcResponseCallback() { @Override - public void onSuccess(byte[] message) { - res.successMessages.add(new String(message, Charsets.UTF_8)); + public void onSuccess(ByteBuffer message) { + String response = JavaUtils.bytesToString(message); + res.successMessages.add(response); sem.release(); } @@ -116,7 +120,7 @@ public void onFailure(Throwable e) { }; for (String command : commands) { - client.sendRpc(command.getBytes(Charsets.UTF_8), callback); + client.sendRpc(JavaUtils.stringToBytes(command), callback); } if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) { @@ -173,7 +177,7 @@ public void sendOneWayMessage() throws Exception { final String message = "no reply"; TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.send(message.getBytes(Charsets.UTF_8)); + client.send(JavaUtils.stringToBytes(message)); assertEquals(0, client.getHandler().numOutstandingRequests()); // Make sure the message arrives. diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 538f3efe8d6f2..9c49556927f0b 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -116,7 +116,10 @@ public ManagedBuffer openStream(String streamId) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 30144f4a9fc7a..128f7cba74350 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,8 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; + import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; @@ -27,6 +29,7 @@ import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; @@ -42,7 +45,7 @@ public class TransportResponseHandlerSuite { @Test - public void handleSuccessfulFetch() { + public void handleSuccessfulFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); @@ -56,7 +59,7 @@ public void handleSuccessfulFetch() { } @Test - public void handleFailedFetch() { + public void handleFailedFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); @@ -69,7 +72,7 @@ public void handleFailedFetch() { } @Test - public void clearAllOutstandingRequests() { + public void clearAllOutstandingRequests() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(new StreamChunkId(1, 0), callback); @@ -88,23 +91,24 @@ public void clearAllOutstandingRequests() { } @Test - public void handleSuccessfulRPC() { + public void handleSuccessfulRPC() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); RpcResponseCallback callback = mock(RpcResponseCallback.class); handler.addRpcRequest(12345, callback); assertEquals(1, handler.numOutstandingRequests()); - handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored + // This response should be ignored. + handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7)))); assertEquals(1, handler.numOutstandingRequests()); - byte[] arr = new byte[10]; - handler.handle(new RpcResponse(12345, arr)); - verify(callback, times(1)).onSuccess(eq(arr)); + ByteBuffer resp = ByteBuffer.allocate(10); + handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp))); + verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10))); assertEquals(0, handler.numOutstandingRequests()); } @Test - public void handleFailedRPC() { + public void handleFailedRPC() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); RpcResponseCallback callback = mock(RpcResponseCallback.class); handler.addRpcRequest(12345, callback); @@ -119,7 +123,7 @@ public void handleFailedRPC() { } @Test - public void testActiveStreams() { + public void testActiveStreams() throws Exception { Channel c = new LocalChannel(); c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); TransportResponseHandler handler = new TransportResponseHandler(c); diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index a6f180bc40c9a..751516b9d82a1 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -22,7 +22,7 @@ import java.io.File; import java.lang.reflect.Method; -import java.nio.charset.StandardCharsets; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; import java.util.Random; @@ -57,6 +57,7 @@ import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -123,39 +124,53 @@ public void testNonMatching() { } @Test - public void testSaslAuthentication() throws Exception { + public void testSaslAuthentication() throws Throwable { testBasicSasl(false); } @Test - public void testSaslEncryption() throws Exception { + public void testSaslEncryption() throws Throwable { testBasicSasl(true); } - private void testBasicSasl(boolean encrypt) throws Exception { + private void testBasicSasl(boolean encrypt) throws Throwable { RpcHandler rpcHandler = mock(RpcHandler.class); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocation) { - byte[] message = (byte[]) invocation.getArguments()[1]; + ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; - assertEquals("Ping", new String(message, StandardCharsets.UTF_8)); - cb.onSuccess("Pong".getBytes(StandardCharsets.UTF_8)); + assertEquals("Ping", JavaUtils.bytesToString(message)); + cb.onSuccess(JavaUtils.stringToBytes("Pong")); return null; } }) .when(rpcHandler) - .receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class)); + .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); try { - byte[] response = ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), - TimeUnit.SECONDS.toMillis(10)); - assertEquals("Pong", new String(response, StandardCharsets.UTF_8)); + ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), + TimeUnit.SECONDS.toMillis(10)); + assertEquals("Pong", JavaUtils.bytesToString(response)); } finally { ctx.close(); // There should be 2 terminated events; one for the client, one for the server. - verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + Throwable error = null; + long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); + while (deadline > System.nanoTime()) { + try { + verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + error = null; + break; + } catch (Throwable t) { + error = t; + TimeUnit.MILLISECONDS.sleep(10); + } + } + if (error != null) { + throw error; + } } } @@ -325,8 +340,8 @@ public void testDataEncryptionIsActuallyEnabled() throws Exception { SaslTestCtx ctx = null; try { ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); - ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), - TimeUnit.SECONDS.toMillis(10)); + ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), + TimeUnit.SECONDS.toMillis(10)); fail("Should have failed to send RPC to server."); } catch (Exception e) { assertFalse(e.getCause() instanceof TimeoutException); diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index 19475c21ffce9..d4de4a941d480 100644 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -118,6 +118,27 @@ public Void answer(InvocationOnMock in) { } } + @Test + public void testSplitLengthField() throws Exception { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + ByteBuf buf = Unpooled.buffer(frame.length + 8); + buf.writeLong(frame.length + 8); + buf.writeBytes(frame); + + TransportFrameDecoder decoder = new TransportFrameDecoder(); + ChannelHandlerContext ctx = mockChannelHandlerContext(); + try { + decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain()); + verify(ctx, never()).fireChannelRead(any(ByteBuf.class)); + decoder.channelRead(ctx, buf); + verify(ctx).fireChannelRead(any(ByteBuf.class)); + assertEquals(0, buf.refCnt()); + } finally { + decoder.channelInactive(ctx); + release(buf); + } + } + @Test(expected = IllegalArgumentException.class) public void testNegativeFrameSize() throws Exception { testInvalidFrame(-1); @@ -183,7 +204,7 @@ private void testInvalidFrame(long size) throws Exception { try { decoder.channelRead(ctx, frame); } finally { - frame.release(); + release(frame); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 3ddf5c3c39189..f22187a01db02 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import com.google.common.annotations.VisibleForTesting; @@ -66,8 +67,8 @@ public ExternalShuffleBlockHandler( } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message); handleMessage(msgObj, client, callback); } @@ -85,13 +86,13 @@ protected void handleMessage( } long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); - callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray()); + callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); } else if (msgObj instanceof RegisterExecutor) { RegisterExecutor msg = (RegisterExecutor) msgObj; checkAuth(client, msg.appId); blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo); - callback.onSuccess(new byte[0]); + callback.onSuccess(ByteBuffer.wrap(new byte[0])); } else { throw new UnsupportedOperationException("Unexpected message: " + msgObj); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index ef3a9dcc8711f..58ca87d9d3b13 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import com.google.common.base.Preconditions; @@ -139,7 +140,7 @@ public void registerWithShuffleServer( checkInit(); TransportClient client = clientFactory.createUnmanagedClient(host, port); try { - byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); } finally { client.close(); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index e653f5cb147ee..1b2ddbf1ed917 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.nio.ByteBuffer; import java.util.Arrays; import org.slf4j.Logger; @@ -89,11 +90,11 @@ public void start() { throw new IllegalArgumentException("Zero-sized blockIds array"); } - client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() { + client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { try { - streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); // Immediately request all chunks -- we expect that the total size of the request is diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 7543b6be4f2a1..675820308bd4c 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle.mesos; import java.io.IOException; +import java.nio.ByteBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,11 +55,11 @@ public MesosExternalShuffleClient( public void registerDriverWithShuffleService(String host, int port) throws IOException { checkInit(); - byte[] registerDriver = new RegisterDriver(appId).toByteArray(); + ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer(); TransportClient client = clientFactory.createClient(host, port); client.sendRpc(registerDriver, new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { logger.info("Successfully registered app " + appId + " with external shuffle service."); } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index fcb52363e632c..7fbe3384b4d4f 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -17,6 +17,8 @@ package org.apache.spark.network.shuffle.protocol; +import java.nio.ByteBuffer; + import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -53,7 +55,7 @@ private Type(int id) { // NB: Java does not support static methods in interfaces, so we must put this in a static class. public static class Decoder { /** Deserializes the 'type' byte followed by the message itself. */ - public static BlockTransferMessage fromByteArray(byte[] msg) { + public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { ByteBuf buf = Unpooled.wrappedBuffer(msg); byte type = buf.readByte(); switch (type) { @@ -68,12 +70,12 @@ public static BlockTransferMessage fromByteArray(byte[] msg) { } /** Serializes the 'type' byte followed by the message itself. */ - public byte[] toByteArray() { + public ByteBuffer toByteBuffer() { // Allow room for encoded message, plus the type byte ByteBuf buf = Unpooled.buffer(encodedLength() + 1); buf.writeByte(type().id); encode(buf); assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); - return buf.array(); + return buf.nioBuffer(); } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 1c2fa4d0d462c..19c870aebb023 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.network.sasl; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.concurrent.atomic.AtomicReference; @@ -52,6 +53,7 @@ import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -107,8 +109,8 @@ public void testGoodClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; - byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS); - assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS); + assertEquals(msg, JavaUtils.bytesToString(resp)); } @Test @@ -136,7 +138,7 @@ public void testNoSaslClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.sendRpcSync(new byte[13], TIMEOUT_MS); + client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); @@ -144,7 +146,7 @@ public void testNoSaslClient() throws IOException { try { // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS); + client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); @@ -222,13 +224,13 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable t) { new String[] { System.getProperty("java.io.tmpdir") }, 1, "org.apache.spark.shuffle.sort.SortShuffleManager"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); - client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS); + client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS); // Make a successful request to fetch blocks, which creates a new stream. But do not actually // fetch any blocks, to keep the stream open. OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); - byte[] response = client1.sendRpcSync(openMessage.toByteArray(), TIMEOUT_MS); - StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS); + StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); long streamId = stream.streamId; // Create a second client, authenticated with a different app ID, and try to read from @@ -275,7 +277,7 @@ public synchronized void onFailure(int chunkIndex, Throwable t) { /** RPC handler which simply responds with the message it received. */ public static class TestRpcHandler extends RpcHandler { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { callback.onSuccess(message); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index d65de9ca550a3..86c8609e7070b 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -36,7 +36,7 @@ public void serializeOpenShuffleBlocks() { } private void checkSerializeDeserialize(BlockTransferMessage msg) { - BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray()); + BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer()); assertEquals(msg, msg2); assertEquals(msg.hashCode(), msg2.hashCode()); assertEquals(msg.toString(), msg2.toString()); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index e61390cf57061..9379412155e88 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -60,12 +60,12 @@ public void testRegisterExecutor() { RpcResponseCallback callback = mock(RpcResponseCallback.class); ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); - byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray(); + ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer(); handler.receive(client, registerMessage, callback); verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config); - verify(callback, times(1)).onSuccess((byte[]) any()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, times(1)).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); } @SuppressWarnings("unchecked") @@ -77,17 +77,18 @@ public void testOpenShuffleBlocks() { ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray(); + ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) + .toByteBuffer(); handler.receive(client, openBlocks, callback); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); - ArgumentCaptor response = ArgumentCaptor.forClass(byte[].class); + ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); verify(callback, never()).onFailure((Throwable) any()); StreamHandle handle = - (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue()); + (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); assertEquals(2, handle.numChunks); @SuppressWarnings("unchecked") @@ -104,7 +105,7 @@ public void testOpenShuffleBlocks() { public void testBadMessages() { RpcResponseCallback callback = mock(RpcResponseCallback.class); - byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 }; + ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 }); try { handler.receive(client, unserializableMsg, callback); fail("Should have thrown"); @@ -112,7 +113,7 @@ public void testBadMessages() { // pass } - byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray(); + ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer(); try { handler.receive(client, unexpectedMsg, callback); fail("Should have thrown"); @@ -120,7 +121,7 @@ public void testBadMessages() { // pass } - verify(callback, never()).onSuccess((byte[]) any()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, never()).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index b35a6d685dd02..2590b9ce4c1f1 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -134,14 +134,14 @@ private BlockFetchingListener fetchBlocks(final LinkedHashMap() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray( - (byte[]) invocationOnMock.getArguments()[0]); + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer( + (ByteBuffer) invocationOnMock.getArguments()[0]); RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; - callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray()); + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer()); assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); return null; } - }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); + }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class)); // Respond to each chunk request with a single buffer from our blocks array. final AtomicInteger expectedChunkIndex = new AtomicInteger(0); From c4d1ae0c832dc24c93366aeab14a7c619ecef5c2 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 30 Nov 2015 17:33:09 -0800 Subject: [PATCH 0981/1168] [SPARK-12049][CORE] User JVM shutdown hook can cause deadlock at shutdown Avoid potential deadlock with a user app's shutdown hook thread by more narrowly synchronizing access to 'hooks' Author: Sean Owen Closes #10042 from srowen/SPARK-12049. --- .../spark/util/ShutdownHookManager.scala | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 4012dca3ecdf8..620f226a23e15 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -206,7 +206,7 @@ private[spark] object ShutdownHookManager extends Logging { private [util] class SparkShutdownHookManager { private val hooks = new PriorityQueue[SparkShutdownHook]() - private var shuttingDown = false + @volatile private var shuttingDown = false /** * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not @@ -232,28 +232,27 @@ private [util] class SparkShutdownHookManager { } } - def runAll(): Unit = synchronized { + def runAll(): Unit = { shuttingDown = true - while (!hooks.isEmpty()) { - Try(Utils.logUncaughtExceptions(hooks.poll().run())) + var nextHook: SparkShutdownHook = null + while ({ nextHook = hooks.synchronized { hooks.poll() }; nextHook != null }) { + Try(Utils.logUncaughtExceptions(nextHook.run())) } } - def add(priority: Int, hook: () => Unit): AnyRef = synchronized { - checkState() - val hookRef = new SparkShutdownHook(priority, hook) - hooks.add(hookRef) - hookRef - } - - def remove(ref: AnyRef): Boolean = synchronized { - hooks.remove(ref) + def add(priority: Int, hook: () => Unit): AnyRef = { + hooks.synchronized { + if (shuttingDown) { + throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") + } + val hookRef = new SparkShutdownHook(priority, hook) + hooks.add(hookRef) + hookRef + } } - private def checkState(): Unit = { - if (shuttingDown) { - throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") - } + def remove(ref: AnyRef): Boolean = { + hooks.synchronized { hooks.remove(ref) } } } From 39b33597cf5450579142c4b032014100951b2258 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 30 Nov 2015 17:15:47 -0800 Subject: [PATCH 0982/1168] [HOTFIX][SPARK-12000] Add missing quotes in Jekyll API docs plugin. I accidentally omitted these as part of #10049. --- docs/_plugins/copy_api_dirs.rb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index f2f3e2e653149..174c202e37918 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -117,7 +117,7 @@ puts "Moving to python/docs directory and building sphinx." cd("../python/docs") - system(make html) || raise("Python doc generation failed") + system("make html") || raise("Python doc generation failed") puts "Moving back into home dir." cd("../../") From c4618776c291fc76f26c3454627bba9ba6a2fb5b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 30 Nov 2015 20:56:42 -0800 Subject: [PATCH 0983/1168] [SPARK-12018][SQL] Refactor common subexpression elimination code JIRA: https://issues.apache.org/jira/browse/SPARK-12018 The code of common subexpression elimination can be factored and simplified. Some unnecessary variables can be removed. Author: Liang-Chi Hsieh Closes #10009 from viirya/refactor-subexpr-eliminate. --- .../sql/catalyst/expressions/Expression.scala | 10 ++---- .../expressions/codegen/CodeGenerator.scala | 34 +++++-------------- .../codegen/GenerateUnsafeProjection.scala | 4 +-- 3 files changed, 14 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 169435a10ea2c..b55d3653a7158 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -94,13 +94,9 @@ abstract class Expression extends TreeNode[Expression] { def gen(ctx: CodeGenContext): GeneratedExpressionCode = { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated meaning the code to evaluated has already been added - // as a function, `subExprState.fnName`. Just call that. - val code = - s""" - |/* $this */ - |${subExprState.fnName}(${ctx.INPUT_ROW}); - """.stripMargin.trim - GeneratedExpressionCode(code, subExprState.code.isNull, subExprState.code.value) + // as a function and called in advance. Just use it. + val code = s"/* $this */" + GeneratedExpressionCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") val primitive = ctx.freshName("primitive") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2f3d6aeb86c5b..440c7d2fc1156 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -104,16 +104,13 @@ class CodeGenContext { val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions // State used for subexpression elimination. - case class SubExprEliminationState( - isLoaded: String, - code: GeneratedExpressionCode, - fnName: String) + case class SubExprEliminationState(isNull: String, value: String) // Foreach expression that is participating in subexpression elimination, the state to use. val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] - // The collection of isLoaded variables that need to be reset on each row. - val subExprIsLoadedVariables = mutable.ArrayBuffer.empty[String] + // The collection of sub-exression result resetting methods that need to be called on each row. + val subExprResetVariables = mutable.ArrayBuffer.empty[String] final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" @@ -408,7 +405,6 @@ class CodeGenContext { val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) commonExprs.foreach(e => { val expr = e.head - val isLoaded = freshName("isLoaded") val isNull = freshName("isNull") val value = freshName("value") val fnName = freshName("evalExpr") @@ -417,18 +413,12 @@ class CodeGenContext { val code = expr.gen(this) val fn = s""" - |private void $fnName(InternalRow ${INPUT_ROW}) { - | if (!$isLoaded) { - | ${code.code.trim} - | $isLoaded = true; - | $isNull = ${code.isNull}; - | $value = ${code.value}; - | } + |private void $fnName(InternalRow $INPUT_ROW) { + | ${code.code.trim} + | $isNull = ${code.isNull}; + | $value = ${code.value}; |} """.stripMargin - code.code = fn - code.isNull = isNull - code.value = value addNewFunction(fnName, fn) @@ -448,18 +438,12 @@ class CodeGenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - - // Maintain the loaded value and isNull as member variables. This is necessary if the codegen - // function is split across multiple functions. - // TODO: maintaining this as a local variable probably allows the compiler to do better - // optimizations. - addMutableState("boolean", isLoaded, s"$isLoaded = false;") addMutableState("boolean", isNull, s"$isNull = false;") addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") - subExprIsLoadedVariables += isLoaded - val state = SubExprEliminationState(isLoaded, code, fnName) + subExprResetVariables += s"$fnName($INPUT_ROW);" + val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7b6c9373ebe30..68005afb21d2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -287,8 +287,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") - // Reset the isLoaded flag for each row. - val subexprReset = ctx.subExprIsLoadedVariables.map { v => s"${v} = false;" }.mkString("\n") + // Reset the subexpression values for each row. + val subexprReset = ctx.subExprResetVariables.mkString("\n") val code = s""" From f693daed4d2d90fefc394e9bd77b826bf2b5aeb3 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 1 Dec 2015 09:26:58 +0000 Subject: [PATCH 0984/1168] [SPARK-11898][MLLIB] Use broadcast for the global tables in Word2Vec jira: https://issues.apache.org/jira/browse/SPARK-11898 syn0Global and sync1Global in word2vec are quite large objects with size (vocab * vectorSize * 8), yet they are passed to worker using basic task serialization. Use broadcast can greatly improve the performance. My benchmark shows that, for 1M vocabulary and default vectorSize 100, changing to broadcast can help, 1. decrease the worker memory consumption by 45%. 2. decrease running time by 40%. This will also help extend the upper limit for Word2Vec. Author: Yuhao Yang Closes #9878 from hhbyyh/w2vBC. --- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index a47f27b0afb14..655ac0bb5545b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -316,12 +316,15 @@ class Word2Vec extends Serializable with Logging { Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) var alpha = learningRate + for (k <- 1 to numIterations) { + val bcSyn0Global = sc.broadcast(syn0Global) + val bcSyn1Global = sc.broadcast(syn1Global) val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) val syn1Modify = new Array[Int](vocabSize) - val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { + val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount @@ -405,6 +408,8 @@ class Word2Vec extends Serializable with Logging { } i += 1 } + bcSyn0Global.unpersist(false) + bcSyn1Global.unpersist(false) } newSentences.unpersist() From e686df1c522c2a6d6c513d007533103370f27190 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 Dec 2015 07:42:37 -0800 Subject: [PATCH 0985/1168] [SPARK-11949][SQL] Set field nullable property for GroupingSets to get correct results for null values JIRA: https://issues.apache.org/jira/browse/SPARK-11949 The result of cube plan uses incorrect schema. The schema of cube result should set nullable property to true because the grouping expressions will have null values. Author: Liang-Chi Hsieh Closes #10038 from viirya/fix-cube. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 10 ++++++++-- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 94ffbbb2e5c65..b8f212fca7509 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -223,6 +223,11 @@ class Analyzer( case other => Alias(other, other.toString)() } + // TODO: We need to use bitmasks to determine which grouping expressions need to be + // set as nullable. For example, if we have GROUPING SETS ((a,b), a), we do not need + // to change the nullability of a. + val attributeMap = groupByAliases.map(a => (a -> a.toAttribute.withNullability(true))).toMap + val aggregations: Seq[NamedExpression] = x.aggregations.map { // If an expression is an aggregate (contains a AggregateExpression) then we dont change // it so that the aggregation is computed on the unmodified value of its argument @@ -231,12 +236,13 @@ class Analyzer( // If not then its a grouping expression and we need to use the modified (with nulls from // Expand) value of the expression. case expr => expr.transformDown { - case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e) + case e => + groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e) }.asInstanceOf[NamedExpression] } val child = Project(x.child.output ++ groupByAliases, x.child) - val groupByAttributes = groupByAliases.map(_.toAttribute) + val groupByAttributes = groupByAliases.map(attributeMap(_)) Aggregate( groupByAttributes :+ VirtualColumn.groupingIdAttribute, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b5c636d0de1d1..b1004bc5bc290 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.DecimalType +case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -86,6 +87,15 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, 2013, 78000.0) :: Row(null, null, 113000.0) :: Nil ) + + val df0 = sqlContext.sparkContext.parallelize(Seq( + Fact(20151123, 18, 35, "room1", 18.6), + Fact(20151123, 18, 35, "room2", 22.4), + Fact(20151123, 18, 36, "room1", 17.4), + Fact(20151123, 18, 36, "room2", 25.6))).toDF() + + val cube0 = df0.cube("date", "hour", "minute", "room_name").agg(Map("temp" -> "avg")) + assert(cube0.where("date IS NULL").count > 0) } test("rollup overlapping columns") { From 948c6f1e006ef1d8ba0fd06e3a2fde06f02f0c5d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 1 Dec 2015 09:45:55 -0800 Subject: [PATCH 0986/1168] [SPARK-12060][CORE] Avoid memory copy in JavaSerializerInstance.serialize `JavaSerializerInstance.serialize` uses `ByteArrayOutputStream.toByteArray` to get the serialized data. `ByteArrayOutputStream.toByteArray` needs to copy the content in the internal array to a new array. However, since the array will be converted to `ByteBuffer` at once, we can avoid the memory copy. This PR added `ByteBufferOutputStream` to access the protected `buf` and convert it to a `ByteBuffer` directly. Author: Shixiong Zhu Closes #10051 from zsxwing/SPARK-12060. --- .../spark/serializer/JavaSerializer.scala | 7 ++--- .../spark/util/ByteBufferOutputStream.scala | 31 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index b463a71d5bd7d..ea718a0edbe71 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -24,8 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.ByteBufferInputStream -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} private[spark] class JavaSerializationStream( out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) @@ -96,11 +95,11 @@ private[spark] class JavaSerializerInstance( extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bos = new ByteArrayOutputStream() + val bos = new ByteBufferOutputStream() val out = serializeStream(bos) out.writeObject(t) out.close() - ByteBuffer.wrap(bos.toByteArray) + bos.toByteBuffer } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala new file mode 100644 index 0000000000000..92e45224db81c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer + +/** + * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer + */ +private[spark] class ByteBufferOutputStream extends ByteArrayOutputStream { + + def toByteBuffer: ByteBuffer = { + return ByteBuffer.wrap(buf, 0, count) + } +} From 2e9a9431f7a007235307324130f2c479c0d3b64d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 1 Dec 2015 10:21:31 -0800 Subject: [PATCH 0987/1168] [SPARK-12046][DOC] Fixes various ScalaDoc/JavaDoc issues This PR backports PR #10039 to master Author: Cheng Lian Closes #10063 from liancheng/spark-12046.doc-fix.master. --- .../spark/api/java/function/Function4.java | 2 +- .../spark/api/java/function/VoidFunction.java | 2 +- .../api/java/function/VoidFunction2.java | 2 +- .../apache/spark/api/java/JavaPairRDD.scala | 16 +++---- .../org/apache/spark/memory/package.scala | 14 +++--- .../org/apache/spark/rdd/CoGroupedRDD.scala | 2 +- .../apache/spark/rdd/PairRDDFunctions.scala | 6 +-- .../org/apache/spark/rdd/ShuffledRDD.scala | 2 +- .../org/apache/spark/scheduler/Task.scala | 5 ++- .../serializer/SerializationDebugger.scala | 13 +++--- .../scala/org/apache/spark/util/Vector.scala | 1 + .../util/collection/ExternalSorter.scala | 30 ++++++------- .../WritablePartitionedPairCollection.scala | 7 +-- .../streaming/kinesis/KinesisReceiver.scala | 23 +++++----- .../streaming/kinesis/KinesisUtils.scala | 13 +++--- .../mllib/optimization/GradientDescent.scala | 12 ++--- project/SparkBuild.scala | 2 + .../scala/org/apache/spark/sql/Column.scala | 11 ++--- .../spark/streaming/StreamingContext.scala | 11 ++--- .../streaming/dstream/FileInputDStream.scala | 19 ++++---- .../streaming/receiver/BlockGenerator.scala | 22 ++++----- .../scheduler/ReceiverSchedulingPolicy.scala | 45 ++++++++++--------- .../util/FileBasedWriteAheadLog.scala | 7 +-- .../spark/streaming/util/RecurringTimer.scala | 8 ++-- .../org/apache/spark/deploy/yarn/Client.scala | 10 ++--- 25 files changed, 152 insertions(+), 133 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function4.java b/core/src/main/java/org/apache/spark/api/java/function/Function4.java index fd727d64863d7..9c35a22ca9d0f 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function4.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function4.java @@ -23,5 +23,5 @@ * A four-argument function that takes arguments of type T1, T2, T3 and T4 and returns an R. */ public interface Function4 extends Serializable { - public R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception; + R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java index 2a10435b7523a..f30d42ee57966 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java @@ -23,5 +23,5 @@ * A function with no return value. */ public interface VoidFunction extends Serializable { - public void call(T t) throws Exception; + void call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java index 6c576ab678455..da9ae1c9c5cdc 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java @@ -23,5 +23,5 @@ * A two-argument function that takes arguments of type T1 and T2 with no return value. */ public interface VoidFunction2 extends Serializable { - public void call(T1 v1, T2 v2) throws Exception; + void call(T1 v1, T2 v2) throws Exception; } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 0b0c6e5bb8cc1..87deaf20e2b25 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -215,13 +215,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C * Note that V and C can be different -- for example, one might group an + * "combined type" C. Note that V and C can be different -- for example, one might group an * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three * functions: * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. * * In addition, users can control the partitioning of the output RDD, the serializer that is use * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple @@ -247,13 +247,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C * Note that V and C can be different -- for example, one might group an + * "combined type" C. Note that V and C can be different -- for example, one might group an * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three * functions: * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. * * In addition, users can control the partitioning of the output RDD. This method automatically * uses map-side aggregation in shuffling the RDD. diff --git a/core/src/main/scala/org/apache/spark/memory/package.scala b/core/src/main/scala/org/apache/spark/memory/package.scala index 564e30d2ffd66..3d00cd9cb6377 100644 --- a/core/src/main/scala/org/apache/spark/memory/package.scala +++ b/core/src/main/scala/org/apache/spark/memory/package.scala @@ -21,13 +21,13 @@ package org.apache.spark * This package implements Spark's memory management system. This system consists of two main * components, a JVM-wide memory manager and a per-task manager: * - * - [[org.apache.spark.memory.MemoryManager]] manages Spark's overall memory usage within a JVM. - * This component implements the policies for dividing the available memory across tasks and for - * allocating memory between storage (memory used caching and data transfer) and execution (memory - * used by computations, such as shuffles, joins, sorts, and aggregations). - * - [[org.apache.spark.memory.TaskMemoryManager]] manages the memory allocated by individual tasks. - * Tasks interact with TaskMemoryManager and never directly interact with the JVM-wide - * MemoryManager. + * - [[org.apache.spark.memory.MemoryManager]] manages Spark's overall memory usage within a JVM. + * This component implements the policies for dividing the available memory across tasks and for + * allocating memory between storage (memory used caching and data transfer) and execution + * (memory used by computations, such as shuffles, joins, sorts, and aggregations). + * - [[org.apache.spark.memory.TaskMemoryManager]] manages the memory allocated by individual + * tasks. Tasks interact with TaskMemoryManager and never directly interact with the JVM-wide + * MemoryManager. * * Internally, each of these components have additional abstractions for memory bookkeeping: * diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 935c3babd8ea1..3a0ca1d813297 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -70,7 +70,7 @@ private[spark] class CoGroupPartition( * * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of * instantiating this directly. - + * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output */ diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index c6181902ace6d..44d195587a081 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -65,9 +65,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Note that V and C can be different -- for example, one might group an RDD of type * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. * * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index a013c3f66a3a8..3ef506e1562bf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -86,7 +86,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) } - override def getPreferredLocations(partition: Partition): Seq[String] = { + override protected def getPreferredLocations(partition: Partition): Seq[String] = { val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] tracker.getPreferredLocationsForShuffle(dep, partition.index) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 4fb32ba8cb188..2fcd5aa57d11b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -33,8 +33,9 @@ import org.apache.spark.util.Utils /** * A unit of execution. We have two kinds of Task's in Spark: - * - [[org.apache.spark.scheduler.ShuffleMapTask]] - * - [[org.apache.spark.scheduler.ResultTask]] + * + * - [[org.apache.spark.scheduler.ShuffleMapTask]] + * - [[org.apache.spark.scheduler.ResultTask]] * * A Spark job consists of one or more stages. The very last stage in a job consists of multiple * ResultTasks, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index a1b1e1631eafb..e2951d8a3e096 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -53,12 +53,13 @@ private[spark] object SerializationDebugger extends Logging { /** * Find the path leading to a not serializable object. This method is modeled after OpenJDK's * serialization mechanism, and handles the following cases: - * - primitives - * - arrays of primitives - * - arrays of non-primitive objects - * - Serializable objects - * - Externalizable objects - * - writeReplace + * + * - primitives + * - arrays of primitives + * - arrays of non-primitive objects + * - Serializable objects + * - Externalizable objects + * - writeReplace * * It does not yet handle writeObject override, but that shouldn't be too hard to do either. */ diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala index 2ed827eab46df..6b3fa8491904a 100644 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -122,6 +122,7 @@ class Vector(val elements: Array[Double]) extends Serializable { override def toString: String = elements.mkString("(", ", ", ")") } +@deprecated("Use Vectors.dense from Spark's mllib.linalg package instead.", "1.0.0") object Vector { def apply(elements: Array[Double]): Vector = new Vector(elements) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 2440139ac95e9..44b1d90667e65 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -67,24 +67,24 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} * * At a high level, this class works internally as follows: * - * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if - * we want to combine by key, or a PartitionedPairBuffer if we don't. - * Inside these buffers, we sort elements by partition ID and then possibly also by key. - * To avoid calling the partitioner multiple times with each key, we store the partition ID - * alongside each record. + * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if + * we want to combine by key, or a PartitionedPairBuffer if we don't. + * Inside these buffers, we sort elements by partition ID and then possibly also by key. + * To avoid calling the partitioner multiple times with each key, we store the partition ID + * alongside each record. * - * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first - * by partition ID and possibly second by key or by hash code of the key, if we want to do - * aggregation. For each file, we track how many objects were in each partition in memory, so we - * don't have to write out the partition ID for every element. + * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first + * by partition ID and possibly second by key or by hash code of the key, if we want to do + * aggregation. For each file, we track how many objects were in each partition in memory, so we + * don't have to write out the partition ID for every element. * - * - When the user requests an iterator or file output, the spilled files are merged, along with - * any remaining in-memory data, using the same sort order defined above (unless both sorting - * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering - * from the ordering parameter, or read the keys with the same hash code and compare them with - * each other for equality to merge values. + * - When the user requests an iterator or file output, the spilled files are merged, along with + * any remaining in-memory data, using the same sort order defined above (unless both sorting + * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering + * from the ordering parameter, or read the keys with the same hash code and compare them with + * each other for equality to merge values. * - * - Users are expected to call stop() at the end to delete all the intermediate files. + * - Users are expected to call stop() at the end to delete all the intermediate files. */ private[spark] class ExternalSorter[K, V, C]( context: TaskContext, diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 38848e9018c6c..5232c2bd8d6f6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -23,9 +23,10 @@ import org.apache.spark.storage.DiskBlockObjectWriter /** * A common interface for size-tracking collections of key-value pairs that - * - Have an associated partition for each key-value pair. - * - Support a memory-efficient sorted iterator - * - Support a WritablePartitionedIterator for writing the contents directly as bytes. + * + * - Have an associated partition for each key-value pair. + * - Support a memory-efficient sorted iterator + * - Support a WritablePartitionedIterator for writing the contents directly as bytes. */ private[spark] trait WritablePartitionedPairCollection[K, V] { /** diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 97dbb918573a3..05080835fc4ad 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -46,17 +46,18 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * https://github.com/awslabs/amazon-kinesis-client * * The way this Receiver works is as follows: - * - The receiver starts a KCL Worker, which is essentially runs a threadpool of multiple - * KinesisRecordProcessor - * - Each KinesisRecordProcessor receives data from a Kinesis shard in batches. Each batch is - * inserted into a Block Generator, and the corresponding range of sequence numbers is recorded. - * - When the block generator defines a block, then the recorded sequence number ranges that were - * inserted into the block are recorded separately for being used later. - * - When the block is ready to be pushed, the block is pushed and the ranges are reported as - * metadata of the block. In addition, the ranges are used to find out the latest sequence - * number for each shard that can be checkpointed through the DynamoDB. - * - Periodically, each KinesisRecordProcessor checkpoints the latest successfully stored sequence - * number for it own shard. + * + * - The receiver starts a KCL Worker, which is essentially runs a threadpool of multiple + * KinesisRecordProcessor + * - Each KinesisRecordProcessor receives data from a Kinesis shard in batches. Each batch is + * inserted into a Block Generator, and the corresponding range of sequence numbers is recorded. + * - When the block generator defines a block, then the recorded sequence number ranges that were + * inserted into the block are recorded separately for being used later. + * - When the block is ready to be pushed, the block is pushed and the ranges are reported as + * metadata of the block. In addition, the ranges are used to find out the latest sequence + * number for each shard that can be checkpointed through the DynamoDB. + * - Periodically, each KinesisRecordProcessor checkpoints the latest successfully stored sequence + * number for it own shard. * * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2849fd8a82102..2de6195716e5c 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -226,12 +226,13 @@ object KinesisUtils { * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * * Note: - * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets AWS credentials. - * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. - * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in - * [[org.apache.spark.SparkConf]]. + * + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name + * in [[org.apache.spark.SparkConf]]. * * @param ssc StreamingContext object * @param streamName Kinesis stream name diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 3b663b5defb03..37bb6f6097f67 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -81,11 +81,13 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * Set the convergence tolerance. Default 0.001 * convergenceTol is a condition which decides iteration termination. * The end of iteration is decided based on below logic. - * - If the norm of the new solution vector is >1, the diff of solution vectors - * is compared to relative tolerance which means normalizing by the norm of - * the new solution vector. - * - If the norm of the new solution vector is <=1, the diff of solution vectors - * is compared to absolute tolerance which is not normalizing. + * + * - If the norm of the new solution vector is >1, the diff of solution vectors + * is compared to relative tolerance which means normalizing by the norm of + * the new solution vector. + * - If the norm of the new solution vector is <=1, the diff of solution vectors + * is compared to absolute tolerance which is not normalizing. + * * Must be between 0.0 and 1.0 inclusively. */ def setConvergenceTol(tolerance: Double): this.type = { diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 63290d8a666e6..b1dcaedcba75e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -535,6 +535,8 @@ object Unidoc { .map(_.filterNot(_.getName.contains("$"))) .map(_.filterNot(_.getCanonicalPath.contains("akka"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/examples"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/memory"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/shuffle"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/executor"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b3cd9e1eff142..ad6af481fadc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -136,11 +136,12 @@ class Column(protected[sql] val expr: Expression) extends Logging { /** * Extracts a value or values from a complex type. * The following types of extraction are supported: - * - Given an Array, an integer ordinal can be used to retrieve a single value. - * - Given a Map, a key of the correct type can be used to retrieve an individual value. - * - Given a Struct, a string fieldName can be used to extract that field. - * - Given an Array of Structs, a string fieldName can be used to extract filed - * of every struct in that array, and return an Array of fields + * + * - Given an Array, an integer ordinal can be used to retrieve a single value. + * - Given a Map, a key of the correct type can be used to retrieve an individual value. + * - Given a Struct, a string fieldName can be used to extract that field. + * - Given an Array of Structs, a string fieldName can be used to extract filed + * of every struct in that array, and return an Array of fields * * @group expr_ops * @since 1.4.0 diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index aee172a4f549a..6fb8ad38abcec 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -574,11 +574,12 @@ class StreamingContext private[streaming] ( * :: DeveloperApi :: * * Return the current state of the context. The context can be in three possible states - - * - StreamingContextState.INTIALIZED - The context has been created, but not been started yet. - * Input DStreams, transformations and output operations can be created on the context. - * - StreamingContextState.ACTIVE - The context has been started, and been not stopped. - * Input DStreams, transformations and output operations cannot be created on the context. - * - StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. + * + * - StreamingContextState.INTIALIZED - The context has been created, but not been started yet. + * Input DStreams, transformations and output operations can be created on the context. + * - StreamingContextState.ACTIVE - The context has been started, and been not stopped. + * Input DStreams, transformations and output operations cannot be created on the context. + * - StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. */ @DeveloperApi def getState(): StreamingContextState = synchronized { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 40208a64861fb..cb5b1f252e90c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -42,6 +42,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * class remembers the information about the files selected in past batches for * a certain duration (say, "remember window") as shown in the figure below. * + * {{{ * |<----- remember window ----->| * ignore threshold --->| |<--- current batch time * |____.____.____.____.____.____| @@ -49,6 +50,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * ---------------------|----|----|----|----|----|----|-----------------------> Time * |____|____|____|____|____|____| * remembered batches + * }}} * * The trailing end of the window is the "ignore threshold" and all files whose mod times * are less than this threshold are assumed to have already been selected and are therefore @@ -59,14 +61,15 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * `isNewFile` for more details. * * This makes some assumptions from the underlying file system that the system is monitoring. - * - The clock of the file system is assumed to synchronized with the clock of the machine running - * the streaming app. - * - If a file is to be visible in the directory listings, it must be visible within a certain - * duration of the mod time of the file. This duration is the "remember window", which is set to - * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be - * selected as the mod time will be less than the ignore threshold when it becomes visible. - * - Once a file is visible, the mod time cannot change. If it does due to appends, then the - * processing semantics are undefined. + * + * - The clock of the file system is assumed to synchronized with the clock of the machine running + * the streaming app. + * - If a file is to be visible in the directory listings, it must be visible within a certain + * duration of the mod time of the file. This duration is the "remember window", which is set to + * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be + * selected as the mod time will be less than the ignore threshold when it becomes visible. + * - Once a file is visible, the mod time cannot change. If it does due to appends, then the + * processing semantics are undefined. */ private[streaming] class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 421d60ae359f8..cc7c04bfc9f63 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -84,13 +84,14 @@ private[streaming] class BlockGenerator( /** * The BlockGenerator can be in 5 possible states, in the order as follows. - * - Initialized: Nothing has been started - * - Active: start() has been called, and it is generating blocks on added data. - * - StoppedAddingData: stop() has been called, the adding of data has been stopped, - * but blocks are still being generated and pushed. - * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but - * they are still being pushed. - * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. + * + * - Initialized: Nothing has been started + * - Active: start() has been called, and it is generating blocks on added data. + * - StoppedAddingData: stop() has been called, the adding of data has been stopped, + * but blocks are still being generated and pushed. + * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but + * they are still being pushed. + * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. */ private object GeneratorState extends Enumeration { type GeneratorState = Value @@ -125,9 +126,10 @@ private[streaming] class BlockGenerator( /** * Stop everything in the right order such that all the data added is pushed out correctly. - * - First, stop adding data to the current buffer. - * - Second, stop generating blocks. - * - Finally, wait for queue of to-be-pushed blocks to be drained. + * + * - First, stop adding data to the current buffer. + * - Second, stop generating blocks. + * - Finally, wait for queue of to-be-pushed blocks to be drained. */ def stop(): Unit = { // Set the state to stop adding data diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala index 234bc8660da8a..391a461f08125 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -27,28 +27,29 @@ import org.apache.spark.streaming.receiver.Receiver * A class that tries to schedule receivers with evenly distributed. There are two phases for * scheduling receivers. * - * - The first phase is global scheduling when ReceiverTracker is starting and we need to schedule - * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. - * It will try to schedule receivers such that they are evenly distributed. ReceiverTracker should - * update its `receiverTrackingInfoMap` according to the results of `scheduleReceivers`. - * `ReceiverTrackingInfo.scheduledLocations` for each receiver should be set to an location list - * that contains the scheduled locations. Then when a receiver is starting, it will send a - * register request and `ReceiverTracker.registerReceiver` will be called. In - * `ReceiverTracker.registerReceiver`, if a receiver's scheduled locations is set, it should check - * if the location of this receiver is one of the scheduled locations, if not, the register will - * be rejected. - * - The second phase is local scheduling when a receiver is restarting. There are two cases of - * receiver restarting: - * - If a receiver is restarting because it's rejected due to the real location and the scheduled - * locations mismatching, in other words, it fails to start in one of the locations that - * `scheduleReceivers` suggested, `ReceiverTracker` should firstly choose the executors that are - * still alive in the list of scheduled locations, then use them to launch the receiver job. - * - If a receiver is restarting without a scheduled locations list, or the executors in the list - * are dead, `ReceiverTracker` should call `rescheduleReceiver`. If so, `ReceiverTracker` should - * not set `ReceiverTrackingInfo.scheduledLocations` for this receiver, instead, it should clear - * it. Then when this receiver is registering, we can know this is a local scheduling, and - * `ReceiverTrackingInfo` should call `rescheduleReceiver` again to check if the launching - * location is matching. + * - The first phase is global scheduling when ReceiverTracker is starting and we need to schedule + * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. + * It will try to schedule receivers such that they are evenly distributed. ReceiverTracker + * should update its `receiverTrackingInfoMap` according to the results of `scheduleReceivers`. + * `ReceiverTrackingInfo.scheduledLocations` for each receiver should be set to an location list + * that contains the scheduled locations. Then when a receiver is starting, it will send a + * register request and `ReceiverTracker.registerReceiver` will be called. In + * `ReceiverTracker.registerReceiver`, if a receiver's scheduled locations is set, it should + * check if the location of this receiver is one of the scheduled locations, if not, the register + * will be rejected. + * - The second phase is local scheduling when a receiver is restarting. There are two cases of + * receiver restarting: + * - If a receiver is restarting because it's rejected due to the real location and the scheduled + * locations mismatching, in other words, it fails to start in one of the locations that + * `scheduleReceivers` suggested, `ReceiverTracker` should firstly choose the executors that + * are still alive in the list of scheduled locations, then use them to launch the receiver + * job. + * - If a receiver is restarting without a scheduled locations list, or the executors in the list + * are dead, `ReceiverTracker` should call `rescheduleReceiver`. If so, `ReceiverTracker` + * should not set `ReceiverTrackingInfo.scheduledLocations` for this receiver, instead, it + * should clear it. Then when this receiver is registering, we can know this is a local + * scheduling, and `ReceiverTrackingInfo` should call `rescheduleReceiver` again to check if + * the launching location is matching. * * In conclusion, we should make a global schedule, try to achieve that exactly as long as possible, * otherwise do local scheduling. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index f5165f7c39122..a99b570835831 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -34,9 +34,10 @@ import org.apache.spark.{Logging, SparkConf} /** * This class manages write ahead log files. - * - Writes records (bytebuffers) to periodically rotating log files. - * - Recovers the log files and the reads the recovered records upon failures. - * - Cleans up old log files. + * + * - Writes records (bytebuffers) to periodically rotating log files. + * - Recovers the log files and the reads the recovered records upon failures. + * - Cleans up old log files. * * Uses [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]] to write * and [[org.apache.spark.streaming.util.FileBasedWriteAheadLogReader]] to read. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index 0148cb51c6f09..bfb53614050a7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -72,10 +72,10 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: /** * Stop the timer, and return the last time the callback was made. - * - interruptTimer = true will interrupt the callback - * if it is in progress (not guaranteed to give correct time in this case). - * - interruptTimer = false guarantees that there will be at least one callback after `stop` has - * been called. + * + * @param interruptTimer True will interrupt the callback if it is in progress (not guaranteed to + * give correct time in this case). False guarantees that there will be at + * least one callback after `stop` has been called. */ def stop(interruptTimer: Boolean): Long = synchronized { if (!stopped) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index a77a3e2420e24..f0590d2d222ec 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1336,11 +1336,11 @@ object Client extends Logging { * * This method uses two configuration values: * - * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may - * only be valid in the gateway node. - * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may - * contain, for example, env variable references, which will be expanded by the NMs when - * starting containers. + * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may + * only be valid in the gateway node. + * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may + * contain, for example, env variable references, which will be expanded by the NMs when + * starting containers. * * If either config is not available, the input path is returned. */ From 163e38eb2a1343a9883e7b7b7f58db835558a819 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 1 Dec 2015 10:22:55 -0800 Subject: [PATCH 0988/1168] [SPARK-12068][SQL] use a single column in Dataset.groupBy and count will fail The reason is that, for a single culumn `RowEncoder`(or a single field product encoder), when we use it as the encoder for grouping key, we should also combine the grouping attributes, although there is only one grouping attribute. Author: Wenchen Fan Closes #10059 from cloud-fan/bug. --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../org/apache/spark/sql/GroupedDataset.scala | 7 ++++--- .../org/apache/spark/sql/DatasetSuite.scala | 19 +++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 6 +++--- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index da4600133290f..c357f88a94dd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -70,7 +70,7 @@ class Dataset[T] private[sql]( * implicit so that we can use it when constructing new [[Dataset]] objects that have the same * object type (that will be possibly resolved to a different schema). */ - private implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) + private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index a10a89342fb5c..4bf0b256fcb4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -228,10 +228,11 @@ class GroupedDataset[K, V] private[sql]( val namedColumns = columns.map( _.withInputType(resolvedVEncoder, dataAttributes).named) - val keyColumn = if (groupingAttributes.length > 1) { - Alias(CreateStruct(groupingAttributes), "key")() - } else { + val keyColumn = if (resolvedKEncoder.flat) { + assert(groupingAttributes.length == 1) groupingAttributes.head + } else { + Alias(CreateStruct(groupingAttributes), "key")() } val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sqlContext, aggregate) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 7d539180ded9e..a2c8d201563e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -272,6 +272,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 3 -> "abcxyz", 5 -> "hello") } + test("groupBy single field class, count") { + val ds = Seq("abc", "xyz", "hello").toDS() + val count = ds.groupBy(s => Tuple1(s.length)).count() + + checkAnswer( + count, + (Tuple1(3), 2L), (Tuple1(5), 1L) + ) + } + test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") @@ -282,6 +292,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 30), ("b", 3), ("c", 1)) } + test("groupBy columns, count") { + val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS() + val count = ds.groupBy($"_1").count() + + checkAnswer( + count, + (Row("a"), 2L), (Row("b"), 1L)) + } + test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").keyAs[String] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 6ea1fe4ccfd89..8f476dd0f99b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -64,12 +64,12 @@ abstract class QueryTest extends PlanTest { * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead * which performs a subset of the checks done by this function. */ - protected def checkAnswer[T : Encoder]( - ds: => Dataset[T], + protected def checkAnswer[T]( + ds: Dataset[T], expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), - sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq) + sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) checkDecoding(ds, expectedAnswer: _*) } From 9cf3e2e423e075634f8d55690fb652cb871a82a2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 1 Dec 2015 10:24:53 -0800 Subject: [PATCH 0989/1168] [SPARK-11856][SQL] add type cast if the real type is different but compatible with encoder schema When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff and lost the required data type, which may lead to runtime error if the real type doesn't match the encoder's schema. For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type is `[a: int, b: long]`, then we will hit runtime error and say that we can't construct class `Data` with int and long, because we lost the information that `b` should be a string. Author: Wenchen Fan Closes #9840 from cloud-fan/err-msg. --- .../spark/sql/catalyst/ScalaReflection.scala | 93 +++++++-- .../sql/catalyst/analysis/Analyzer.scala | 40 ++++ .../catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../catalyst/encoders/ExpressionEncoder.scala | 4 +- .../spark/sql/catalyst/expressions/Cast.scala | 9 + .../expressions/complexTypeCreator.scala | 2 +- .../apache/spark/sql/types/DecimalType.scala | 12 ++ .../encoders/EncoderResolutionSuite.scala | 180 ++++++++++++++++++ .../spark/sql/DatasetAggregatorSuite.scala | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 21 +- 10 files changed, 335 insertions(+), 32 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d133ad3f0d89d..9b6b5b8bd1a27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -18,9 +18,8 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -117,31 +116,75 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None) + def constructorFor[T : TypeTag]: Expression = { + val tpe = localTypeOf[T] + val clsName = getClassNameFromType(tpe) + val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + constructorFor(tpe, None, walkedTypePath) + } private def constructorFor( tpe: `Type`, - path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { + path: Option[Expression], + walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) + def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { + val newPath = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetStructField(p, ordinal)) - .getOrElse(BoundReference(ordinal, dataType, false)) + def addToPathOrdinal( + ordinal: Int, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { + val newPath = path + .map(p => GetStructField(p, ordinal)) + .getOrElse(BoundReference(ordinal, dataType, false)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } /** Returns the current path or `BoundReference`. */ - def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + def getPath: Expression = { + val dataType = schemaFor(tpe).dataType + if (path.isDefined) { + path.get + } else { + upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath) + } + } + + /** + * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff + * and lost the required data type, which may lead to runtime error if the real type doesn't + * match the encoder's schema. + * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type + * is [a: int, b: long], then we will hit runtime error and say that we can't construct class + * `Data` with int and long, because we lost the information that `b` should be a string. + * + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * don't need to cast struct type because there must be `UnresolvedExtractValue` or + * `GetStructField` wrapping it, thus we only need to handle leaf type. + */ + def upCastToExpectedType( + expr: Expression, + expected: DataType, + walkedTypePath: Seq[String]): Expression = expected match { + case _: StructType => expr + case _ => UpCast(expr, expected, walkedTypePath) + } tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - WrapOption(constructorFor(optType, path)) + val className = getClassNameFromType(optType) + val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath + WrapOption(constructorFor(optType, path, newTypePath)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] @@ -219,9 +262,11 @@ object ScalaReflection extends ScalaReflection { primitiveMethod.map { method => Invoke(getPath, method, arrayClassFor(elementType)) }.getOrElse { + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath Invoke( MapObjects( - p => constructorFor(elementType, Some(p)), + p => constructorFor(elementType, Some(p), newTypePath), getPath, schemaFor(elementType).dataType), "array", @@ -230,10 +275,12 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath val arrayData = Invoke( MapObjects( - p => constructorFor(elementType, Some(p)), + p => constructorFor(elementType, Some(p), newTypePath), getPath, schemaFor(elementType).dataType), "array", @@ -246,12 +293,13 @@ object ScalaReflection extends ScalaReflection { arrayData :: Nil) case t if t <:< localTypeOf[Map[_, _]] => + // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t val keyData = Invoke( MapObjects( - p => constructorFor(keyType, Some(p)), + p => constructorFor(keyType, Some(p), walkedTypePath), Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), schemaFor(keyType).dataType), "array", @@ -260,7 +308,7 @@ object ScalaReflection extends ScalaReflection { val valueData = Invoke( MapObjects( - p => constructorFor(valueType, Some(p)), + p => constructorFor(valueType, Some(p), walkedTypePath), Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), schemaFor(valueType).dataType), "array", @@ -297,12 +345,19 @@ object ScalaReflection extends ScalaReflection { val fieldName = p.name.toString val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) val dataType = schemaFor(fieldType).dataType - + val clsName = getClassNameFromType(fieldType) + val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. if (cls.getName startsWith "scala.Tuple") { - constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) + constructorFor( + fieldType, + Some(addToPathOrdinal(i, dataType, newTypePath)), + newTypePath) } else { - constructorFor(fieldType, Some(addToPath(fieldName))) + constructorFor( + fieldType, + Some(addToPath(fieldName, dataType, newTypePath)), + newTypePath) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b8f212fca7509..765327c474e69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -72,6 +72,7 @@ class Analyzer( ResolveReferences :: ResolveGroupingAnalytics :: ResolvePivot :: + ResolveUpCast :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: @@ -1182,3 +1183,42 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { } } } + +/** + * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. + */ +object ResolveUpCast extends Rule[LogicalPlan] { + private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { + throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " + + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + + "You can either add an explicit cast to the input data or choose a higher precision " + + "type of the field in the target object") + } + + private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) + toPrecedence > 0 && fromPrecedence > toPrecedence + } + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transformAllExpressions { + case u @ UpCast(child, _, _) if !child.resolved => u + + case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { + case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => + fail(child, to, walkedTypePath) + case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => + fail(child, to, walkedTypePath) + case (from, to) if illegalNumericPrecedence(from, to) => + fail(child, to, walkedTypePath) + case (TimestampType, DateType) => + fail(child, DateType, walkedTypePath) + case (StringType, to: NumericType) => + fail(child, to, walkedTypePath) + case _ => Cast(child, dataType) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index f90fc3cc12189..29502a59915f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -53,7 +53,7 @@ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: - private val numericPrecedence = + private[sql] val numericPrecedence = IndexedSeq( ByteType, ShortType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 0c10a56c555f4..06ffe864552fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtract import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types.{StructField, ObjectType, StructType} @@ -235,12 +236,13 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) + val optimizedPlan = SimplifyCasts(analyzedPlan) // In order to construct instances of inner classes (for example those declared in a REPL cell), // we need an instance of the outer scope. This rule substitues those outer objects into // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` // registry. - copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform { + copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform { case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => val outer = outerScopes.get(n.cls.getDeclaringClass.getName) if (outer == null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a2c6c39fd8ce2..cb60d5958d535 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -914,3 +914,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { """ } } + +/** + * Cast the child expression to the target data type, but will throw error if the cast might + * truncate, e.g. long -> int, timestamp -> data. + */ +case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String]) + extends UnaryExpression with Unevaluable { + override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 1854dfaa7db35..72cc89c8be915 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -126,7 +126,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { case class CreateNamedStruct(children: Seq[Expression]) extends Expression { /** - * Returns Aliased [[Expressions]] that could be used to construct a flattened version of this + * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this * StructType. */ def flatten: Seq[NamedExpression] = valExprs.zip(names).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 0cd352d0fa928..ce45245b9f6dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -90,6 +90,18 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { case _ => false } + /** + * Returns whether this DecimalType is tighter than `other`. If yes, it means `this` + * can be casted into `other` safely without losing any precision or range. + */ + private[sql] def isTighterThan(other: DataType): Boolean = other match { + case dt: DecimalType => + (precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale + case dt: IntegralType => + isTighterThan(DecimalType.forType(dt)) + case _ => false + } + /** * The default size of a value of the DecimalType is 4096 bytes. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala new file mode 100644 index 0000000000000..0289988342e78 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types._ + +case class StringLongClass(a: String, b: Long) + +case class StringIntClass(a: String, b: Int) + +case class ComplexClass(a: Long, b: StringLongClass) + +class EncoderResolutionSuite extends PlanTest { + test("real type doesn't match encoder schema but they are compatible: product") { + val encoder = ExpressionEncoder[StringLongClass] + val cls = classOf[StringLongClass] + + { + val attrs = Seq('a.string, 'b.int) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + cls, + toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil, + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + + { + val attrs = Seq('a.int, 'b.long) + val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression + val expected = NewInstance( + cls, + toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil, + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + } + + test("real type doesn't match encoder schema but they are compatible: nested product") { + val encoder = ExpressionEncoder[ComplexClass] + val innerCls = classOf[StringLongClass] + val cls = classOf[ComplexClass] + + val structType = new StructType().add("a", IntegerType).add("b", LongType) + val attrs = Seq('a.int, 'b.struct(structType)) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + cls, + Seq( + 'a.int.cast(LongType), + If( + 'b.struct(structType).isNull, + Literal.create(null, ObjectType(innerCls)), + NewInstance( + innerCls, + Seq( + toExternalString( + GetStructField('b.struct(structType), 0, Some("a")).cast(StringType)), + GetStructField('b.struct(structType), 1, Some("b"))), + false, + ObjectType(innerCls)) + )), + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + + test("real type doesn't match encoder schema but they are compatible: tupled encoder") { + val encoder = ExpressionEncoder.tuple( + ExpressionEncoder[StringLongClass], + ExpressionEncoder[Long]) + val cls = classOf[StringLongClass] + + val structType = new StructType().add("a", StringType).add("b", ByteType) + val attrs = Seq('a.struct(structType), 'b.int) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + classOf[Tuple2[_, _]], + Seq( + NewInstance( + cls, + Seq( + toExternalString(GetStructField('a.struct(structType), 0, Some("a"))), + GetStructField('a.struct(structType), 1, Some("b")).cast(LongType)), + false, + ObjectType(cls)), + 'b.int.cast(LongType)), + false, + ObjectType(classOf[Tuple2[_, _]])) + compareExpressions(fromRowExpr, expected) + } + + private def toExternalString(e: Expression): Expression = { + Invoke(e, "toString", ObjectType(classOf[String]), Nil) + } + + test("throw exception if real type is not compatible with encoder schema") { + val msg1 = intercept[AnalysisException] { + ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) + }.message + assert(msg1 == + s""" + |Cannot up cast `b` from bigint to int as it may truncate + |The type path of the target object is: + |- field (class: "scala.Int", name: "b") + |- root class: "org.apache.spark.sql.catalyst.encoders.StringIntClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + + val msg2 = intercept[AnalysisException] { + val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) + ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null) + }.message + assert(msg2 == + s""" + |Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate + |The type path of the target object is: + |- field (class: "scala.Long", name: "b") + |- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b") + |- root class: "org.apache.spark.sql.catalyst.encoders.ComplexClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + } + + // test for leaf types + castSuccess[Int, Long] + castSuccess[java.sql.Date, java.sql.Timestamp] + castSuccess[Long, String] + castSuccess[Int, java.math.BigDecimal] + castSuccess[Long, java.math.BigDecimal] + + castFail[Long, Int] + castFail[java.sql.Timestamp, java.sql.Date] + castFail[java.math.BigDecimal, Double] + castFail[Double, java.math.BigDecimal] + castFail[java.math.BigDecimal, Int] + castFail[String, Long] + + + private def castSuccess[T: TypeTag, U: TypeTag]: Unit = { + val from = ExpressionEncoder[T] + val to = ExpressionEncoder[U] + val catalystType = from.schema.head.dataType.simpleString + test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") { + to.resolve(from.schema.toAttributes, null) + } + } + + private def castFail[T: TypeTag, U: TypeTag]: Unit = { + val from = ExpressionEncoder[T] + val to = ExpressionEncoder[U] + val catalystType = from.schema.head.dataType.simpleString + test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") { + intercept[AnalysisException](to.resolve(from.schema.toAttributes, null)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 19dce5d1e2f37..c6d2bf07b2803 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -131,9 +131,9 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { checkAnswer( ds.groupBy(_._1).agg( sum(_._2), - expr("sum(_2)").as[Int], + expr("sum(_2)").as[Long], count("*")), - ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) + ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L)) } test("typed aggregation: complex case") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a2c8d201563e5..542e4d6c43b9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -335,24 +335,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Int]), - ("a", 30), ("b", 3), ("c", 1)) + ds.groupBy(_._1).agg(sum("_2").as[Long]), + ("a", 30L), ("b", 3L), ("c", 1L)) } test("typed aggregation: expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long]), - ("a", 30, 32L), ("b", 3, 5L), ("c", 1, 2L)) + ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), + ("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L)) } test("typed aggregation: expr, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long], count("*").as[Long]), - ("a", 30, 32L, 2L), ("b", 3, 5L, 2L), ("c", 1, 2L, 1L)) + ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), + ("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L)) } test("typed aggregation: expr, expr, expr, expr") { @@ -360,11 +360,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer( ds.groupBy(_._1).agg( - sum("_2").as[Int], + sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*").as[Long], avg("_2").as[Double]), - ("a", 30, 32L, 2L, 15.0), ("b", 3, 5L, 2L, 1.5), ("c", 1, 2L, 1L, 1.0)) + ("a", 30L, 32L, 2L, 15.0), ("b", 3L, 5L, 2L, 1.5), ("c", 1L, 2L, 1L, 1.0)) } test("cogroup") { @@ -476,6 +476,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ((nullInt, "1"), (new java.lang.Integer(22), "2")), ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) } + + test("change encoder with compatible schema") { + val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData] + assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3))) + } } From 01ab37dd5eb3532fb5d25770ffcad1eb876b2f4c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 1 Dec 2015 10:35:12 -0800 Subject: [PATCH 0990/1168] [SPARK-11954][SQL] Encoder for JavaBeans create java version of `constructorFor` and `extractorFor` in `JavaTypeInference` Author: Wenchen Fan This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #9937 from cloud-fan/pojo. --- .../scala/org/apache/spark/sql/Encoder.scala | 18 + .../sql/catalyst/JavaTypeInference.scala | 313 +++++++++++++++++- .../catalyst/encoders/ExpressionEncoder.scala | 21 +- .../sql/catalyst/expressions/objects.scala | 42 ++- .../spark/sql/catalyst/trees/TreeNode.scala | 27 +- .../sql/catalyst/util/ArrayBasedMapData.scala | 5 + .../sql/catalyst/util/GenericArrayData.scala | 3 + .../sql/catalyst/trees/TreeNodeSuite.scala | 25 ++ .../apache/spark/sql/JavaDatasetSuite.java | 174 +++++++++- 9 files changed, 608 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 03aa25eda807f..c40061ae0aafd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -97,6 +97,24 @@ object Encoders { */ def STRING: Encoder[java.lang.String] = ExpressionEncoder() + /** + * Creates an encoder for Java Bean of type T. + * + * T must be publicly accessible. + * + * supported types for java bean field: + * - primitive types: boolean, int, double, etc. + * - boxed types: Boolean, Integer, Double, etc. + * - String + * - java.math.BigDecimal + * - time related: java.sql.Date, java.sql.Timestamp + * - collection types: only array and java.util.List currently, map support is in progress + * - nested java bean. + * + * @since 1.6.0 + */ + def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) + /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7d4cfbe6faecb..c8ee87e8819f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -17,14 +17,20 @@ package org.apache.spark.sql.catalyst -import java.beans.Introspector +import java.beans.{PropertyDescriptor, Introspector} import java.lang.{Iterable => JIterable} -import java.util.{Iterator => JIterator, Map => JMap} +import java.util.{Iterator => JIterator, Map => JMap, List => JList} import scala.language.existentials import com.google.common.reflect.TypeToken + import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.unsafe.types.UTF8String + /** * Type-inference utilities for POJOs and Java collections. @@ -33,13 +39,14 @@ object JavaTypeInference { private val iterableType = TypeToken.of(classOf[JIterable[_]]) private val mapType = TypeToken.of(classOf[JMap[_, _]]) + private val listType = TypeToken.of(classOf[JList[_]]) private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType /** - * Infers the corresponding SQL data type of a JavaClean class. + * Infers the corresponding SQL data type of a JavaBean class. * @param beanClass Java type * @return (SQL data type, nullable) */ @@ -58,6 +65,8 @@ object JavaTypeInference { (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) @@ -87,15 +96,14 @@ object JavaTypeInference { (ArrayType(dataType, nullable), true) case _ if mapType.isAssignableFrom(typeToken) => - val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] - val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]]) - val keyType = elementType(mapSupertype.resolveType(keySetReturnType)) - val valueType = elementType(mapSupertype.resolveType(valuesReturnType)) + val (keyType, valueType) = mapKeyValueType(typeToken) val (keyDataType, _) = inferDataType(keyType) val (valueDataType, nullable) = inferDataType(valueType) (MapType(keyDataType, valueDataType, nullable), true) case _ => + // TODO: we should only collect properties that have getter and setter. However, some tests + // pass in scala case class as java bean class which doesn't have getter and setter. val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") val fields = properties.map { property => @@ -107,11 +115,294 @@ object JavaTypeInference { } } + private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { + val beanInfo = Introspector.getBeanInfo(beanClass) + beanInfo.getPropertyDescriptors + .filter(p => p.getReadMethod != null && p.getWriteMethod != null) + } + private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]] - val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]]) - val iteratorType = iterableSupertype.resolveType(iteratorReturnType) - val itemType = iteratorType.resolveType(nextReturnType) - itemType + val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]]) + val iteratorType = iterableSuperType.resolveType(iteratorReturnType) + iteratorType.resolveType(nextReturnType) + } + + private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = { + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] + val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]]) + val keyType = elementType(mapSuperType.resolveType(keySetReturnType)) + val valueType = elementType(mapSuperType.resolveType(valuesReturnType)) + keyType -> valueType + } + + /** + * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping + * to a native type, an ObjectType is returned. + * + * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type + * system. As a result, ObjectType will be returned for things like boxed Integers. + */ + private def inferExternalType(cls: Class[_]): DataType = cls match { + case c if c == java.lang.Boolean.TYPE => BooleanType + case c if c == java.lang.Byte.TYPE => ByteType + case c if c == java.lang.Short.TYPE => ShortType + case c if c == java.lang.Integer.TYPE => IntegerType + case c if c == java.lang.Long.TYPE => LongType + case c if c == java.lang.Float.TYPE => FloatType + case c if c == java.lang.Double.TYPE => DoubleType + case c if c == classOf[Array[Byte]] => BinaryType + case _ => ObjectType(cls) + } + + /** + * Returns an expression that can be used to construct an object of java bean `T` given an input + * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * of the same name as the constructor arguments. Nested classes will have their fields accessed + * using UnresolvedExtractValue. + */ + def constructorFor(beanClass: Class[_]): Expression = { + constructorFor(TypeToken.of(beanClass), None) + } + + private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { + /** Returns the current path with a sub-field extracted. */ + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true)) + + typeToken.getRawType match { + case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath + + case c if c == classOf[java.lang.Short] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Integer] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Long] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Double] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Byte] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Float] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Boolean] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + + case c if c == classOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + ObjectType(c), + "toJavaDate", + getPath :: Nil, + propagateNull = true) + + case c if c == classOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + ObjectType(c), + "toJavaTimestamp", + getPath :: Nil, + propagateNull = true) + + case c if c == classOf[java.lang.String] => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case c if c == classOf[java.math.BigDecimal] => + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case c if c.isArray => + val elementType = c.getComponentType + val primitiveMethod = elementType match { + case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray") + case c if c == java.lang.Byte.TYPE => Some("toByteArray") + case c if c == java.lang.Short.TYPE => Some("toShortArray") + case c if c == java.lang.Integer.TYPE => Some("toIntArray") + case c if c == java.lang.Long.TYPE => Some("toLongArray") + case c if c == java.lang.Float.TYPE => Some("toFloatArray") + case c if c == java.lang.Double.TYPE => Some("toDoubleArray") + case _ => None + } + + primitiveMethod.map { method => + Invoke(getPath, method, ObjectType(c)) + }.getOrElse { + Invoke( + MapObjects( + p => constructorFor(typeToken.getComponentType, Some(p)), + getPath, + inferDataType(elementType)._1), + "array", + ObjectType(c)) + } + + case c if listType.isAssignableFrom(typeToken) => + val et = elementType(typeToken) + val array = + Invoke( + MapObjects( + p => constructorFor(et, Some(p)), + getPath, + inferDataType(et)._1), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil) + + case _ if mapType.isAssignableFrom(typeToken) => + val (keyType, valueType) = mapKeyValueType(typeToken) + val keyDataType = inferDataType(keyType)._1 + val valueDataType = inferDataType(valueType)._1 + + val keyData = + Invoke( + MapObjects( + p => constructorFor(keyType, Some(p)), + Invoke(getPath, "keyArray", ArrayType(keyDataType)), + keyDataType), + "array", + ObjectType(classOf[Array[Any]])) + + val valueData = + Invoke( + MapObjects( + p => constructorFor(valueType, Some(p)), + Invoke(getPath, "valueArray", ArrayType(valueDataType)), + valueDataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[JMap[_, _]]), + "toJavaMap", + keyData :: valueData :: Nil) + + case other => + val properties = getJavaBeanProperties(other) + assert(properties.length > 0) + + val setters = properties.map { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName))) + }.toMap + + val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other)) + val result = InitializeJavaBean(newInstance, setters) + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(other)), + result + ) + } else { + result + } + } + } + + /** + * Returns expressions for extracting all the fields from the given type. + */ + def extractorsFor(beanClass: Class[_]): CreateNamedStruct = { + val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) + extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + } + + private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { + + def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { + val (dataType, nullable) = inferDataType(elementType) + if (ScalaReflection.isNativeType(dataType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(dataType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType)) + } + } + + if (!inputObject.dataType.isInstanceOf[ObjectType]) { + inputObject + } else { + typeToken.getRawType match { + case c if c == classOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case c if c == classOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case c if c == classOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case c if c == classOf[java.math.BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case c if c == classOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + case c if c == classOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case c if c == classOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case c if c == classOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case c if c == classOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case c if c == classOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case c if c == classOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + + case _ if typeToken.isArray => + toCatalystArray(inputObject, typeToken.getComponentType) + + case _ if listType.isAssignableFrom(typeToken) => + toCatalystArray(inputObject, elementType(typeToken)) + + case _ if mapType.isAssignableFrom(typeToken) => + // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can + // not guarantee they have same iteration order(which is different from scala map). + // A possible solution is creating a new `MapObjects` that can iterate a map directly. + throw new UnsupportedOperationException("map type is not supported currently") + + case other => + val properties = getJavaBeanProperties(other) + if (properties.length > 0) { + CreateNamedStruct(properties.flatMap { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + val fieldValue = Invoke( + inputObject, + p.getReadMethod.getName, + inferExternalType(fieldType.getRawType)) + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + }) + } else { + throw new UnsupportedOperationException(s"no encoder found for ${other.getName}") + } + } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 06ffe864552fd..3e8420ecb9ccf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -29,8 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection} import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** @@ -68,6 +67,22 @@ object ExpressionEncoder { ClassTag[T](cls)) } + // TODO: improve error message for java bean encoder. + def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = { + val schema = JavaTypeInference.inferDataType(beanClass)._1 + assert(schema.isInstanceOf[StructType]) + + val toRowExpression = JavaTypeInference.extractorsFor(beanClass) + val fromRowExpression = JavaTypeInference.constructorFor(beanClass) + + new ExpressionEncoder[T]( + schema.asInstanceOf[StructType], + flat = false, + toRowExpression.flatten, + fromRowExpression, + ClassTag[T](beanClass)) + } + /** * Given a set of N encoders, constructs a new encoder that produce objects as items in an * N-tuple. Note that these encoders should be unresolved so that information about @@ -216,7 +231,7 @@ case class ExpressionEncoder[T]( */ def assertUnresolved(): Unit = { (fromRowExpression +: toRowExpressions).foreach(_.foreach { - case a: AttributeReference => + case a: AttributeReference if a.name != "loopVar" => sys.error(s"Unresolved encoder expected, but $a was found.") case _ => }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 62d09f0f55105..e6ab9a31be59e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -346,7 +346,8 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext * as an ArrayType. This is similar to a typical map operation, but where the lambda function * is expressed using catalyst expressions. * - * The following collection ObjectTypes are currently supported: Seq, Array, ArrayData + * The following collection ObjectTypes are currently supported: + * Seq, Array, ArrayData, java.util.List * * @param function A function that returns an expression, given an attribute that can be used * to access the current value. This is does as a lambda function so that @@ -386,6 +387,8 @@ case class MapObjects( (".size()", (i: String) => s".apply($i)", false) case ObjectType(cls) if cls.isArray => (".length", (i: String) => s"[$i]", false) + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + (".size()", (i: String) => s".get($i)", false) case ArrayType(t, _) => val (sqlType, primitiveElement) = t match { case m: MapType => (m, false) @@ -596,3 +599,40 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B override def dataType: DataType = ObjectType(tag.runtimeClass) } + +/** + * Initialize a Java Bean instance by setting its field values via setters. + */ +case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression]) + extends Expression { + + override def nullable: Boolean = beanInstance.nullable + override def children: Seq[Expression] = beanInstance +: setters.values.toSeq + override def dataType: DataType = beanInstance.dataType + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val instanceGen = beanInstance.gen(ctx) + + val initialize = setters.map { + case (setterMethod, fieldValue) => + val fieldGen = fieldValue.gen(ctx) + s""" + ${fieldGen.code} + ${instanceGen.value}.$setterMethod(${fieldGen.value}); + """ + } + + ev.isNull = instanceGen.isNull + ev.value = instanceGen.value + + s""" + ${instanceGen.code} + if (!${instanceGen.isNull}) { + ${initialize.mkString("\n")} + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 35f087baccdee..f1cea07976a37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.trees +import scala.collection.Map + import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.types.{StructType, DataType} @@ -191,6 +193,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case nonChild: AnyRef => nonChild case null => null } + case m: Map[_, _] => m.mapValues { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + case nonChild: AnyRef => nonChild + case null => null + }.view.force // `mapValues` is lazy and we need to force it to materialize case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) @@ -262,7 +277,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } else { Some(arg) } - case m: Map[_, _] => m + case m: Map[_, _] => m.mapValues { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case other => other + }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if containsChild(arg) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index 70b028d2b3f7c..d85b72ed83def 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -70,4 +70,9 @@ object ArrayBasedMapData { def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { keys.zip(values).toMap } + + def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = { + import scala.collection.JavaConverters._ + keys.zip(values).toMap.asJava + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 96588bb5dc1bc..2b8cdc1e23ab3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{DataType, Decimal} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -24,6 +26,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(seq: Seq[Any]) = this(seq.toArray) + def this(list: java.util.List[Any]) = this(list.asScala) // TODO: This is boxing. We should specialize. def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 8fff39906b342..965bdb1515e55 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -38,6 +38,13 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]]) override def output: Seq[Attribute] = Nil } +case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable { + override def children: Seq[Expression] = map.values.toSeq + override def nullable: Boolean = true + override def dataType: NullType = NullType + override lazy val resolved = true +} + class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } @@ -236,4 +243,22 @@ class TreeNodeSuite extends SparkFunSuite { val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2")))) assert(expected === actual) } + + test("expressions inside a map") { + val expression = ExpressionInMap(Map("1" -> Literal(1), "2" -> Literal(2))) + + { + val actual = expression.transform { + case Literal(i: Int, _) => Literal(i + 1) + } + val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3))) + assert(actual === expected) + } + + { + val actual = expression.withNewChildren(Seq(Literal(2), Literal(3))) + val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3))) + assert(actual === expected) + } + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 67a3190cb7d4f..ae47f4fe0e231 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -31,14 +31,15 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.Encoder; -import org.apache.spark.sql.Encoders; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.GroupedDataset; +import org.apache.spark.sql.*; import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.catalyst.encoders.OuterScopes; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.types.DataTypes.*; public class JavaDatasetSuite implements Serializable { private transient JavaSparkContext jsc; @@ -506,4 +507,169 @@ public void testJavaEncoderErrorMessageForPrivateClass() { public void testKryoEncoderErrorMessageForPrivateClass() { Encoders.kryo(PrivateClassTest.class); } + + public class SimpleJavaBean implements Serializable { + private boolean a; + private int b; + private byte[] c; + private String[] d; + private List e; + private List f; + + public boolean isA() { + return a; + } + + public void setA(boolean a) { + this.a = a; + } + + public int getB() { + return b; + } + + public void setB(int b) { + this.b = b; + } + + public byte[] getC() { + return c; + } + + public void setC(byte[] c) { + this.c = c; + } + + public String[] getD() { + return d; + } + + public void setD(String[] d) { + this.d = d; + } + + public List getE() { + return e; + } + + public void setE(List e) { + this.e = e; + } + + public List getF() { + return f; + } + + public void setF(List f) { + this.f = f; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + SimpleJavaBean that = (SimpleJavaBean) o; + + if (a != that.a) return false; + if (b != that.b) return false; + if (!Arrays.equals(c, that.c)) return false; + if (!Arrays.equals(d, that.d)) return false; + if (!e.equals(that.e)) return false; + return f.equals(that.f); + } + + @Override + public int hashCode() { + int result = (a ? 1 : 0); + result = 31 * result + b; + result = 31 * result + Arrays.hashCode(c); + result = 31 * result + Arrays.hashCode(d); + result = 31 * result + e.hashCode(); + result = 31 * result + f.hashCode(); + return result; + } + } + + public class NestedJavaBean implements Serializable { + private SimpleJavaBean a; + + public SimpleJavaBean getA() { + return a; + } + + public void setA(SimpleJavaBean a) { + this.a = a; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + NestedJavaBean that = (NestedJavaBean) o; + + return a.equals(that.a); + } + + @Override + public int hashCode() { + return a.hashCode(); + } + } + + @Test + public void testJavaBeanEncoder() { + OuterScopes.addOuterScope(this); + SimpleJavaBean obj1 = new SimpleJavaBean(); + obj1.setA(true); + obj1.setB(3); + obj1.setC(new byte[]{1, 2}); + obj1.setD(new String[]{"hello", null}); + obj1.setE(Arrays.asList("a", "b")); + obj1.setF(Arrays.asList(100L, null, 200L)); + SimpleJavaBean obj2 = new SimpleJavaBean(); + obj2.setA(false); + obj2.setB(30); + obj2.setC(new byte[]{3, 4}); + obj2.setD(new String[]{null, "world"}); + obj2.setE(Arrays.asList("x", "y")); + obj2.setF(Arrays.asList(300L, null, 400L)); + + List data = Arrays.asList(obj1, obj2); + Dataset ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); + Assert.assertEquals(data, ds.collectAsList()); + + NestedJavaBean obj3 = new NestedJavaBean(); + obj3.setA(obj1); + + List data2 = Arrays.asList(obj3); + Dataset ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); + Assert.assertEquals(data2, ds2.collectAsList()); + + Row row1 = new GenericRow(new Object[]{ + true, + 3, + new byte[]{1, 2}, + new String[]{"hello", null}, + Arrays.asList("a", "b"), + Arrays.asList(100L, null, 200L)}); + Row row2 = new GenericRow(new Object[]{ + false, + 30, + new byte[]{3, 4}, + new String[]{null, "world"}, + Arrays.asList("x", "y"), + Arrays.asList(300L, null, 400L)}); + StructType schema = new StructType() + .add("a", BooleanType, false) + .add("b", IntegerType, false) + .add("c", BinaryType) + .add("d", createArrayType(StringType)) + .add("e", createArrayType(StringType)) + .add("f", createArrayType(LongType)); + Dataset ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema) + .as(Encoders.bean(SimpleJavaBean.class)); + Assert.assertEquals(data, ds3.collectAsList()); + } } From efd966102eb31af1ea96d442497e944f3254509b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 1 Dec 2015 10:38:59 -0800 Subject: [PATCH 0991/1168] [SPARK-11905][SQL] Support Persist/Cache and Unpersist in Dataset APIs Persist and Unpersist exist in both RDD and Dataframe APIs. I think they are still very critical in Dataset APIs. Not sure if my understanding is correct? If so, could you help me check if the implementation is acceptable? Please provide your opinions. marmbrus rxin cloud-fan Thank you very much! Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #9889 from gatorsmile/persistDS. --- .../org/apache/spark/sql/DataFrame.scala | 9 +++ .../scala/org/apache/spark/sql/Dataset.scala | 50 +++++++++++- .../org/apache/spark/sql/SQLContext.scala | 9 +++ .../spark/sql/execution/CacheManager.scala | 27 ++++--- .../apache/spark/sql/DatasetCacheSuite.scala | 80 +++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 5 +- 6 files changed, 162 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 6197f10813a3b..eb8700369275e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1584,6 +1584,7 @@ class DataFrame private[sql]( def distinct(): DataFrame = dropDuplicates() /** + * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). * @group basic * @since 1.3.0 */ @@ -1593,12 +1594,17 @@ class DataFrame private[sql]( } /** + * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). * @group basic * @since 1.3.0 */ def cache(): this.type = persist() /** + * Persist this [[DataFrame]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. * @group basic * @since 1.3.0 */ @@ -1608,6 +1614,8 @@ class DataFrame private[sql]( } /** + * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. + * @param blocking Whether to block until all blocks are deleted. * @group basic * @since 1.3.0 */ @@ -1617,6 +1625,7 @@ class DataFrame private[sql]( } /** + * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. * @group basic * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c357f88a94dd0..d6bb1d2ad8e50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils /** @@ -565,7 +566,7 @@ class Dataset[T] private[sql]( * combined. * * Note that, this function is not a typical set union operation, in that it does not eliminate - * duplicate items. As such, it is analagous to `UNION ALL` in SQL. + * duplicate items. As such, it is analogous to `UNION ALL` in SQL. * @since 1.6.0 */ def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) @@ -618,7 +619,6 @@ class Dataset[T] private[sql]( case _ => Alias(CreateStruct(rightOutput), "_2")() } - implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) withPlan[(T, U)](other) { (left, right) => @@ -697,11 +697,55 @@ class Dataset[T] private[sql]( */ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) + /** + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * @since 1.6.0 + */ + def persist(): this.type = { + sqlContext.cacheManager.cacheQuery(this) + this + } + + /** + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * @since 1.6.0 + */ + def cache(): this.type = persist() + + /** + * Persist this [[Dataset]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. + * @group basic + * @since 1.6.0 + */ + def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheManager.cacheQuery(this, None, newLevel) + this + } + + /** + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * @param blocking Whether to block until all blocks are deleted. + * @since 1.6.0 + */ + def unpersist(blocking: Boolean): this.type = { + sqlContext.cacheManager.tryUncacheQuery(this, blocking) + this + } + + /** + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * @since 1.6.0 + */ + def unpersist(): this.type = unpersist(blocking = false) + /* ******************** * * Internal Functions * * ******************** */ - private[sql] def logicalPlan = queryExecution.analyzed + private[sql] def logicalPlan: LogicalPlan = queryExecution.analyzed private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9cc65de19180a..4e26250868374 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -338,6 +338,15 @@ class SQLContext private[sql]( cacheManager.lookupCachedData(table(tableName)).nonEmpty } + /** + * Returns true if the [[Queryable]] is currently cached in-memory. + * @group cachemgmt + * @since 1.3.0 + */ + private[sql] def isCached(qName: Queryable): Boolean = { + cacheManager.lookupCachedData(qName).nonEmpty + } + /** * Caches the specified table in-memory. * @group cachemgmt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 293fcfe96e677..50f6562815c21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging -import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel @@ -75,12 +74,12 @@ private[sql] class CacheManager extends Logging { } /** - * Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike - * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing - * the in-memory columnar representation of the underlying table is expensive. + * Caches the data produced by the logical representation of the given [[Queryable]]. + * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because + * recomputing the in-memory columnar representation of the underlying table is expensive. */ private[sql] def cacheQuery( - query: DataFrame, + query: Queryable, tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed @@ -95,13 +94,13 @@ private[sql] class CacheManager extends Logging { sqlContext.conf.useCompression, sqlContext.conf.columnBatchSize, storageLevel, - sqlContext.executePlan(query.logicalPlan).executedPlan, + sqlContext.executePlan(planToCache).executedPlan, tableName)) } } - /** Removes the data for the given [[DataFrame]] from the cache */ - private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { + /** Removes the data for the given [[Queryable]] from the cache */ + private[sql] def uncacheQuery(query: Queryable, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") @@ -109,9 +108,11 @@ private[sql] class CacheManager extends Logging { cachedData.remove(dataIndex) } - /** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */ + /** Tries to remove the data for the given [[Queryable]] from the cache + * if it's cached + */ private[sql] def tryUncacheQuery( - query: DataFrame, + query: Queryable, blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -123,12 +124,12 @@ private[sql] class CacheManager extends Logging { found } - /** Optionally returns cached data for the given [[DataFrame]] */ - private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock { + /** Optionally returns cached data for the given [[Queryable]] */ + private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } - /** Optionally returns cached data for the given LogicalPlan. */ + /** Optionally returns cached data for the given [[LogicalPlan]]. */ private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { cachedData.find(cd => plan.sameResult(cd.plan)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala new file mode 100644 index 0000000000000..3a283a4e1f610 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + + +class DatasetCacheSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("persist and unpersist") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) + val cached = ds.cache() + // count triggers the caching action. It should not throw. + cached.count() + // Make sure, the Dataset is indeed cached. + assertCached(cached) + // Check result. + checkAnswer( + cached, + 2, 3, 4) + // Drop the cache. + cached.unpersist() + assert(!sqlContext.isCached(cached), "The Dataset should not be cached.") + } + + test("persist and then rebind right encoder when join 2 datasets") { + val ds1 = Seq("1", "2").toDS().as("a") + val ds2 = Seq(2, 3).toDS().as("b") + + ds1.persist() + assertCached(ds1) + ds2.persist() + assertCached(ds2) + + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") + checkAnswer(joined, ("2", 2)) + assertCached(joined, 2) + + ds1.unpersist() + assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.") + ds2.unpersist() + assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.") + } + + test("persist and then groupBy columns asKey, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1").keyAs[String] + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } + agged.persist() + + checkAnswer( + agged.filter(_._1 == "b"), + ("b", 3)) + assertCached(agged.filter(_._1 == "b")) + + ds.unpersist() + assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.") + agged.unpersist() + assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8f476dd0f99b6..bc22fb8b7bdb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.Queryable abstract class QueryTest extends PlanTest { @@ -163,9 +164,9 @@ abstract class QueryTest extends PlanTest { } /** - * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. + * Asserts that a given [[Queryable]] will be executed using the given number of cached results. */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { + def assertCached(query: Queryable, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached From 6f8fb6f712ea0a9170d0e7f9f3c4eb5b90d3faff Mon Sep 17 00:00:00 2001 From: woj-i Date: Tue, 1 Dec 2015 11:05:45 -0800 Subject: [PATCH 0992/1168] [SPARK-11821] Propagate Kerberos keytab for all environments andrewor14 the same PR as in branch 1.5 harishreedharan Author: woj-i Closes #9859 from woj-i/master. --- .../main/scala/org/apache/spark/deploy/SparkSubmit.scala | 4 ++++ docs/running-on-yarn.md | 4 ++-- docs/sql-programming-guide.md | 7 ++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 2e912b59afdb8..52d3ab34c1784 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -545,6 +545,10 @@ object SparkSubmit { if (args.isPython) { sysProps.put("spark.yarn.isPython", "true") } + } + + // assure a keytab is available from any place in a JVM + if (clusterManager == YARN || clusterManager == LOCAL) { if (args.principal != null) { require(args.keytab != null, "Keytab must be specified when principal is specified") if (!new File(args.keytab).exists()) { diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 925a1e0ba6fcf..06413f83c3a71 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -358,14 +358,14 @@ If you need a reference to the proper location to put log files in the YARN so t The full path to the file that contains the keytab for the principal specified above. This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, - for renewing the login tickets and the delegation tokens periodically. + for renewing the login tickets and the delegation tokens periodically. (Works also with the "local" master) spark.yarn.principal (none) - Principal to be used to login to KDC, while running on secure HDFS. + Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d7b205c2fa0df..7b1d97baa3823 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1614,7 +1614,8 @@ This command builds a new assembly jar that includes Hive. Note that this Hive a on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` (for security configuration), + `hdfs-site.xml` (for HDFS configuration) file in `conf/`. Please note when running the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the @@ -2028,7 +2029,7 @@ Beeline will ask you for a username and password. In non-secure mode, simply ent your machine and a blank password. For secure mode, please follow the instructions given in the [beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients). -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` and `hdfs-site.xml` files in `conf/`. You may also use the beeline script that comes with Hive. @@ -2053,7 +2054,7 @@ To start the Spark SQL CLI, run the following in the Spark directory: ./bin/spark-sql -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` and `hdfs-site.xml` files in `conf/`. You may run `./bin/spark-sql --help` for a complete list of all available options. From a46c68b6ebcd257cebe42c4e255a8a1a9f831787 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Dec 2015 11:49:20 -0800 Subject: [PATCH 0993/1168] [SPARK-12065] Upgrade Tachyon from 0.8.1 to 0.8.2 This commit upgrades the Tachyon dependency from 0.8.1 to 0.8.2. Author: Josh Rosen Closes #10054 from JoshRosen/upgrade-to-tachyon-0.8.2. --- core/pom.xml | 2 +- make-distribution.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 37e3f168ab374..61744bb5c7bf5 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -270,7 +270,7 @@ org.tachyonproject tachyon-client - 0.8.1 + 0.8.2 org.apache.hadoop diff --git a/make-distribution.sh b/make-distribution.sh index 7b417fe7cf619..e64ceb802464c 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -33,7 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.8.1" +TACHYON_VERSION="0.8.2" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="http://tachyon-project.org/downloads/files/${TACHYON_VERSION}/${TACHYON_TGZ}" From 7f4be9fae0d61c9103e482736ac364d6ae8831d0 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 1 Dec 2015 12:59:53 -0800 Subject: [PATCH 0994/1168] [SPARK-12030] Fix Platform.copyMemory to handle overlapping regions. This bug was exposed as memory corruption in Timsort which uses copyMemory to copy large regions that can overlap. The prior implementation did not handle this case half the time and always copied forward, resulting in the data being corrupt. Author: Nong Li Closes #10068 from nongli/spark-12030. --- .../org/apache/spark/unsafe/Platform.java | 27 ++++++-- .../spark/unsafe/PlatformUtilSuite.java | 61 +++++++++++++++++++ 2 files changed, 82 insertions(+), 6 deletions(-) create mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 1c16da982923b..0d6b215fe5aaf 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -107,12 +107,27 @@ public static void freeMemory(long address) { public static void copyMemory( Object src, long srcOffset, Object dst, long dstOffset, long length) { - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); - length -= size; - srcOffset += size; - dstOffset += size; + // Check if dstOffset is before or after srcOffset to determine if we should copy + // forward or backwards. This is necessary in case src and dst overlap. + if (dstOffset < srcOffset) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } else { + srcOffset += length; + dstOffset += length; + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + srcOffset -= size; + dstOffset -= size; + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + } + } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java new file mode 100644 index 0000000000000..693ec6ec58dbd --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import org.junit.Assert; +import org.junit.Test; + +public class PlatformUtilSuite { + + @Test + public void overlappingCopyMemory() { + byte[] data = new byte[3 * 1024 * 1024]; + int size = 2 * 1024 * 1024; + for (int i = 0; i < data.length; ++i) { + data[i] = (byte)i; + } + + Platform.copyMemory(data, Platform.BYTE_ARRAY_OFFSET, data, Platform.BYTE_ARRAY_OFFSET, size); + for (int i = 0; i < data.length; ++i) { + Assert.assertEquals((byte)i, data[i]); + } + + Platform.copyMemory( + data, + Platform.BYTE_ARRAY_OFFSET + 1, + data, + Platform.BYTE_ARRAY_OFFSET, + size); + for (int i = 0; i < size; ++i) { + Assert.assertEquals((byte)(i + 1), data[i]); + } + + for (int i = 0; i < data.length; ++i) { + data[i] = (byte)i; + } + Platform.copyMemory( + data, + Platform.BYTE_ARRAY_OFFSET, + data, + Platform.BYTE_ARRAY_OFFSET + 1, + size); + for (int i = 0; i < size; ++i) { + Assert.assertEquals((byte)i, data[i + 1]); + } + } +} From 3ff48ba6cdeb75d071e7b66715af2d076365afb2 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 1 Dec 2015 14:08:36 -0800 Subject: [PATCH 0995/1168] [SPARK-12004] Preserve the RDD partitioner through RDD checkpointing The solution is the save the RDD partitioner in a separate file in the RDD checkpoint directory. That is, `/_partitioner`. In most cases, whether the RDD partitioner was recovered or not, does not affect the correctness, only reduces performance. So this solution makes a best-effort attempt to save and recover the partitioner. If either fails, the checkpointing is not affected. This makes this patch safe and backward compatible. Author: Tathagata Das Closes #9983 from tdas/SPARK-12004. --- .../spark/rdd/ReliableCheckpointRDD.scala | 122 +++++++++++++++++- .../spark/rdd/ReliableRDDCheckpointData.scala | 21 +-- .../org/apache/spark/CheckpointSuite.scala | 61 ++++++++- 3 files changed, 173 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index a69be6a068bbf..fa71b8c26233d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -20,12 +20,12 @@ package org.apache.spark.rdd import java.io.IOException import scala.reflect.ClassTag +import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.{SerializableConfiguration, Utils} /** @@ -33,8 +33,9 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} */ private[spark] class ReliableCheckpointRDD[T: ClassTag]( sc: SparkContext, - val checkpointPath: String) - extends CheckpointRDD[T](sc) { + val checkpointPath: String, + _partitioner: Option[Partitioner] = None + ) extends CheckpointRDD[T](sc) { @transient private val hadoopConf = sc.hadoopConfiguration @transient private val cpath = new Path(checkpointPath) @@ -47,7 +48,13 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag]( /** * Return the path of the checkpoint directory this RDD reads data from. */ - override def getCheckpointFile: Option[String] = Some(checkpointPath) + override val getCheckpointFile: Option[String] = Some(checkpointPath) + + override val partitioner: Option[Partitioner] = { + _partitioner.orElse { + ReliableCheckpointRDD.readCheckpointedPartitionerFile(context, checkpointPath) + } + } /** * Return partitions described by the files in the checkpoint directory. @@ -100,10 +107,52 @@ private[spark] object ReliableCheckpointRDD extends Logging { "part-%05d".format(partitionIndex) } + private def checkpointPartitionerFileName(): String = { + "_partitioner" + } + + /** + * Write RDD to checkpoint files and return a ReliableCheckpointRDD representing the RDD. + */ + def writeRDDToCheckpointDirectory[T: ClassTag]( + originalRDD: RDD[T], + checkpointDir: String, + blockSize: Int = -1): ReliableCheckpointRDD[T] = { + + val sc = originalRDD.sparkContext + + // Create the output path for the checkpoint + val checkpointDirPath = new Path(checkpointDir) + val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration) + if (!fs.mkdirs(checkpointDirPath)) { + throw new SparkException(s"Failed to create checkpoint path $checkpointDirPath") + } + + // Save to file, and reload it as an RDD + val broadcastedConf = sc.broadcast( + new SerializableConfiguration(sc.hadoopConfiguration)) + // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) + sc.runJob(originalRDD, + writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _) + + if (originalRDD.partitioner.nonEmpty) { + writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath) + } + + val newRDD = new ReliableCheckpointRDD[T]( + sc, checkpointDirPath.toString, originalRDD.partitioner) + if (newRDD.partitions.length != originalRDD.partitions.length) { + throw new SparkException( + s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + + s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})") + } + newRDD + } + /** - * Write this partition's values to a checkpoint file. + * Write a RDD partition's data to a checkpoint file. */ - def writeCheckpointFile[T: ClassTag]( + def writePartitionToCheckpointFile[T: ClassTag]( path: String, broadcastedConf: Broadcast[SerializableConfiguration], blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { @@ -151,6 +200,67 @@ private[spark] object ReliableCheckpointRDD extends Logging { } } + /** + * Write a partitioner to the given RDD checkpoint directory. This is done on a best-effort + * basis; any exception while writing the partitioner is caught, logged and ignored. + */ + private def writePartitionerToCheckpointDir( + sc: SparkContext, partitioner: Partitioner, checkpointDirPath: Path): Unit = { + try { + val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName) + val bufferSize = sc.conf.getInt("spark.buffer.size", 65536) + val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration) + val fileOutputStream = fs.create(partitionerFilePath, false, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + Utils.tryWithSafeFinally { + serializeStream.writeObject(partitioner) + } { + serializeStream.close() + } + logDebug(s"Written partitioner to $partitionerFilePath") + } catch { + case NonFatal(e) => + logWarning(s"Error writing partitioner $partitioner to $checkpointDirPath") + } + } + + + /** + * Read a partitioner from the given RDD checkpoint directory, if it exists. + * This is done on a best-effort basis; any exception while reading the partitioner is + * caught, logged and ignored. + */ + private def readCheckpointedPartitionerFile( + sc: SparkContext, + checkpointDirPath: String): Option[Partitioner] = { + try { + val bufferSize = sc.conf.getInt("spark.buffer.size", 65536) + val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName) + val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration) + if (fs.exists(partitionerFilePath)) { + val fileInputStream = fs.open(partitionerFilePath, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + val partitioner = Utils.tryWithSafeFinally[Partitioner] { + deserializeStream.readObject[Partitioner] + } { + deserializeStream.close() + } + logDebug(s"Read partitioner from $partitionerFilePath") + Some(partitioner) + } else { + logDebug("No partitioner file") + None + } + } catch { + case NonFatal(e) => + logWarning(s"Error reading partitioner from $checkpointDirPath, " + + s"partitioner will not be recovered which may lead to performance loss", e) + None + } + } + /** * Read the content of the specified checkpoint file. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala index 91cad6662e4d2..cac6cbe780e91 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -55,25 +55,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v * This is called immediately after the first action invoked on this RDD has completed. */ protected override def doCheckpoint(): CheckpointRDD[T] = { - - // Create the output path for the checkpoint - val path = new Path(cpDir) - val fs = path.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(path)) { - throw new SparkException(s"Failed to create checkpoint path $cpDir") - } - - // Save to file, and reload it as an RDD - val broadcastedConf = rdd.context.broadcast( - new SerializableConfiguration(rdd.context.hadoopConfiguration)) - // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) - rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _) - val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir) - if (newRDD.partitions.length != rdd.partitions.length) { - throw new SparkException( - s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + - s"number of partitions from original RDD $rdd(${rdd.partitions.length})") - } + val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir) // Optionally clean our checkpoint files if the reference is out of scope if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { @@ -83,7 +65,6 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v } logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}") - newRDD } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index ab23326c6c25d..553d46285ac03 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,7 +21,8 @@ import java.io.File import scala.reflect.ClassTag -import org.apache.spark.CheckpointSuite._ +import org.apache.hadoop.fs.Path + import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils @@ -74,8 +75,10 @@ trait RDDCheckpointTester { self: SparkFunSuite => // Test whether the checkpoint file has been created if (reliableCheckpoint) { - assert( - collectFunc(sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) + assert(operatedRDD.getCheckpointFile.nonEmpty) + val recoveredRDD = sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get) + assert(collectFunc(recoveredRDD) === result) + assert(recoveredRDD.partitioner === operatedRDD.partitioner) } // Test whether dependencies have been changed from its earlier parent RDD @@ -211,9 +214,14 @@ trait RDDCheckpointTester { self: SparkFunSuite => } /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ - protected def runTest(name: String)(body: Boolean => Unit): Unit = { + protected def runTest( + name: String, + skipLocalCheckpoint: Boolean = false + )(body: Boolean => Unit): Unit = { test(name + " [reliable checkpoint]")(body(true)) - test(name + " [local checkpoint]")(body(false)) + if (!skipLocalCheckpoint) { + test(name + " [local checkpoint]")(body(false)) + } } /** @@ -264,6 +272,49 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS assert(flatMappedRDD.collect() === result) } + runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean => + + def testPartitionerCheckpointing( + partitioner: Partitioner, + corruptPartitionerFile: Boolean = false + ): Unit = { + val rddWithPartitioner = sc.makeRDD(1 to 4).map { _ -> 1 }.partitionBy(partitioner) + rddWithPartitioner.checkpoint() + rddWithPartitioner.count() + assert(rddWithPartitioner.getCheckpointFile.get.nonEmpty, + "checkpointing was not successful") + + if (corruptPartitionerFile) { + // Overwrite the partitioner file with garbage data + val checkpointDir = new Path(rddWithPartitioner.getCheckpointFile.get) + val fs = checkpointDir.getFileSystem(sc.hadoopConfiguration) + val partitionerFile = fs.listStatus(checkpointDir) + .find(_.getPath.getName.contains("partitioner")) + .map(_.getPath) + require(partitionerFile.nonEmpty, "could not find the partitioner file for testing") + val output = fs.create(partitionerFile.get, true) + output.write(100) + output.close() + } + + val newRDD = sc.checkpointFile[(Int, Int)](rddWithPartitioner.getCheckpointFile.get) + assert(newRDD.collect().toSet === rddWithPartitioner.collect().toSet, "RDD not recovered") + + if (!corruptPartitionerFile) { + assert(newRDD.partitioner != None, "partitioner not recovered") + assert(newRDD.partitioner === rddWithPartitioner.partitioner, + "recovered partitioner does not match") + } else { + assert(newRDD.partitioner == None, "partitioner unexpectedly recovered") + } + } + + testPartitionerCheckpointing(partitioner) + + // Test that corrupted partitioner file does not prevent recovery of RDD + testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true) + } + runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean => testRDD(_.map(x => x.toString), reliableCheckpoint) testRDD(_.flatMap(x => 1 to x), reliableCheckpoint) From e61fe19da8d86436caaff1295528a43384f190c6 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 1 Dec 2015 15:13:10 -0800 Subject: [PATCH 0996/1168] Revert "[SPARK-12060][CORE] Avoid memory copy in JavaSerializerInstance.serialize" This reverts commit 1401166576c7018c5f9c31e0a6703d5fb16ea339. --- .../spark/serializer/JavaSerializer.scala | 7 +++-- .../spark/util/ByteBufferOutputStream.scala | 31 ------------------- 2 files changed, 4 insertions(+), 34 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index ea718a0edbe71..b463a71d5bd7d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -24,7 +24,8 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} +import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.Utils private[spark] class JavaSerializationStream( out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) @@ -95,11 +96,11 @@ private[spark] class JavaSerializerInstance( extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bos = new ByteBufferOutputStream() + val bos = new ByteArrayOutputStream() val out = serializeStream(bos) out.writeObject(t) out.close() - bos.toByteBuffer + ByteBuffer.wrap(bos.toByteArray) } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala deleted file mode 100644 index 92e45224db81c..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.io.ByteArrayOutputStream -import java.nio.ByteBuffer - -/** - * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer - */ -private[spark] class ByteBufferOutputStream extends ByteArrayOutputStream { - - def toByteBuffer: ByteBuffer = { - return ByteBuffer.wrap(buf, 0, count) - } -} From cd6f3303bb591ed126740685d57f242ffaef4284 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 1 Dec 2015 15:21:53 -0800 Subject: [PATCH 0997/1168] [SPARK-11961][DOC] Add docs of ChiSqSelector https://issues.apache.org/jira/browse/SPARK-11961 Author: Xusen Yin Closes #9965 from yinxusen/SPARK-11961. --- docs/ml-features.md | 50 +++++++++++++ .../examples/ml/JavaChiSqSelectorExample.java | 71 +++++++++++++++++++ .../examples/ml/ChiSqSelectorExample.scala | 57 +++++++++++++++ 3 files changed, 178 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index cd1838d6d2882..5f888775553f2 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1949,3 +1949,53 @@ output.select("features", "label").show() {% endhighlight %} + +## ChiSqSelector + +`ChiSqSelector` stands for Chi-Squared feature selection. It operates on labeled data with +categorical features. ChiSqSelector orders features based on a +[Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) +from the class, and then filters (selects) the top features which the class label depends on the +most. This is akin to yielding the features with the most predictive power. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `features`, and `clicked`, which is used as +our target to be predicted: + +~~~ +id | features | clicked +---|-----------------------|--------- + 7 | [0.0, 0.0, 18.0, 1.0] | 1.0 + 8 | [0.0, 1.0, 12.0, 0.0] | 0.0 + 9 | [1.0, 0.0, 15.0, 0.1] | 0.0 +~~~ + +If we use `ChiSqSelector` with a `numTopFeatures = 1`, then according to our label `clicked` the +last column in our `features` chosen as the most useful feature: + +~~~ +id | features | clicked | selectedFeatures +---|-----------------------|---------|------------------ + 7 | [0.0, 0.0, 18.0, 1.0] | 1.0 | [1.0] + 8 | [0.0, 1.0, 12.0, 0.0] | 0.0 | [0.0] + 9 | [1.0, 0.0, 15.0, 0.1] | 0.0 | [0.1] +~~~ + +
    +
    + +Refer to the [ChiSqSelector Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ChiSqSelector) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala %} +
    + +
    + +Refer to the [ChiSqSelector Java docs](api/java/org/apache/spark/ml/feature/ChiSqSelector.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java %} +
    +
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java new file mode 100644 index 0000000000000..ede05d6e20118 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.ChiSqSelector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaChiSqSelectorExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaChiSqSelectorExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0), + RowFactory.create(8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0), + RowFactory.create(9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("clicked", DataTypes.DoubleType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + ChiSqSelector selector = new ChiSqSelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("clicked") + .setOutputCol("selectedFeatures"); + + DataFrame result = selector.fit(df).transform(df); + result.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala new file mode 100644 index 0000000000000..a8d2bc4907e80 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.ChiSqSelector +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object ChiSqSelectorExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("ChiSqSelectorExample") + val sc = new SparkContext(conf) + + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + // $example on$ + val data = Seq( + (7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0), + (8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0), + (9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) + ) + + val df = sc.parallelize(data).toDF("id", "features", "clicked") + + val selector = new ChiSqSelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("clicked") + .setOutputCol("selectedFeatures") + + val result = selector.fit(df).transform(df) + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println From 09fd7672a18d58c42389b924abfb644fc6cbbdec Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 1 Dec 2015 15:26:10 -0800 Subject: [PATCH 0998/1168] [SPARK-12002][STREAMING][PYSPARK] Fix python direct stream checkpoint recovery issue Fixed a minor race condition in #10017 Closes #10017 Author: jerryshao Author: Shixiong Zhu Closes #10074 from zsxwing/review-pr10017. --- python/pyspark/streaming/tests.py | 49 +++++++++++++++++++++++++++++++ python/pyspark/streaming/util.py | 13 ++++---- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a647e6bf39581..d50c6b8d4a428 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1149,6 +1149,55 @@ def test_topic_and_partition_equality(self): self.assertNotEqual(topic_and_partition_a, topic_and_partition_c) self.assertNotEqual(topic_and_partition_a, topic_and_partition_d) + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_transform_with_checkpoint(self): + """Test the Python direct Kafka stream transform with checkpoint correctly recovered.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + offsetRanges = [] + + def transformWithOffsetRanges(rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + return rdd + + self.ssc.stop(False) + self.ssc = None + tmpdir = "checkpoint-test-%d" % random.randint(0, 10000) + + def setup(): + ssc = StreamingContext(self.sc, 0.5) + ssc.checkpoint(tmpdir) + stream = KafkaUtils.createDirectStream(ssc, [topic], kafkaParams) + stream.transform(transformWithOffsetRanges).count().pprint() + return ssc + + try: + ssc1 = StreamingContext.getOrCreate(tmpdir, setup) + ssc1.start() + self.wait_for(offsetRanges, 1) + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + + # To make sure some checkpoint is written + time.sleep(3) + ssc1.stop(False) + ssc1 = None + + # Restart again to make sure the checkpoint is recovered correctly + ssc2 = StreamingContext.getOrCreate(tmpdir, setup) + ssc2.start() + ssc2.awaitTermination(3) + ssc2.stop(stopSparkContext=False, stopGraceFully=True) + ssc2 = None + finally: + shutil.rmtree(tmpdir) + @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_rdd_message_handler(self): """Test Python direct Kafka RDD MessageHandler.""" diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index c7f02bca2ae38..abbbf6eb9394f 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -37,11 +37,11 @@ def __init__(self, ctx, func, *deserializers): self.ctx = ctx self.func = func self.deserializers = deserializers - self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + self.rdd_wrap_func = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) self.failure = None def rdd_wrapper(self, func): - self._rdd_wrapper = func + self.rdd_wrap_func = func return self def call(self, milliseconds, jrdds): @@ -59,7 +59,7 @@ def call(self, milliseconds, jrdds): if len(sers) < len(jrdds): sers += (sers[0],) * (len(jrdds) - len(sers)) - rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None + rdds = [self.rdd_wrap_func(jrdd, self.ctx, ser) if jrdd else None for jrdd, ser in zip(jrdds, sers)] t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) @@ -101,7 +101,8 @@ def dumps(self, id): self.failure = None try: func = self.gateway.gateway_property.pool[id] - return bytearray(self.serializer.dumps((func.func, func.deserializers))) + return bytearray(self.serializer.dumps(( + func.func, func.rdd_wrap_func, func.deserializers))) except: self.failure = traceback.format_exc() @@ -109,8 +110,8 @@ def loads(self, data): # Clear the failure self.failure = None try: - f, deserializers = self.serializer.loads(bytes(data)) - return TransformFunction(self.ctx, f, *deserializers) + f, wrap_func, deserializers = self.serializer.loads(bytes(data)) + return TransformFunction(self.ctx, f, *deserializers).rdd_wrapper(wrap_func) except: self.failure = traceback.format_exc() From 8f0ce8bfd787df1678b9e2ceec01ab74a8112c18 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 2 Dec 2015 07:29:45 +0800 Subject: [PATCH 0999/1168] [SPARK-12075][SQL] Speed up HiveComparisionTest by avoiding / speeding up TestHive.reset() When profiling HiveCompatibilitySuite, I noticed that most of the time seems to be spent in expensive `TestHive.reset()` calls. This patch speeds up suites based on HiveComparisionTest, such as HiveCompatibilitySuite, with the following changes: - Avoid `TestHive.reset()` whenever possible: - Use a simple set of heuristics to guess whether we need to call `reset()` in between tests. - As a safety-net, automatically re-run failed tests by calling `reset()` before the re-attempt. - Speed up the expensive parts of `TestHive.reset()`: loading the `src` and `srcpart` tables took roughly 600ms per test, so we now avoid this by using a simple heuristic which only loads those tables by tests that reference them. This is based on simple string matching over the test queries which errs on the side of loading in more situations than might be strictly necessary. After these changes, HiveCompatibilitySuite seems to run in about 10 minutes. This PR is a revival of #6663, an earlier experimental PR from June, where I played around with several possible speedups for this suite. Author: Josh Rosen Closes #10055 from JoshRosen/speculative-testhive-reset. --- .../apache/spark/sql/hive/test/TestHive.scala | 7 -- .../hive/execution/HiveComparisonTest.scala | 67 +++++++++++++++++-- .../hive/execution/HiveQueryFileTest.scala | 2 +- 3 files changed, 62 insertions(+), 14 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 6883d305cbead..2e2d201bf254d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -443,13 +443,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { defaultOverrides() runSqlHive("USE default") - - // Just loading src makes a lot of tests pass. This is because some tests do something like - // drop an index on src at the beginning. Since we just pass DDL to hive this bypasses our - // Analyzer and thus the test table auto-loading mechanism. - // Remove after we handle more DDL operations natively. - loadTestTable("src") - loadTestTable("srcpart") } catch { case e: Exception => logError("FATAL ERROR: Failed to reset TestDB state.", e) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index aa95ba94fa873..4455430aa727a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -209,7 +209,11 @@ abstract class HiveComparisonTest } val installHooksCommand = "(?i)SET.*hooks".r - def createQueryTest(testCaseName: String, sql: String, reset: Boolean = true) { + def createQueryTest( + testCaseName: String, + sql: String, + reset: Boolean = true, + tryWithoutResettingFirst: Boolean = false) { // testCaseName must not contain ':', which is not allowed to appear in a filename of Windows assert(!testCaseName.contains(":")) @@ -240,9 +244,6 @@ abstract class HiveComparisonTest test(testCaseName) { logDebug(s"=== HIVE TEST: $testCaseName ===") - // Clear old output for this testcase. - outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) - val sqlWithoutComment = sql.split("\n").filterNot(l => l.matches("--.*(?<=[^\\\\]);")).mkString("\n") val allQueries = @@ -269,11 +270,32 @@ abstract class HiveComparisonTest }.mkString("\n== Console version of this test ==\n", "\n", "\n") } - try { + def doTest(reset: Boolean, isSpeculative: Boolean = false): Unit = { + // Clear old output for this testcase. + outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) + if (reset) { TestHive.reset() } + // Many tests drop indexes on src and srcpart at the beginning, so we need to load those + // tables here. Since DROP INDEX DDL is just passed to Hive, it bypasses the analyzer and + // thus the tables referenced in those DDL commands cannot be extracted for use by our + // test table auto-loading mechanism. In addition, the tests which use the SHOW TABLES + // command expect these tables to exist. + val hasShowTableCommand = queryList.exists(_.toLowerCase.contains("show tables")) + for (table <- Seq("src", "srcpart")) { + val hasMatchingQuery = queryList.exists { query => + val normalizedQuery = query.toLowerCase.stripSuffix(";") + normalizedQuery.endsWith(table) || + normalizedQuery.contains(s"from $table") || + normalizedQuery.contains(s"from default.$table") + } + if (hasShowTableCommand || hasMatchingQuery) { + TestHive.loadTestTable(table) + } + } + val hiveCacheFiles = queryList.zipWithIndex.map { case (queryString, i) => val cachedAnswerName = s"$testCaseName-$i-${getMd5(queryString)}" @@ -430,12 +452,45 @@ abstract class HiveComparisonTest """.stripMargin stringToFile(new File(wrongDirectory, testCaseName), errorMessage + consoleTestCase) - fail(errorMessage) + if (isSpeculative && !reset) { + fail("Failed on first run; retrying") + } else { + fail(errorMessage) + } } } // Touch passed file. new FileOutputStream(new File(passedDirectory, testCaseName)).close() + } + + val canSpeculativelyTryWithoutReset: Boolean = { + val excludedSubstrings = Seq( + "into table", + "create table", + "drop index" + ) + !queryList.map(_.toLowerCase).exists { query => + excludedSubstrings.exists(s => query.contains(s)) + } + } + + try { + try { + if (tryWithoutResettingFirst && canSpeculativelyTryWithoutReset) { + doTest(reset = false, isSpeculative = true) + } else { + doTest(reset) + } + } catch { + case tf: org.scalatest.exceptions.TestFailedException => + if (tryWithoutResettingFirst && canSpeculativelyTryWithoutReset) { + logWarning("Test failed without reset(); retrying with reset()") + doTest(reset = true) + } else { + throw tf + } + } } catch { case tf: org.scalatest.exceptions.TestFailedException => throw tf case originalException: Exception => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index f7b37dae0a5f3..f96c989c4614f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -59,7 +59,7 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) - createQueryTest(testCaseName, queriesString) + createQueryTest(testCaseName, queriesString, reset = true, tryWithoutResettingFirst = true) } else { // Only output warnings for the built in whitelist as this clutters the output when the user // trying to execute a single test from the commandline. From eb2d13a53d443a3715f7f90ab0bf040d9a875ae5 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 1 Dec 2015 15:30:21 -0800 Subject: [PATCH 1000/1168] [SPARK-11328][SQL] Improve error message when hitting this issue The issue is that the output commiter is not idempotent and retry attempts will fail because the output file already exists. It is not safe to clean up the file as this output committer is by design not retryable. Currently, the job fails with a confusing file exists error. This patch is a stop gap to tell the user to look at the top of the error log for the proper message. This is difficult to test locally as Spark is hardcoded not to retry. Manually verified by upping the retry attempts. Author: Nong Li Author: Nong Li Closes #10080 from nongli/spark-11328. --- .../datasources/WriterContainer.scala | 22 +++++++++++++++++-- .../DirectParquetOutputCommitter.scala | 3 ++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 1b59b19d9420d..ad55367258890 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -124,6 +124,24 @@ private[sql] abstract class BaseWriterContainer( } } + protected def newOutputWriter(path: String): OutputWriter = { + try { + outputWriterFactory.newInstance(path, dataSchema, taskAttemptContext) + } catch { + case e: org.apache.hadoop.fs.FileAlreadyExistsException => + if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) { + // Spark-11382: DirectParquetOutputCommitter is not idempotent, meaning on retry + // attempts, the task will fail because the output file is created from a prior attempt. + // This often means the most visible error to the user is misleading. Augment the error + // to tell the user to look for the actual error. + throw new SparkException("The output file already exists but this could be due to a " + + "failure from an earlier attempt. Look through the earlier logs or stage page for " + + "the first error.\n File exists error: " + e) + } + throw e + } + } + private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) @@ -234,7 +252,7 @@ private[sql] class DefaultWriterContainer( executorSideSetup(taskContext) val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) configuration.set("spark.sql.sources.output.path", outputPath) - val writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) + val writer = newOutputWriter(getWorkPath) writer.initConverter(dataSchema) var writerClosed = false @@ -403,7 +421,7 @@ private[sql] class DynamicPartitionWriterContainer( val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) configuration.set( "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - val newWriter = outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) + val newWriter = super.newOutputWriter(path.toString) newWriter.initConverter(dataSchema) newWriter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 300e8677b312f..1a4e99ff10afb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -41,7 +41,8 @@ import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetO * no safe way undo a failed appending job (that's why both `abortTask()` and `abortJob()` are * left empty). */ -private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) +private[datasources] class DirectParquetOutputCommitter( + outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { val LOG = Log.getLog(classOf[ParquetOutputCommitter]) From 90bcb153bd79eb6e1f24bd2ae2a8a924b8dbad47 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 1 Dec 2015 15:32:57 -0800 Subject: [PATCH 1001/1168] [SPARK-11788][SQL] surround timestamp/date value with quotes in JDBC data source When query the Timestamp or Date column like the following val filtered = jdbcdf.where($"TIMESTAMP_COLUMN" >= beg && $"TIMESTAMP_COLUMN" < end) The generated SQL query is "TIMESTAMP_COLUMN >= 2015-01-01 00:00:00.0" It should have quote around the Timestamp/Date value such as "TIMESTAMP_COLUMN >= '2015-01-01 00:00:00.0'" Author: Huaxin Gao Closes #9872 from huaxingao/spark-11788. --- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 4 +++- .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 11 +++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 57a8a044a37cd..392d3ed58e3ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} +import java.sql.{Connection, Date, DriverManager, ResultSet, ResultSetMetaData, SQLException, Timestamp} import java.util.Properties import scala.util.control.NonFatal @@ -267,6 +267,8 @@ private[sql] class JDBCRDD( */ private def compileValue(value: Any): Any = value match { case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "'" + timestampValue + "'" + case dateValue: Date => "'" + dateValue + "'" case _ => value } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d530b1a469ce2..8c24aa3151bc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -484,4 +484,15 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(h2.getTableExistsQuery(table) == defaultQuery) assert(derby.getTableExistsQuery(table) == defaultQuery) } + + test("Test DataFrame.where for Date and Timestamp") { + // Regression test for bug SPARK-11788 + val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543"); + val date = java.sql.Date.valueOf("1995-01-01") + val jdbcDf = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = jdbcDf.where($"B" > date && $"C" > timestamp).collect() + assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + assert(rows(0).getAs[java.sql.Timestamp](2) + === java.sql.Timestamp.valueOf("2002-02-20 11:22:33.543543")) + } } From 7606bb6009ee718fa17816b5f5149fc8fa3b87b9 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 1 Dec 2015 16:24:04 -0800 Subject: [PATCH 1002/1168] [SPARK-11352][SQL] Escape */ in the generated comments. https://issues.apache.org/jira/browse/SPARK-11352 Author: Yin Huai Closes #10072 from yhuai/SPARK-11352. --- .../spark/sql/catalyst/expressions/Expression.scala | 10 ++++++++-- .../catalyst/expressions/codegen/CodegenFallback.scala | 2 +- .../sql/catalyst/expressions/CodeGenerationSuite.scala | 9 +++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b55d3653a7158..4ee6542455a6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -95,7 +95,7 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated meaning the code to evaluated has already been added // as a function and called in advance. Just use it. - val code = s"/* $this */" + val code = s"/* ${this.toCommentSafeString} */" GeneratedExpressionCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") @@ -103,7 +103,7 @@ abstract class Expression extends TreeNode[Expression] { val ve = GeneratedExpressionCode("", isNull, primitive) ve.code = genCode(ctx, ve) // Add `this` in the comment. - ve.copy(s"/* $this */\n" + ve.code.trim) + ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) } } @@ -214,6 +214,12 @@ abstract class Expression extends TreeNode[Expression] { } override def toString: String = prettyName + flatArguments.mkString("(", ",", ")") + + /** + * Returns the string representation of this expression that is safe to be put in + * code comments of generated code. + */ + protected def toCommentSafeString: String = this.toString.replace("*/", "\\*\\/") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index a31574c251af5..26fb143d1e45c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -33,7 +33,7 @@ trait CodegenFallback extends Expression { ctx.references += this val objectTerm = ctx.freshName("obj") s""" - /* expression: ${this} */ + /* expression: ${this.toCommentSafeString} */ java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 002ed16dcfe7e..fe754240dcd67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -98,4 +98,13 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { unsafeRow.getStruct(3, 1).getStruct(0, 2).setInt(1, 4) assert(internalRow === internalRow2) } + + test("*/ in the data") { + // When */ appears in a comment block (i.e. in /**/), code gen will break. + // So, in Expression and CodegenFallback, we escape */ to \*\/. + checkEvaluation( + EqualTo(BoundReference(0, StringType, false), Literal.create("*/", StringType)), + true, + InternalRow(UTF8String.fromString("*/"))) + } } From 8871b24da393b55fdacbbff3d66e2caecc64d22b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 1 Dec 2015 17:18:45 -0800 Subject: [PATCH 1003/1168] [SPARK-11596][SQL] In TreeNode's argString, if a TreeNode is not a child of the current TreeNode, we should only return the simpleString. In TreeNode's argString, if a TreeNode is not a child of the current TreeNode, we will only return the simpleString. I tested the [following case provided by Cristian](https://issues.apache.org/jira/browse/SPARK-11596?focusedCommentId=15019241&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15019241). ``` val c = (1 to 20).foldLeft[Option[DataFrame]] (None) { (curr, idx) => println(s"PROCESSING >>>>>>>>>>> $idx") val df = sqlContext.sparkContext.parallelize((0 to 10).zipWithIndex).toDF("A", "B") val union = curr.map(_.unionAll(df)).getOrElse(df) union.cache() Some(union) } c.get.explain(true) ``` Without the change, `c.get.explain(true)` took 100s. With the change, `c.get.explain(true)` took 26ms. https://issues.apache.org/jira/browse/SPARK-11596 Author: Yin Huai Closes #10079 from yhuai/SPARK-11596. --- .../scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f1cea07976a37..ad2bd78430a68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -380,7 +380,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { case tn: TreeNode[_] if containsChild(tn) => Nil - case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil + case tn: TreeNode[_] => s"(${tn.simpleString})" :: Nil case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil case set: Set[_] => set.mkString("{", ",", "}") :: Nil From f4f27a593b9da4f93f102bb52323b472c9ba1c8d Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 1 Dec 2015 19:36:34 -0800 Subject: [PATCH 1004/1168] [SPARK-8414] Ensure context cleaner periodic cleanups Garbage collection triggers cleanups. If the driver JVM is huge and there is little memory pressure, we may never clean up shuffle files on executors. This is a problem for long-running applications (e.g. streaming). Author: Andrew Or Closes #10070 from andrewor14/periodic-gc. --- .../org/apache/spark/ContextCleaner.scala | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index d23c1533db758..bc732535fed87 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -18,12 +18,13 @@ package org.apache.spark import java.lang.ref.{ReferenceQueue, WeakReference} +import java.util.concurrent.{TimeUnit, ScheduledExecutorService} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Classes that represent cleaning tasks. @@ -66,6 +67,20 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val cleaningThread = new Thread() { override def run() { keepCleaning() }} + private val periodicGCService: ScheduledExecutorService = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("context-cleaner-periodic-gc") + + /** + * How often to trigger a garbage collection in this JVM. + * + * This context cleaner triggers cleanups only when weak references are garbage collected. + * In long-running applications with large driver JVMs, where there is little memory pressure + * on the driver, this may happen very occasionally or not at all. Not cleaning at all may + * lead to executors running out of disk space after a while. + */ + private val periodicGCInterval = + sc.conf.getTimeAsSeconds("spark.cleaner.periodicGC.interval", "30min") + /** * Whether the cleaning thread will block on cleanup tasks (other than shuffle, which * is controlled by the `spark.cleaner.referenceTracking.blocking.shuffle` parameter). @@ -104,6 +119,9 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { cleaningThread.setDaemon(true) cleaningThread.setName("Spark Context Cleaner") cleaningThread.start() + periodicGCService.scheduleAtFixedRate(new Runnable { + override def run(): Unit = System.gc() + }, periodicGCInterval, periodicGCInterval, TimeUnit.SECONDS) } /** @@ -119,6 +137,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { cleaningThread.interrupt() } cleaningThread.join() + periodicGCService.shutdown() } /** Register a RDD for cleanup when it is garbage collected. */ From 170a8d53210beb69f4c9d69a14d79fd7ae6c4a15 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 1 Dec 2015 19:51:12 -0800 Subject: [PATCH 1005/1168] [SPARK-12081] Make unified memory manager work with small heaps The existing `spark.memory.fraction` (default 0.75) gives the system 25% of the space to work with. For small heaps, this is not enough: e.g. default 1GB leaves only 250MB system memory. This is especially a problem in local mode, where the driver and executor are crammed in the same JVM. Members of the community have reported driver OOM's in such cases. **New proposal.** We now reserve 300MB before taking the 75%. For 1GB JVMs, this leaves `(1024 - 300) * 0.75 = 543MB` for execution and storage. This is proposal (1) listed in the [JIRA](https://issues.apache.org/jira/browse/SPARK-12081). Author: Andrew Or Closes #10081 from andrewor14/unified-memory-small-heaps. --- .../spark/memory/UnifiedMemoryManager.scala | 22 +++++++++++++++---- .../memory/UnifiedMemoryManagerSuite.scala | 20 +++++++++++++++++ docs/configuration.md | 4 ++-- docs/tuning.md | 2 +- 4 files changed, 41 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 8be5b05419094..48b4e23433e43 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -26,7 +26,7 @@ import org.apache.spark.storage.{BlockStatus, BlockId} * A [[MemoryManager]] that enforces a soft boundary between execution and storage such that * either side can borrow memory from the other. * - * The region shared between execution and storage is a fraction of the total heap space + * The region shared between execution and storage is a fraction of (the total heap space - 300MB) * configurable through `spark.memory.fraction` (default 0.75). The position of the boundary * within this space is further determined by `spark.memory.storageFraction` (default 0.5). * This means the size of the storage region is 0.75 * 0.5 = 0.375 of the heap space by default. @@ -48,7 +48,7 @@ import org.apache.spark.storage.{BlockStatus, BlockId} */ private[spark] class UnifiedMemoryManager private[memory] ( conf: SparkConf, - maxMemory: Long, + val maxMemory: Long, private val storageRegionSize: Long, numCores: Int) extends MemoryManager( @@ -130,6 +130,12 @@ private[spark] class UnifiedMemoryManager private[memory] ( object UnifiedMemoryManager { + // Set aside a fixed amount of memory for non-storage, non-execution purposes. + // This serves a function similar to `spark.memory.fraction`, but guarantees that we reserve + // sufficient memory for the system even for small heaps. E.g. if we have a 1GB JVM, then + // the memory used for execution and storage will be (1024 - 300) * 0.75 = 543MB by default. + private val RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024 + def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = { val maxMemory = getMaxMemory(conf) new UnifiedMemoryManager( @@ -144,8 +150,16 @@ object UnifiedMemoryManager { * Return the total amount of memory shared between execution and storage, in bytes. */ private def getMaxMemory(conf: SparkConf): Long = { - val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val systemMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val reservedMemory = conf.getLong("spark.testing.reservedMemory", + if (conf.contains("spark.testing")) 0 else RESERVED_SYSTEM_MEMORY_BYTES) + val minSystemMemory = reservedMemory * 1.5 + if (systemMemory < minSystemMemory) { + throw new IllegalArgumentException(s"System memory $systemMemory must " + + s"be at least $minSystemMemory. Please use a larger heap size.") + } + val usableMemory = systemMemory - reservedMemory val memoryFraction = conf.getDouble("spark.memory.fraction", 0.75) - (systemMaxMemory * memoryFraction).toLong + (usableMemory * memoryFraction).toLong } } diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 8cebe81c3bfff..e97c898a44783 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -182,4 +182,24 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assertEnsureFreeSpaceCalled(ms, 850L) } + test("small heap") { + val systemMemory = 1024 * 1024 + val reservedMemory = 300 * 1024 + val memoryFraction = 0.8 + val conf = new SparkConf() + .set("spark.memory.fraction", memoryFraction.toString) + .set("spark.testing.memory", systemMemory.toString) + .set("spark.testing.reservedMemory", reservedMemory.toString) + val mm = UnifiedMemoryManager(conf, numCores = 1) + val expectedMaxMemory = ((systemMemory - reservedMemory) * memoryFraction).toLong + assert(mm.maxMemory === expectedMaxMemory) + + // Try using a system memory that's too small + val conf2 = conf.clone().set("spark.testing.memory", (reservedMemory / 2).toString) + val exception = intercept[IllegalArgumentException] { + UnifiedMemoryManager(conf2, numCores = 1) + } + assert(exception.getMessage.contains("larger heap size")) + } + } diff --git a/docs/configuration.md b/docs/configuration.md index 741d6b2b37a87..c39b4890851bc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -719,8 +719,8 @@ Apart from these, the following properties are also available, and may be useful spark.memory.fraction 0.75 - Fraction of the heap space used for execution and storage. The lower this is, the more - frequently spills and cached data eviction occur. The purpose of this config is to set + Fraction of (heap space - 300MB) used for execution and storage. The lower this is, the + more frequently spills and cached data eviction occur. The purpose of this config is to set aside memory for internal metadata, user data structures, and imprecise size estimation in the case of sparse, unusually large records. Leaving this at the default value is recommended. For more detail, see diff --git a/docs/tuning.md b/docs/tuning.md index a8fe7a4532798..e73ed69ffbbf8 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -114,7 +114,7 @@ variety of workloads without requiring user expertise of how memory is divided i Although there are two relevant configurations, the typical user should not need to adjust them as the default values are applicable to most workloads: -* `spark.memory.fraction` expresses the size of `M` as a fraction of the total JVM heap space +* `spark.memory.fraction` expresses the size of `M` as a fraction of the (JVM heap space - 300MB) (default 0.75). The rest of the space (25%) is reserved for user data structures, internal metadata in Spark, and safeguarding against OOM errors in the case of sparse and unusually large records. From b70c651d5a3d56421081583cf114b2f851130ef3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 1 Dec 2015 20:17:12 -0800 Subject: [PATCH 1006/1168] [SPARK-12077][SQL] change the default plan for single distinct Use try to match the behavior for single distinct aggregation with Spark 1.5, but that's not scalable, we should be robust by default, have a flag to address performance regression for low cardinality aggregation. cc yhuai nongli Author: Davies Liu Closes #10075 from davies/agg_15. --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- .../scala/org/apache/spark/sql/execution/PlannerSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 5ef3a48c56a87..58adf64e49869 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -451,7 +451,7 @@ private[spark] object SQLConf { val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING = booleanConf("spark.sql.specializeSingleDistinctAggPlanning", - defaultValue = Some(true), + defaultValue = Some(false), isPublic = false, doc = "When true, if a query only has a single distinct column and it has " + "grouping expressions, we will use our planner rule to handle this distinct " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index dfec139985f73..a4626259b2823 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -44,10 +44,10 @@ class PlannerSuite extends SharedSQLContext { fail(s"Could query play aggregation query $query. Is it an aggregation query?")) val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - // For the new aggregation code path, there will be three aggregate operator for + // For the new aggregation code path, there will be four aggregate operator for // distinct aggregations. assert( - aggregations.size == 2 || aggregations.size == 3, + aggregations.size == 2 || aggregations.size == 4, s"The plan of query $query does not have partial aggregations.") } From 041df5050fa4ef24b27b1c88b5afbaeee595c798 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 1 Dec 2015 21:04:52 -0800 Subject: [PATCH 1007/1168] [SPARK-12087][STREAMING] Create new JobConf for every batch in saveAsHadoopFiles The JobConf object created in `DStream.saveAsHadoopFiles` is used concurrently in multiple places: * The JobConf is updated by `RDD.saveAsHadoopFile()` before the job is launched * The JobConf is serialized as part of the DStream checkpoints. These concurrent accesses (updating in one thread, while the another thread is serializing it) can lead to concurrentModidicationException in the underlying Java hashmap using in the internal Hadoop Configuration object. The solution is to create a new JobConf in every batch, that is updated by `RDD.saveAsHadoopFile()`, while the checkpointing serializes the original JobConf. Tests to be added in #9988 will fail reliably without this patch. Keeping this patch really small to make sure that it can be added to previous branches. Author: Tathagata Das Closes #10088 from tdas/SPARK-12087. --- .../apache/spark/streaming/dstream/PairDStreamFunctions.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index fb691eed27e32..2762309134eb1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -730,7 +730,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) val serializableConf = new SerializableJobConf(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) - rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, serializableConf.value) + rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, + new JobConf(serializableConf.value)) } self.foreachRDD(saveFunc) } From f5ef4d119e8216b964a317c20eaad33f9f4851db Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 Dec 2015 21:51:33 -0800 Subject: [PATCH 1008/1168] [SPARK-11949][SQL] Check bitmasks to set nullable property Following up #10038. We can use bitmasks to determine which grouping expressions need to be set as nullable. cc yhuai Author: Liang-Chi Hsieh Closes #10067 from viirya/fix-cube-following. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 765327c474e69..d3163dcd4db94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -224,10 +224,15 @@ class Analyzer( case other => Alias(other, other.toString)() } - // TODO: We need to use bitmasks to determine which grouping expressions need to be - // set as nullable. For example, if we have GROUPING SETS ((a,b), a), we do not need - // to change the nullability of a. - val attributeMap = groupByAliases.map(a => (a -> a.toAttribute.withNullability(true))).toMap + val nonNullBitmask = x.bitmasks.reduce(_ & _) + + val attributeMap = groupByAliases.zipWithIndex.map { case (a, idx) => + if ((nonNullBitmask & 1 << idx) == 0) { + (a -> a.toAttribute.withNullability(true)) + } else { + (a -> a.toAttribute) + } + }.toMap val aggregations: Seq[NamedExpression] = x.aggregations.map { // If an expression is an aggregate (contains a AggregateExpression) then we dont change From 1bca9186811a6601d34d615d192230e89ef436fb Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 1 Dec 2015 22:41:48 -0800 Subject: [PATCH 1009/1168] [SPARK-12090] [PYSPARK] consider shuffle in coalesce() Author: Davies Liu Closes #10090 from davies/fix_coalesce. --- python/pyspark/rdd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4b4d59647b2bc..00bb9a62e904a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2015,7 +2015,7 @@ def coalesce(self, numPartitions, shuffle=False): >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect() [[1, 2, 3, 4, 5]] """ - jrdd = self._jrdd.coalesce(numPartitions) + jrdd = self._jrdd.coalesce(numPartitions, shuffle) return RDD(jrdd, self.ctx, self._jrdd_deserializer) def zip(self, other): From 65278855958dff52f25f15ba90706974d7ddf0f0 Mon Sep 17 00:00:00 2001 From: Jeroen Schot Date: Wed, 2 Dec 2015 09:40:07 +0000 Subject: [PATCH 1010/1168] [SPARK-3580][CORE] Add Consistent Method To Get Number of RDD Partitions Across Different Languages I have tried to address all the comments in pull request https://github.com/apache/spark/pull/2447. Note that the second commit (using the new method in all internal code of all components) is quite intrusive and could be omitted. Author: Jeroen Schot Closes #9767 from schot/master. --- .../org/apache/spark/api/java/JavaRDDLike.scala | 5 +++++ core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 +++++++- .../test/java/org/apache/spark/JavaAPISuite.java | 13 +++++++++++++ .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 1 + project/MimaExcludes.scala | 4 ++++ 5 files changed, 30 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 1e9d4f1803a81..0e4d7dce0f2f5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -28,6 +28,7 @@ import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark._ +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap @@ -62,6 +63,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** Set of partitions in this RDD. */ def partitions: JList[Partition] = rdd.partitions.toSeq.asJava + /** Return the number of partitions in this RDD. */ + @Since("1.6.0") + def getNumPartitions: Int = rdd.getNumPartitions + /** The partitioner of this RDD. */ def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8b3731d935788..9fe9d83a705b2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.api.java.JavaRDD import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator @@ -242,6 +242,12 @@ abstract class RDD[T: ClassTag]( } } + /** + * Returns the number of partitions of this RDD. + */ + @Since("1.6.0") + final def getNumPartitions: Int = partitions.length + /** * Get the preferred locations of a partition, taking into account whether the * RDD is checkpointed. diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 4d4e9820500e7..11f1248c24d38 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -973,6 +973,19 @@ public Iterator call(Integer index, Iterator iter) { Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); } + @Test + public void getNumPartitions(){ + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); + JavaDoubleRDD rdd2 = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0), 2); + JavaPairRDD rdd3 = sc.parallelizePairs(Arrays.asList( + new Tuple2<>("a", 1), + new Tuple2<>("aa", 2), + new Tuple2<>("aaa", 3) + ), 2); + Assert.assertEquals(3, rdd1.getNumPartitions()); + Assert.assertEquals(2, rdd2.getNumPartitions()); + Assert.assertEquals(2, rdd3.getNumPartitions()); + } @Test public void repartition() { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 5f718ea9f7be1..46ed5c04f4338 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -34,6 +34,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(nums.getNumPartitions === 2) assert(nums.collect().toList === List(1, 2, 3, 4)) assert(nums.toLocalIterator.toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 566bfe8efb7a4..d3a3c0ceb68c8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -155,6 +155,10 @@ object MimaExcludes { "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor$") + ) ++ Seq( + // SPARK-3580 Add getNumPartitions method to JavaRDD + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") ) case v if v.startsWith("1.5") => Seq( From 1117281cb099d0d243e747509c4193e3eb901eb1 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 2 Dec 2015 09:36:12 -0800 Subject: [PATCH 1011/1168] [SPARK-12094][SQL] Prettier tree string for TreeNode When examining plans of complex queries with multiple joins, a pain point of mine is that, it's hard to immediately see the sibling node of a specific query plan node. This PR adds tree lines for the tree string of a `TreeNode`, so that the result can be visually more intuitive. Author: Cheng Lian Closes #10099 from liancheng/prettier-tree-string. --- .../spark/sql/catalyst/trees/TreeNode.scala | 31 ++++++++++++++++--- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index ad2bd78430a68..dfea583e01465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -393,7 +393,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ - def treeString: String = generateTreeString(0, new StringBuilder).toString + def treeString: String = generateTreeString(0, Nil, new StringBuilder).toString /** * Returns a string representation of the nodes in this tree, where each operator is numbered. @@ -419,12 +419,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } - /** Appends the string represent of this node and its children to the given StringBuilder. */ - protected def generateTreeString(depth: Int, builder: StringBuilder): StringBuilder = { - builder.append(" " * depth) + /** + * Appends the string represent of this node and its children to the given StringBuilder. + * + * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at + * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and + * `lastChildren` for the root node should be empty. + */ + protected def generateTreeString( + depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = { + if (depth > 0) { + lastChildren.init.foreach { isLast => + val prefixFragment = if (isLast) " " else ": " + builder.append(prefixFragment) + } + + val branch = if (lastChildren.last) "+- " else ":- " + builder.append(branch) + } + builder.append(simpleString) builder.append("\n") - children.foreach(_.generateTreeString(depth + 1, builder)) + + if (children.nonEmpty) { + children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) + children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) + } + builder } From 6876c50008d813a50de6c51bbdcf1517fdcd994a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 2 Dec 2015 13:44:01 -0800 Subject: [PATCH 1012/1168] [SPARK-12001] Allow partially-stopped StreamingContext to be completely stopped If `StreamingContext.stop()` is interrupted midway through the call, the context will be marked as stopped but certain state will have not been cleaned up. Because `state = STOPPED` will be set, subsequent `stop()` calls will be unable to finish stopping the context, preventing any new StreamingContexts from being created. This patch addresses this issue by only marking the context as `STOPPED` once the `stop()` has successfully completed which allows `stop()` to be called a second time in order to finish stopping the context in case the original `stop()` call was interrupted. I discovered this issue by examining logs from a failed Jenkins run in which this race condition occurred in `FailureSuite`, leaking an unstoppable context and causing all subsequent tests to fail. Author: Josh Rosen Closes #9982 from JoshRosen/SPARK-12001. --- .../spark/streaming/StreamingContext.scala | 49 ++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 6fb8ad38abcec..cf843e3e8b8ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -699,28 +699,33 @@ class StreamingContext private[streaming] ( " AsynchronousListenerBus") } synchronized { - try { - state match { - case INITIALIZED => - logWarning("StreamingContext has not been started yet") - case STOPPED => - logWarning("StreamingContext has already been stopped") - case ACTIVE => - scheduler.stop(stopGracefully) - // Removing the streamingSource to de-register the metrics on stop() - env.metricsSystem.removeSource(streamingSource) - uiTab.foreach(_.detach()) - StreamingContext.setActiveContext(null) - waiter.notifyStop() - if (shutdownHookRef != null) { - shutdownHookRefToRemove = shutdownHookRef - shutdownHookRef = null - } - logInfo("StreamingContext stopped successfully") - } - } finally { - // The state should always be Stopped after calling `stop()`, even if we haven't started yet - state = STOPPED + // The state should always be Stopped after calling `stop()`, even if we haven't started yet + state match { + case INITIALIZED => + logWarning("StreamingContext has not been started yet") + state = STOPPED + case STOPPED => + logWarning("StreamingContext has already been stopped") + state = STOPPED + case ACTIVE => + // It's important that we don't set state = STOPPED until the very end of this case, + // since we need to ensure that we're still able to call `stop()` to recover from + // a partially-stopped StreamingContext which resulted from this `stop()` call being + // interrupted. See SPARK-12001 for more details. Because the body of this case can be + // executed twice in the case of a partial stop, all methods called here need to be + // idempotent. + scheduler.stop(stopGracefully) + // Removing the streamingSource to de-register the metrics on stop() + env.metricsSystem.removeSource(streamingSource) + uiTab.foreach(_.detach()) + StreamingContext.setActiveContext(null) + waiter.notifyStop() + if (shutdownHookRef != null) { + shutdownHookRefToRemove = shutdownHookRef + shutdownHookRef = null + } + logInfo("StreamingContext stopped successfully") + state = STOPPED } } if (shutdownHookRefToRemove != null) { From 2e3c99468520b0499ef78e64d032d41abcbc5e6a Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 2 Dec 2015 14:15:54 -0800 Subject: [PATCH 1013/1168] [SPARK-10266][DOCUMENTATION, ML] Fixed @Since annotation for ml.tunning cc mengxr noel-smith I worked on this issues based on https://github.com/apache/spark/pull/8729. ehsanmok thank you for your contricution! Author: Yu ISHIKAWA Author: Ehsan M.Kermani Closes #9338 from yu-iskw/JIRA-10266. --- .../spark/ml/tuning/CrossValidator.scala | 34 ++++++++++++++----- .../spark/ml/tuning/ParamGridBuilder.scala | 14 ++++++-- .../ml/tuning/TrainValidationSplit.scala | 26 +++++++++++--- 3 files changed, 58 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 83a9048374267..5c09f1aaff80d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -19,18 +19,18 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS import org.apache.hadoop.fs.Path -import org.json4s.{JObject, DefaultFormats} import org.json4s.jackson.JsonMethods._ +import org.json4s.{DefaultFormats, JObject} -import org.apache.spark.ml.classification.OneVsRestParams -import org.apache.spark.ml.feature.RFormulaModel -import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ +import org.apache.spark.ml.classification.OneVsRestParams import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param._ -import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -58,26 +58,34 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { * :: Experimental :: * K-fold cross validation. */ +@Since("1.2.0") @Experimental -class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] +class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) + extends Estimator[CrossValidatorModel] with CrossValidatorParams with MLWritable with Logging { + @Since("1.2.0") def this() = this(Identifiable.randomUID("cv")) private val f2jBLAS = new F2jBLAS /** @group setParam */ + @Since("1.2.0") def setEstimator(value: Estimator[_]): this.type = set(estimator, value) /** @group setParam */ + @Since("1.2.0") def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) /** @group setParam */ + @Since("1.2.0") def setEvaluator(value: Evaluator): this.type = set(evaluator, value) /** @group setParam */ + @Since("1.2.0") def setNumFolds(value: Int): this.type = set(numFolds, value) + @Since("1.4.0") override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -116,10 +124,12 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + @Since("1.4.0") override def validateParams(): Unit = { super.validateParams() val est = $(estimator) @@ -128,6 +138,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } } + @Since("1.4.0") override def copy(extra: ParamMap): CrossValidator = { val copied = defaultCopy(extra).asInstanceOf[CrossValidator] if (copied.isDefined(estimator)) { @@ -308,26 +319,31 @@ object CrossValidator extends MLReadable[CrossValidator] { * @param avgMetrics Average cross-validation metrics for each paramMap in * [[CrossValidator.estimatorParamMaps]], in the corresponding order. */ +@Since("1.2.0") @Experimental class CrossValidatorModel private[ml] ( - override val uid: String, - val bestModel: Model[_], - val avgMetrics: Array[Double]) + @Since("1.4.0") override val uid: String, + @Since("1.2.0") val bestModel: Model[_], + @Since("1.5.0") val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { + @Since("1.4.0") override def validateParams(): Unit = { bestModel.validateParams() } + @Since("1.4.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + @Since("1.4.0") override def copy(extra: ParamMap): CrossValidatorModel = { val copied = new CrossValidatorModel( uid, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala index 98a8f0330ca45..b836d2a2340e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -20,21 +20,23 @@ package org.apache.spark.ml.tuning import scala.annotation.varargs import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ /** * :: Experimental :: * Builder for a param grid used in grid search-based model selection. */ +@Since("1.2.0") @Experimental -class ParamGridBuilder { +class ParamGridBuilder @Since("1.2.0") { private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] /** * Sets the given parameters in this grid to fixed values. */ + @Since("1.2.0") def baseOn(paramMap: ParamMap): this.type = { baseOn(paramMap.toSeq: _*) this @@ -43,6 +45,7 @@ class ParamGridBuilder { /** * Sets the given parameters in this grid to fixed values. */ + @Since("1.2.0") @varargs def baseOn(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => @@ -54,6 +57,7 @@ class ParamGridBuilder { /** * Adds a param with multiple values (overwrites if the input param exists). */ + @Since("1.2.0") def addGrid[T](param: Param[T], values: Iterable[T]): this.type = { paramGrid.put(param, values) this @@ -64,6 +68,7 @@ class ParamGridBuilder { /** * Adds a double param with multiple values. */ + @Since("1.2.0") def addGrid(param: DoubleParam, values: Array[Double]): this.type = { addGrid[Double](param, values) } @@ -71,6 +76,7 @@ class ParamGridBuilder { /** * Adds a int param with multiple values. */ + @Since("1.2.0") def addGrid(param: IntParam, values: Array[Int]): this.type = { addGrid[Int](param, values) } @@ -78,6 +84,7 @@ class ParamGridBuilder { /** * Adds a float param with multiple values. */ + @Since("1.2.0") def addGrid(param: FloatParam, values: Array[Float]): this.type = { addGrid[Float](param, values) } @@ -85,6 +92,7 @@ class ParamGridBuilder { /** * Adds a long param with multiple values. */ + @Since("1.2.0") def addGrid(param: LongParam, values: Array[Long]): this.type = { addGrid[Long](param, values) } @@ -92,6 +100,7 @@ class ParamGridBuilder { /** * Adds a boolean param with true and false. */ + @Since("1.2.0") def addGrid(param: BooleanParam): this.type = { addGrid[Boolean](param, Array(true, false)) } @@ -99,6 +108,7 @@ class ParamGridBuilder { /** * Builds and returns all combinations of parameters specified by the param grid. */ + @Since("1.2.0") def build(): Array[ParamMap] = { var paramMaps = Array(new ParamMap) paramGrid.foreach { case (param, values) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 73a14b8310157..adf06302047a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.tuning import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} @@ -51,24 +51,32 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { * and uses evaluation metric on the validation set to select the best model. * Similar to [[CrossValidator]], but only splits the set once. */ +@Since("1.5.0") @Experimental -class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel] +class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String) + extends Estimator[TrainValidationSplitModel] with TrainValidationSplitParams with Logging { + @Since("1.5.0") def this() = this(Identifiable.randomUID("tvs")) /** @group setParam */ + @Since("1.5.0") def setEstimator(value: Estimator[_]): this.type = set(estimator, value) /** @group setParam */ + @Since("1.5.0") def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) /** @group setParam */ + @Since("1.5.0") def setEvaluator(value: Evaluator): this.type = set(evaluator, value) /** @group setParam */ + @Since("1.5.0") def setTrainRatio(value: Double): this.type = set(trainRatio, value) + @Since("1.5.0") override def fit(dataset: DataFrame): TrainValidationSplitModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -108,10 +116,12 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + @Since("1.5.0") override def validateParams(): Unit = { super.validateParams() val est = $(estimator) @@ -120,6 +130,7 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali } } + @Since("1.5.0") override def copy(extra: ParamMap): TrainValidationSplit = { val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit] if (copied.isDefined(estimator)) { @@ -140,26 +151,31 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali * @param bestModel Estimator determined best model. * @param validationMetrics Evaluated validation metrics. */ +@Since("1.5.0") @Experimental class TrainValidationSplitModel private[ml] ( - override val uid: String, - val bestModel: Model[_], - val validationMetrics: Array[Double]) + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val bestModel: Model[_], + @Since("1.5.0") val validationMetrics: Array[Double]) extends Model[TrainValidationSplitModel] with TrainValidationSplitParams { + @Since("1.5.0") override def validateParams(): Unit = { bestModel.validateParams() } + @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + @Since("1.5.0") override def copy(extra: ParamMap): TrainValidationSplitModel = { val copied = new TrainValidationSplitModel ( uid, From 7148fb90d7f7d7b8db8b288741f0a3873d2ec810 Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Thu, 3 Dec 2015 08:48:49 +0800 Subject: [PATCH 1014/1168] [SPARK-12093][SQL] Fix the error of comment in DDLParser Author: Yadong Qi Closes #10096 from watermen/patch-1. --- .../apache/spark/sql/execution/datasources/DDLParser.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index 6969b423d01b9..f22508b21090c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -66,15 +66,15 @@ class DDLParser(parseQuery: String => LogicalPlan) protected def start: Parser[LogicalPlan] = ddl /** - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` * or - * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] + * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable(intField int, stringField string...) * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` * or - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` * AS SELECT ... From 7f4868d354dac6daf1cea822e809ff9ef782a843 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 2 Dec 2015 17:19:31 -0800 Subject: [PATCH 1015/1168] [SPARK-12000] do not specify arg types when reference a method in ScalaDoc This fixes SPARK-12000, verified on my local with JDK 7. It seems that `scaladoc` try to match method names and messed up with annotations. cc: JoshRosen jkbradley Author: Xiangrui Meng Closes #10114 from mengxr/SPARK-12000.2. --- .../org/apache/spark/mllib/clustering/BisectingKMeans.scala | 2 +- .../apache/spark/mllib/clustering/BisectingKMeansModel.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 29a7aa0bb63f2..82adfa6ffd596 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -214,7 +214,7 @@ class BisectingKMeans private ( } /** - * Java-friendly version of [[run(RDD[Vector])*]] + * Java-friendly version of [[run()]]. */ def run(data: JavaRDD[Vector]): BisectingKMeansModel = run(data.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 5015f1540d920..f942e5613ffaf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -64,7 +64,7 @@ class BisectingKMeansModel @Since("1.6.0") ( } /** - * Java-friendly version of [[predict(RDD[Vector])*]] + * Java-friendly version of [[predict()]]. */ @Since("1.6.0") def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = @@ -88,7 +88,7 @@ class BisectingKMeansModel @Since("1.6.0") ( } /** - * Java-friendly version of [[computeCost(RDD[Vector])*]]. + * Java-friendly version of [[computeCost()]]. */ @Since("1.6.0") def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd) From 73b7a6af3084487e4e0baf6cb25ba4b75c8d2089 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 3 Dec 2015 11:12:02 +0800 Subject: [PATCH 1016/1168] [SPARK-12082][FLAKY-TEST] Increase timeouts in NettyBlockTransferSecuritySuite We should try increasing a timeout in NettyBlockTransferSecuritySuite in order to reduce that suite's flakiness in Jenkins. Author: Josh Rosen Closes #10113 from JoshRosen/SPARK-12082. --- .../spark/network/netty/NettyBlockTransferSecuritySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 3940527fb874e..98da94139f7f8 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -148,7 +148,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } }) - Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS)) + Await.ready(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get } } From be50460352e36604d389b0dba0c75f45b21fc722 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 3 Dec 2015 11:21:24 +0800 Subject: [PATCH 1017/1168] [SPARK-12109][SQL] Expressions's simpleString should delegate to its toString. https://issues.apache.org/jira/browse/SPARK-12109 The change of https://issues.apache.org/jira/browse/SPARK-11596 exposed the problem. In the sql plan viz, the filter shows ![image](https://cloud.githubusercontent.com/assets/2072857/11547075/1a285230-9906-11e5-8481-2bb451e35ef1.png) After changes in this PR, the viz is back to normal. ![image](https://cloud.githubusercontent.com/assets/2072857/11547080/2bc570f4-9906-11e5-8897-3b3bff173276.png) Author: Yin Huai Closes #10111 from yhuai/SPARK-12109. --- .../org/apache/spark/sql/catalyst/expressions/Expression.scala | 3 ++- .../spark/sql/catalyst/expressions/windowExpressions.scala | 3 --- .../scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4ee6542455a6c..614f0c075fd23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -207,12 +207,13 @@ abstract class Expression extends TreeNode[Expression] { }.toString } - private def flatArguments = productIterator.flatMap { case t: Traversable[_] => t case single => single :: Nil } + override def simpleString: String = toString + override def toString: String = prettyName + flatArguments.mkString("(", ",", ")") /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 09ec0e333aa44..1680aa8252ecb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -71,9 +71,6 @@ case class WindowSpecDefinition( childrenResolved && checkInputDataTypes().isSuccess && frameSpecification.isInstanceOf[SpecifiedWindowFrame] - - override def toString: String = simpleString - override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index dfea583e01465..d838d845d20fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -380,7 +380,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { case tn: TreeNode[_] if containsChild(tn) => Nil - case tn: TreeNode[_] => s"(${tn.simpleString})" :: Nil + case tn: TreeNode[_] => s"${tn.simpleString}" :: Nil case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil case set: Set[_] => set.mkString("{", ",", "}") :: Nil From 4eefdf6e204f4cd286747da9d4b15742409cf9bb Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 3 Dec 2015 08:42:21 +0000 Subject: [PATCH 1018/1168] =?UTF-8?q?[SPARK-12088][SQL]=20check=20connecti?= =?UTF-8?q?on.isClosed=20before=20calling=20connection=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In Java Spec java.sql.Connection, it has boolean getAutoCommit() throws SQLException Throws: SQLException - if a database access error occurs or this method is called on a closed connection So if conn.getAutoCommit is called on a closed connection, a SQLException will be thrown. Even though the code catch the SQLException and program can continue, I think we should check conn.isClosed before calling conn.getAutoCommit to avoid the unnecessary SQLException. Author: Huaxin Gao Closes #10095 from huaxingao/spark-12088. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 392d3ed58e3ce..b9dd7f6b4099b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -498,7 +498,7 @@ private[sql] class JDBCRDD( } try { if (null != conn) { - if (!conn.getAutoCommit && !conn.isClosed) { + if (!conn.isClosed && !conn.getAutoCommit) { try { conn.commit() } catch { From 6d5a6f91fd196f67830edb1b2188a647ab6728ba Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 3 Dec 2015 15:36:28 +0000 Subject: [PATCH 1019/1168] [DOCUMENTATION][MLLIB] typo in mllib doc \cc mengxr Author: Jeff Zhang Closes #10093 from zjffdu/mllib_typo. --- docs/ml-features.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 5f888775553f2..05c2c96c5ec5a 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1232,7 +1232,7 @@ lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) * `withStd`: True by default. Scales the data to unit standard deviation. * `withMean`: False by default. Centers the data with mean before scaling. It will build a dense output, so this does not work on sparse input and will raise an exception. -`StandardScaler` is a `Model` which can be `fit` on a dataset to produce a `StandardScalerModel`; this amounts to computing summary statistics. The model can then transform a `Vector` column in a dataset to have unit standard deviation and/or zero mean features. +`StandardScaler` is an `Estimator` which can be `fit` on a dataset to produce a `StandardScalerModel`; this amounts to computing summary statistics. The model can then transform a `Vector` column in a dataset to have unit standard deviation and/or zero mean features. Note that if the standard deviation of a feature is zero, it will return default `0.0` value in the `Vector` for that feature. From c1c5f56b3ba5507b2c94d1c5d0ab4a81399578f8 Mon Sep 17 00:00:00 2001 From: microwishing Date: Thu, 3 Dec 2015 16:09:05 +0000 Subject: [PATCH 1020/1168] [DOCUMENTATION][KAFKA] fix typo in kafka/OffsetRange.scala this is to fix some typo in external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala Author: microwishing Closes #10121 from microwishing/master. --- .../scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala | 2 +- .../scala/org/apache/spark/streaming/kafka/OffsetRange.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index 86394ea8a685e..45a6982b9afe5 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -151,7 +151,7 @@ private[kafka] class KafkaTestUtils extends Logging { } } - /** Create a Kafka topic and wait until it propagated to the whole cluster */ + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ def createTopic(topic: String): Unit = { AdminUtils.createTopic(zkClient, topic, 1, 1) // wait until metadata is propagated diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 8a5f371494511..d9b856e4697a0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.kafka import kafka.common.TopicAndPartition /** - * Represents any object that has a collection of [[OffsetRange]]s. This can be used access the + * Represents any object that has a collection of [[OffsetRange]]s. This can be used to access the * offset ranges in RDDs generated by the direct Kafka DStream (see * [[KafkaUtils.createDirectStream()]]). * {{{ From a67fbee1830c13f59c84075d6acbf12767bf29ec Mon Sep 17 00:00:00 2001 From: felixcheung Date: Thu, 3 Dec 2015 09:22:21 -0800 Subject: [PATCH 1021/1168] [SPARK-12116][SPARKR][DOCS] document how to workaround function name conflicts with dplyr shivaram Author: felixcheung Closes #10119 from felixcheung/rdocdplyrmasked. --- docs/sparkr.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/sparkr.md b/docs/sparkr.md index cfb9b41350f45..01148786b79d7 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -384,5 +384,6 @@ The following functions are masked by the SparkR package: +Since part of SparkR is modeled on the `dplyr` package, certain functions in SparkR share the same names with those in `dplyr`. Depending on the load order of the two packages, some functions from the package loaded first are masked by those in the package loaded after. In such case, prefix such calls with the package name, for instance, `SparkR::cume_dist(x)` or `dplyr::cume_dist(x)`. + You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) - From 3ed25b4db7e530dfeb6c8db7751d1f6f49f0a216 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Thu, 3 Dec 2015 10:33:06 -0800 Subject: [PATCH 1022/1168] [SPARK-11314][YARN] add service API and test service for Yarn Cluster schedulers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is purely the yarn/src/main and yarn/src/test bits of the YARN ATS integration: the extension model to load and run implementations of `SchedulerExtensionService` in the yarn cluster scheduler process —and to stop them afterwards. There's duplication between the two schedulers, yarn-client and yarn-cluster, at least in terms of setting everything up, because the common superclass, `YarnSchedulerBackend` is in spark-core, and the extension services need the YARN app/attempt IDs. If you look at how the the extension services are loaded, the case class `SchedulerExtensionServiceBinding` is used to pass in config info -currently just the spark context and the yarn IDs, of which one, the attemptID, will be null when running client-side. I'm passing in a case class to ensure that it would be possible in future to add extra arguments to the binding class, yet, as the method signature will not have changed, still be able to load existing services. There's no functional extension service here, just one for testing. The real tests come in the bigger pull requests. At the same time, there's no restriction of this extension service purely to the ATS history publisher. Anything else that wants to listen to the spark context and publish events could use this, and I'd also consider writing one for the YARN-913 registry service, so that the URLs of the web UI would be locatable through that (low priority; would make more sense if integrated with a REST client). There's no minicluster test. Given the test execution overhead of setting up minicluster tests, it'd probably be better to add an extension service into one of the existing tests. Author: Steve Loughran Closes #9182 from steveloughran/stevel/feature/SPARK-1537-service. --- .../spark/deploy/yarn/ApplicationMaster.scala | 8 + .../cluster/SchedulerExtensionService.scala | 154 ++++++++++++++++++ .../cluster/YarnClientSchedulerBackend.scala | 22 +-- .../cluster/YarnClusterSchedulerBackend.scala | 20 +-- .../cluster/YarnSchedulerBackend.scala | 70 +++++++- .../ExtensionServiceIntegrationSuite.scala | 71 ++++++++ .../cluster/SimpleExtensionService.scala | 34 ++++ .../cluster/StubApplicationAttemptId.scala | 48 ++++++ .../scheduler/cluster/StubApplicationId.scala | 42 +++++ 9 files changed, 431 insertions(+), 38 deletions(-) create mode 100644 yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala rename {core => yarn}/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala (81%) create mode 100644 yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala create mode 100644 yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala create mode 100644 yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala create mode 100644 yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 50ae7ffeec4c5..13ef4dfd64165 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -117,6 +117,10 @@ private[spark] class ApplicationMaster( private var delegationTokenRenewerOption: Option[AMDelegationTokenRenewer] = None + def getAttemptId(): ApplicationAttemptId = { + client.getAttemptId() + } + final def run(): Int = { try { val appAttemptId = client.getAttemptId() @@ -662,6 +666,10 @@ object ApplicationMaster extends Logging { master.sparkContextStopped(sc) } + private[spark] def getAttemptId(): ApplicationAttemptId = { + master.getAttemptId + } + } /** diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala new file mode 100644 index 0000000000000..c064521845399 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.util.Utils + +/** + * An extension service that can be loaded into a Spark YARN scheduler. + * A Service that can be started and stopped. + * + * 1. For implementations to be loadable by `SchedulerExtensionServices`, + * they must provide an empty constructor. + * 2. The `stop()` operation MUST be idempotent, and succeed even if `start()` was + * never invoked. + */ +trait SchedulerExtensionService { + + /** + * Start the extension service. This should be a no-op if + * called more than once. + * @param binding binding to the spark application and YARN + */ + def start(binding: SchedulerExtensionServiceBinding): Unit + + /** + * Stop the service + * The `stop()` operation MUST be idempotent, and succeed even if `start()` was + * never invoked. + */ + def stop(): Unit +} + +/** + * Binding information for a [[SchedulerExtensionService]]. + * + * The attempt ID will be set if the service is started within a YARN application master; + * there is then a different attempt ID for every time that AM is restarted. + * When the service binding is instantiated in client mode, there's no attempt ID, as it lacks + * this information. + * @param sparkContext current spark context + * @param applicationId YARN application ID + * @param attemptId YARN attemptID. This will always be unset in client mode, and always set in + * cluster mode. + */ +case class SchedulerExtensionServiceBinding( + sparkContext: SparkContext, + applicationId: ApplicationId, + attemptId: Option[ApplicationAttemptId] = None) + +/** + * Container for [[SchedulerExtensionService]] instances. + * + * Loads Extension Services from the configuration property + * `"spark.yarn.services"`, instantiates and starts them. + * When stopped, it stops all child entries. + * + * The order in which child extension services are started and stopped + * is undefined. + */ +private[spark] class SchedulerExtensionServices extends SchedulerExtensionService + with Logging { + private var serviceOption: Option[String] = None + private var services: List[SchedulerExtensionService] = Nil + private val started = new AtomicBoolean(false) + private var binding: SchedulerExtensionServiceBinding = _ + + /** + * Binding operation will load the named services and call bind on them too; the + * entire set of services are then ready for `init()` and `start()` calls. + * + * @param binding binding to the spark application and YARN + */ + def start(binding: SchedulerExtensionServiceBinding): Unit = { + if (started.getAndSet(true)) { + logWarning("Ignoring re-entrant start operation") + return + } + require(binding.sparkContext != null, "Null context parameter") + require(binding.applicationId != null, "Null appId parameter") + this.binding = binding + val sparkContext = binding.sparkContext + val appId = binding.applicationId + val attemptId = binding.attemptId + logInfo(s"Starting Yarn extension services with app $appId and attemptId $attemptId") + + serviceOption = sparkContext.getConf.getOption(SchedulerExtensionServices.SPARK_YARN_SERVICES) + services = serviceOption + .map { s => + s.split(",").map(_.trim()).filter(!_.isEmpty) + .map { sClass => + val instance = Utils.classForName(sClass) + .newInstance() + .asInstanceOf[SchedulerExtensionService] + // bind this service + instance.start(binding) + logInfo(s"Service $sClass started") + instance + }.toList + }.getOrElse(Nil) + } + + /** + * Get the list of services. + * + * @return a list of services; Nil until the service is started + */ + def getServices: List[SchedulerExtensionService] = services + + /** + * Stop the services; idempotent. + * + */ + override def stop(): Unit = { + if (started.getAndSet(false)) { + logInfo(s"Stopping $this") + services.foreach { s => + Utils.tryLogNonFatalError(s.stop()) + } + } + } + + override def toString(): String = s"""SchedulerExtensionServices + |(serviceOption=$serviceOption, + | services=$services, + | started=$started)""".stripMargin +} + +private[spark] object SchedulerExtensionServices { + + /** + * A list of comma separated services to instantiate in the scheduler + */ + val SPARK_YARN_SERVICES = "spark.yarn.services" +} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 20771f655473c..0e27a2665e939 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} +import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} @@ -33,7 +33,6 @@ private[spark] class YarnClientSchedulerBackend( with Logging { private var client: Client = null - private var appId: ApplicationId = null private var monitorThread: MonitorThread = null /** @@ -54,13 +53,12 @@ private[spark] class YarnClientSchedulerBackend( val args = new ClientArguments(argsArrayBuf.toArray, conf) totalExpectedExecutors = args.numExecutors client = new Client(args, conf) - appId = client.submitApplication() + bindToYarn(client.submitApplication(), None) // SPARK-8687: Ensure all necessary properties have already been set before // we initialize our driver scheduler backend, which serves these properties // to the executors super.start() - waitForApplication() // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver @@ -116,8 +114,8 @@ private[spark] class YarnClientSchedulerBackend( * This assumes both `client` and `appId` have already been set. */ private def waitForApplication(): Unit = { - assert(client != null && appId != null, "Application has not been submitted yet!") - val (state, _) = client.monitorApplication(appId, returnOnRunning = true) // blocking + assert(client != null && appId.isDefined, "Application has not been submitted yet!") + val (state, _) = client.monitorApplication(appId.get, returnOnRunning = true) // blocking if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { @@ -125,7 +123,7 @@ private[spark] class YarnClientSchedulerBackend( "It might have been killed or unable to launch application master.") } if (state == YarnApplicationState.RUNNING) { - logInfo(s"Application $appId has started running.") + logInfo(s"Application ${appId.get} has started running.") } } @@ -141,7 +139,7 @@ private[spark] class YarnClientSchedulerBackend( override def run() { try { - val (state, _) = client.monitorApplication(appId, logApplicationReport = false) + val (state, _) = client.monitorApplication(appId.get, logApplicationReport = false) logError(s"Yarn application has already exited with state $state!") allowInterrupt = false sc.stop() @@ -163,7 +161,7 @@ private[spark] class YarnClientSchedulerBackend( * This assumes both `client` and `appId` have already been set. */ private def asyncMonitorApplication(): MonitorThread = { - assert(client != null && appId != null, "Application has not been submitted yet!") + assert(client != null && appId.isDefined, "Application has not been submitted yet!") val t = new MonitorThread t.setName("Yarn application state monitor") t.setDaemon(true) @@ -193,10 +191,4 @@ private[spark] class YarnClientSchedulerBackend( logInfo("Stopped") } - override def applicationId(): String = { - Option(appId).map(_.toString).getOrElse { - logWarning("Application ID is not initialized yet.") - super.applicationId - } - } } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 50b699f11b21c..ced597bed36d9 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.SparkContext -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil +import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils @@ -31,26 +31,12 @@ private[spark] class YarnClusterSchedulerBackend( extends YarnSchedulerBackend(scheduler, sc) { override def start() { + val attemptId = ApplicationMaster.getAttemptId + bindToYarn(attemptId.getApplicationId(), Some(attemptId)) super.start() totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf) } - override def applicationId(): String = - // In YARN Cluster mode, the application ID is expected to be set, so log an error if it's - // not found. - sc.getConf.getOption("spark.yarn.app.id").getOrElse { - logError("Application ID is not set.") - super.applicationId - } - - override def applicationAttemptId(): Option[String] = - // In YARN Cluster mode, the attempt ID is expected to be set, so log an error if it's - // not found. - sc.getConf.getOption("spark.yarn.app.attemptId").orElse { - logError("Application attempt ID is not set.") - super.applicationAttemptId - } - override def getDriverLogUrls: Option[Map[String, String]] = { var driverLogs: Option[Map[String, String]] = None try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala similarity index 81% rename from core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala rename to yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 80da37b09b590..e3dd87798f018 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -17,17 +17,17 @@ package org.apache.spark.scheduler.cluster -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{Future, ExecutionContext} +import scala.concurrent.{ExecutionContext, Future} +import scala.util.control.NonFatal + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.{Logging, SparkContext} import org.apache.spark.rpc._ -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.ui.JettyUtils -import org.apache.spark.util.{ThreadUtils, RpcUtils} - -import scala.util.control.NonFatal +import org.apache.spark.util.{RpcUtils, ThreadUtils} /** * Abstract Yarn scheduler backend that contains common logic @@ -51,6 +51,64 @@ private[spark] abstract class YarnSchedulerBackend( private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) + /** Application ID. */ + protected var appId: Option[ApplicationId] = None + + /** Attempt ID. This is unset for client-mode schedulers */ + private var attemptId: Option[ApplicationAttemptId] = None + + /** Scheduler extension services. */ + private val services: SchedulerExtensionServices = new SchedulerExtensionServices() + + /** + * Bind to YARN. This *must* be done before calling [[start()]]. + * + * @param appId YARN application ID + * @param attemptId Optional YARN attempt ID + */ + protected def bindToYarn(appId: ApplicationId, attemptId: Option[ApplicationAttemptId]): Unit = { + this.appId = Some(appId) + this.attemptId = attemptId + } + + override def start() { + require(appId.isDefined, "application ID unset") + val binding = SchedulerExtensionServiceBinding(sc, appId.get, attemptId) + services.start(binding) + super.start() + } + + override def stop(): Unit = { + try { + super.stop() + } finally { + services.stop() + } + } + + /** + * Get the attempt ID for this run, if the cluster manager supports multiple + * attempts. Applications run in client mode will not have attempt IDs. + * + * @return The application attempt id, if available. + */ + override def applicationAttemptId(): Option[String] = { + attemptId.map(_.toString) + } + + /** + * Get an application ID associated with the job. + * This returns the string value of [[appId]] if set, otherwise + * the locally-generated ID from the superclass. + * @return The application ID + */ + override def applicationId(): String = { + appId.map(_.toString).getOrElse { + logWarning("Application ID is not initialized yet.") + super.applicationId + } + } + /** * Request executors from the ApplicationMaster by specifying the total number desired. * This includes executors already pending or running. diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala new file mode 100644 index 0000000000000..b4d1b0a3d22a7 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{LocalSparkContext, Logging, SparkConf, SparkContext, SparkFunSuite} + +/** + * Test the integration with [[SchedulerExtensionServices]] + */ +class ExtensionServiceIntegrationSuite extends SparkFunSuite + with LocalSparkContext with BeforeAndAfter + with Logging { + + val applicationId = new StubApplicationId(0, 1111L) + val attemptId = new StubApplicationAttemptId(applicationId, 1) + + /* + * Setup phase creates the spark context + */ + before { + val sparkConf = new SparkConf() + sparkConf.set(SchedulerExtensionServices.SPARK_YARN_SERVICES, + classOf[SimpleExtensionService].getName()) + sparkConf.setMaster("local").setAppName("ExtensionServiceIntegrationSuite") + sc = new SparkContext(sparkConf) + } + + test("Instantiate") { + val services = new SchedulerExtensionServices() + assertResult(Nil, "non-nil service list") { + services.getServices + } + services.start(SchedulerExtensionServiceBinding(sc, applicationId)) + services.stop() + } + + test("Contains SimpleExtensionService Service") { + val services = new SchedulerExtensionServices() + try { + services.start(SchedulerExtensionServiceBinding(sc, applicationId)) + val serviceList = services.getServices + assert(serviceList.nonEmpty, "empty service list") + val (service :: Nil) = serviceList + val simpleService = service.asInstanceOf[SimpleExtensionService] + assert(simpleService.started.get, "service not started") + services.stop() + assert(!simpleService.started.get, "service not stopped") + } finally { + services.stop() + } + } +} + + diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala new file mode 100644 index 0000000000000..9b8c98cda8da8 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.util.concurrent.atomic.AtomicBoolean + +private[spark] class SimpleExtensionService extends SchedulerExtensionService { + + /** started flag; set in the `start()` call, stopped in `stop()`. */ + val started = new AtomicBoolean(false) + + override def start(binding: SchedulerExtensionServiceBinding): Unit = { + started.set(true) + } + + override def stop(): Unit = { + started.set(false) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala new file mode 100644 index 0000000000000..4b57b9509a655 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} + +/** + * A stub application ID; can be set in constructor and/or updated later. + * @param applicationId application ID + * @param attempt an attempt counter + */ +class StubApplicationAttemptId(var applicationId: ApplicationId, var attempt: Int) + extends ApplicationAttemptId { + + override def setApplicationId(appID: ApplicationId): Unit = { + applicationId = appID + } + + override def getAttemptId: Int = { + attempt + } + + override def setAttemptId(attemptId: Int): Unit = { + attempt = attemptId + } + + override def getApplicationId: ApplicationId = { + applicationId + } + + override def build(): Unit = { + } +} diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala new file mode 100644 index 0000000000000..bffa0e09befd2 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.yarn.api.records.ApplicationId + +/** + * Simple Testing Application Id; ID and cluster timestamp are set in constructor + * and cannot be updated. + * @param id app id + * @param clusterTimestamp timestamp + */ +private[spark] class StubApplicationId(id: Int, clusterTimestamp: Long) extends ApplicationId { + override def getId: Int = { + id + } + + override def getClusterTimestamp: Long = { + clusterTimestamp + } + + override def setId(id: Int): Unit = {} + + override def setClusterTimestamp(clusterTimestamp: Long): Unit = {} + + override def build(): Unit = {} +} From 2bb1336ae731853de633f4d8c28f007203d403f0 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 3 Dec 2015 11:05:12 -0800 Subject: [PATCH 1023/1168] [SPARK-12059][CORE] Avoid assertion error when unexpected state transition met in Master Downgrade to warning log for unexpected state transition. andrewor14 please review, thanks a lot. Author: jerryshao Closes #10091 from jerryshao/SPARK-12059. --- .../main/scala/org/apache/spark/deploy/master/Master.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 1355e1ad1b523..04b20e0d6ab9c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -257,8 +257,9 @@ private[deploy] class Master( exec.state = state if (state == ExecutorState.RUNNING) { - assert(oldState == ExecutorState.LAUNCHING, - s"executor $execId state transfer from $oldState to RUNNING is illegal") + if (oldState != ExecutorState.LAUNCHING) { + logWarning(s"Executor $execId state transfer from $oldState to RUNNING is unexpected") + } appInfo.resetRetryCount() } From 582cc93c685100d04190df1f3b9ca9fd522bf4ad Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 3 Dec 2015 11:06:25 -0800 Subject: [PATCH 1024/1168] [SPARK-12101][CORE] Fix thread pools that cannot cache tasks in Worker and AppClient `SynchronousQueue` cannot cache any task. This issue is similar to #9978. It's an easy fix. Just use the fixed `ThreadUtils.newDaemonCachedThreadPool`. Author: Shixiong Zhu Closes #10108 from zsxwing/fix-threadpool. --- .../org/apache/spark/deploy/client/AppClient.scala | 10 ++++------ .../org/apache/spark/deploy/worker/Worker.scala | 10 ++++------ .../apache/spark/deploy/yarn/YarnAllocator.scala | 14 ++++---------- 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index df6ba7d669ce9..1e2f469214b84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -68,12 +68,10 @@ private[spark] class AppClient( // A thread pool for registering with masters. Because registering with a master is a blocking // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same // time so that we can register with all masters. - private val registerMasterThreadPool = new ThreadPoolExecutor( - 0, - masterRpcAddresses.length, // Make sure we can register with all masters at the same time - 60L, TimeUnit.SECONDS, - new SynchronousQueue[Runnable](), - ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) + private val registerMasterThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "appclient-register-master-threadpool", + masterRpcAddresses.length // Make sure we can register with all masters at the same time + ) // A scheduled executor for scheduling the registration actions private val registrationRetryThread = diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 418faf8fc967f..1afc1ff59f2f9 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -146,12 +146,10 @@ private[deploy] class Worker( // A thread pool for registering with masters. Because registering with a master is a blocking // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same // time so that we can register with all masters. - private val registerMasterThreadPool = new ThreadPoolExecutor( - 0, - masterRpcAddresses.size, // Make sure we can register with all masters at the same time - 60L, TimeUnit.SECONDS, - new SynchronousQueue[Runnable](), - ThreadUtils.namedThreadFactory("worker-register-master-threadpool")) + private val registerMasterThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "worker-register-master-threadpool", + masterRpcAddresses.size // Make sure we can register with all masters at the same time + ) var coresUsed = 0 var memoryUsed = 0 diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 73cd9031f0250..4e044aa4788da 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -25,8 +25,6 @@ import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.collection.JavaConverters._ -import com.google.common.util.concurrent.ThreadFactoryBuilder - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient @@ -40,7 +38,7 @@ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor -import org.apache.spark.util.Utils +import org.apache.spark.util.ThreadUtils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -117,13 +115,9 @@ private[yarn] class YarnAllocator( // Resource capability requested for each executors private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) - private val launcherPool = new ThreadPoolExecutor( - // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue - sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25), Integer.MAX_VALUE, - 1, TimeUnit.MINUTES, - new LinkedBlockingQueue[Runnable](), - new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) - launcherPool.allowCoreThreadTimeOut(true) + private val launcherPool = ThreadUtils.newDaemonCachedThreadPool( + "ContainerLauncher", + sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25)) // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) From c3c92f8d1061618f8ac15dd37dd9898380341157 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 3 Dec 2015 11:09:29 -0800 Subject: [PATCH 1025/1168] [SPARK-12108] Make event logs smaller **Problem.** Event logs in 1.6 were much bigger than 1.5. I ran page rank and the event log size in 1.6 was almost 5x that in 1.5. I did a bisect to find that the RDD callsite added in #9398 is largely responsible for this. **Solution.** This patch removes the long form of the callsite (which is not used!) from the event log. This reduces the size of the event log significantly. *Note on compatibility*: if this patch is to be merged into 1.6.0, then it won't break any compatibility. Otherwise, if it is merged into 1.6.1, then we might need to add more backward compatibility handling logic (currently does not exist yet). Author: Andrew Or Closes #10115 from andrewor14/smaller-event-logs. --- .../org/apache/spark/storage/RDDInfo.scala | 4 +-- .../spark/ui/scope/RDDOperationGraph.scala | 4 +-- .../org/apache/spark/util/JsonProtocol.scala | 17 ++------- .../apache/spark/util/JsonProtocolSuite.scala | 35 ++++++++----------- 4 files changed, 20 insertions(+), 40 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 87c1b981e7e13..94e8559bd2e91 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -28,7 +28,7 @@ class RDDInfo( val numPartitions: Int, var storageLevel: StorageLevel, val parentIds: Seq[Int], - val callSite: CallSite = CallSite.empty, + val callSite: String = "", val scope: Option[RDDOperationScope] = None) extends Ordered[RDDInfo] { @@ -58,6 +58,6 @@ private[spark] object RDDInfo { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) new RDDInfo(rdd.id, rddName, rdd.partitions.length, - rdd.getStorageLevel, parentIds, rdd.creationSite, rdd.scope) + rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 24274562657b3..e9c8a8e299cd7 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -39,7 +39,7 @@ private[ui] case class RDDOperationGraph( rootCluster: RDDOperationCluster) /** A node in an RDDOperationGraph. This represents an RDD. */ -private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: CallSite) +private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: String) /** * A directed edge connecting two nodes in an RDDOperationGraph. @@ -178,7 +178,7 @@ private[ui] object RDDOperationGraph extends Logging { /** Return the dot representation of a node in an RDDOperationGraph. */ private def makeDotNode(node: RDDOperationNode): String = { - val label = s"${node.name} [${node.id}]\n${node.callsite.shortForm}" + val label = s"${node.name} [${node.id}]\n${node.callsite}" s"""${node.id} [label="$label"]""" } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c9beeb25e05af..2d2bd90eb339e 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -398,7 +398,7 @@ private[spark] object JsonProtocol { ("RDD ID" -> rddInfo.id) ~ ("Name" -> rddInfo.name) ~ ("Scope" -> rddInfo.scope.map(_.toJson)) ~ - ("Callsite" -> callsiteToJson(rddInfo.callSite)) ~ + ("Callsite" -> rddInfo.callSite) ~ ("Parent IDs" -> parentIds) ~ ("Storage Level" -> storageLevel) ~ ("Number of Partitions" -> rddInfo.numPartitions) ~ @@ -408,11 +408,6 @@ private[spark] object JsonProtocol { ("Disk Size" -> rddInfo.diskSize) } - def callsiteToJson(callsite: CallSite): JValue = { - ("Short Form" -> callsite.shortForm) ~ - ("Long Form" -> callsite.longForm) - } - def storageLevelToJson(storageLevel: StorageLevel): JValue = { ("Use Disk" -> storageLevel.useDisk) ~ ("Use Memory" -> storageLevel.useMemory) ~ @@ -857,9 +852,7 @@ private[spark] object JsonProtocol { val scope = Utils.jsonOption(json \ "Scope") .map(_.extract[String]) .map(RDDOperationScope.fromJson) - val callsite = Utils.jsonOption(json \ "Callsite") - .map(callsiteFromJson) - .getOrElse(CallSite.empty) + val callsite = Utils.jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") val parentIds = Utils.jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) @@ -880,12 +873,6 @@ private[spark] object JsonProtocol { rddInfo } - def callsiteFromJson(json: JValue): CallSite = { - val shortForm = (json \ "Short Form").extract[String] - val longForm = (json \ "Long Form").extract[String] - CallSite(shortForm, longForm) - } - def storageLevelFromJson(json: JValue): StorageLevel = { val useDisk = (json \ "Use Disk").extract[Boolean] val useMemory = (json \ "Use Memory").extract[Boolean] diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 3f94ef7041914..1939ce5c743b0 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -111,7 +111,6 @@ class JsonProtocolSuite extends SparkFunSuite { test("Dependent Classes") { val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap testRDDInfo(makeRddInfo(2, 3, 4, 5L, 6L)) - testCallsite(CallSite("happy", "birthday")) testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L)) testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false)) testTaskMetrics(makeTaskMetrics( @@ -343,13 +342,13 @@ class JsonProtocolSuite extends SparkFunSuite { // "Scope" and "Parent IDs" were introduced in Spark 1.4.0 // "Callsite" was introduced in Spark 1.6.0 val rddInfo = new RDDInfo(1, "one", 100, StorageLevel.NONE, Seq(1, 6, 8), - CallSite("short", "long"), Some(new RDDOperationScope("fable"))) + "callsite", Some(new RDDOperationScope("fable"))) val oldRddInfoJson = JsonProtocol.rddInfoToJson(rddInfo) .removeField({ _._1 == "Parent IDs"}) .removeField({ _._1 == "Scope"}) .removeField({ _._1 == "Callsite"}) val expectedRddInfo = new RDDInfo( - 1, "one", 100, StorageLevel.NONE, Seq.empty, CallSite.empty, scope = None) + 1, "one", 100, StorageLevel.NONE, Seq.empty, "", scope = None) assertEquals(expectedRddInfo, JsonProtocol.rddInfoFromJson(oldRddInfoJson)) } @@ -397,11 +396,6 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(info, newInfo) } - private def testCallsite(callsite: CallSite): Unit = { - val newCallsite = JsonProtocol.callsiteFromJson(JsonProtocol.callsiteToJson(callsite)) - assert(callsite === newCallsite) - } - private def testStageInfo(info: StageInfo) { val newInfo = JsonProtocol.stageInfoFromJson(JsonProtocol.stageInfoToJson(info)) assertEquals(info, newInfo) @@ -726,8 +720,7 @@ class JsonProtocolSuite extends SparkFunSuite { } private def makeRddInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { - val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, - Seq(1, 4, 7), CallSite(a.toString, b.toString)) + val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, Seq(1, 4, 7), a.toString) r.numCachedPartitions = c r.memSize = d r.diskSize = e @@ -870,7 +863,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 101, | "Name": "mayor", - | "Callsite": {"Short Form": "101", "Long Form": "201"}, + | "Callsite": "101", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1273,7 +1266,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 1, | "Name": "mayor", - | "Callsite": {"Short Form": "1", "Long Form": "200"}, + | "Callsite": "1", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1317,7 +1310,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 2, | "Name": "mayor", - | "Callsite": {"Short Form": "2", "Long Form": "400"}, + | "Callsite": "2", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1335,7 +1328,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", - | "Callsite": {"Short Form": "3", "Long Form": "401"}, + | "Callsite": "3", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1379,7 +1372,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", - | "Callsite": {"Short Form": "3", "Long Form": "600"}, + | "Callsite": "3", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1397,7 +1390,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", - | "Callsite": {"Short Form": "4", "Long Form": "601"}, + | "Callsite": "4", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1415,7 +1408,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", - | "Callsite": {"Short Form": "5", "Long Form": "602"}, + | "Callsite": "5", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1459,7 +1452,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", - | "Callsite": {"Short Form": "4", "Long Form": "800"}, + | "Callsite": "4", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1477,7 +1470,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", - | "Callsite": {"Short Form": "5", "Long Form": "801"}, + | "Callsite": "5", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1495,7 +1488,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 6, | "Name": "mayor", - | "Callsite": {"Short Form": "6", "Long Form": "802"}, + | "Callsite": "6", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1513,7 +1506,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 7, | "Name": "mayor", - | "Callsite": {"Short Form": "7", "Long Form": "803"}, + | "Callsite": "7", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, From 6d6be99d87d568cd495b1d7f1c65d2aa92d8aef5 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 3 Dec 2015 11:37:34 -0800 Subject: [PATCH 1026/1168] [MINOR][ML] Use coefficients replace weights Use ```coefficients``` replace ```weights```, I wish they are the last two. mengxr Author: Yanbo Liang Closes #10065 from yanboliang/coefficients. --- python/pyspark/ml/classification.py | 2 +- python/pyspark/ml/regression.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 4a2982e2047ff..5599b8f3ecd88 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -49,7 +49,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") >>> model = lr.fit(df) - >>> model.weights + >>> model.coefficients DenseVector([5.5...]) >>> model.intercept -2.68... diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 944e648ec8801..a0bb8ceed8861 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -40,7 +40,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Linear regression. The learning objective is to minimize the squared error, with regularization. - The specific squared error loss function used is: L = 1/2n ||A weights - y||^2^ + The specific squared error loss function used is: L = 1/2n ||A coefficients - y||^2^ This support multiple types of regularization: - none (a.k.a. ordinary least squares) From 1c86f0f9546b05b861a0631c2d2f032e60668fc0 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Thu, 3 Dec 2015 11:59:10 -0800 Subject: [PATCH 1027/1168] [SPARK-12107][EC2] Update spark-ec2 versions I haven't created a JIRA. If we absolutely need one I'll do it, but I'm fine with not getting mentioned in the release notes if that's the only purpose it'll serve. cc marmbrus - We should include this in 1.6-RC2 if there is one. I can open a second PR against branch-1.6 if necessary. Author: Nicholas Chammas Closes #10109 from nchammas/spark-ec2-versions. --- ec2/spark_ec2.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 84a950c9f6529..19d5980560fef 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -51,7 +51,7 @@ raw_input = input xrange = range -SPARK_EC2_VERSION = "1.5.0" +SPARK_EC2_VERSION = "1.6.0" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) VALID_SPARK_VERSIONS = set([ @@ -72,7 +72,10 @@ "1.3.1", "1.4.0", "1.4.1", - "1.5.0" + "1.5.0", + "1.5.1", + "1.5.2", + "1.6.0", ]) SPARK_TACHYON_MAP = { @@ -87,7 +90,10 @@ "1.3.1": "0.5.0", "1.4.0": "0.6.4", "1.4.1": "0.6.4", - "1.5.0": "0.7.1" + "1.5.0": "0.7.1", + "1.5.1": "0.7.1", + "1.5.2": "0.7.1", + "1.6.0": "0.8.2", } DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION From 64d6254cccb6543cbf7e64b0462719e288e1a2f0 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 3 Dec 2015 12:00:09 -0800 Subject: [PATCH 1028/1168] [FLAKY-TEST-FIX][STREAMING][TEST] Make sure StreamingContexts are shutdown after test Author: Tathagata Das Closes #10124 from tdas/InputStreamSuite-flaky-test. --- .../spark/streaming/InputStreamsSuite.scala | 122 +++++++++--------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 047e38ef90998..3a3176b91b1ee 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -206,28 +206,28 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val numTotalRecords = numThreads * numRecordsPerThread val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread) MultiThreadTestReceiver.haveAllThreadsFinished = false - - // set up the network stream using the test receiver - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.receiverStream[Int](testReceiver) - val countStream = networkStream.count val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] - val outputStream = new TestOutputStream(countStream, outputBuffer) def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Let the data from the receiver be received - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val startTime = System.currentTimeMillis() - while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && - System.currentTimeMillis() - startTime < 5000) { - Thread.sleep(100) - clock.advance(batchDuration.milliseconds) + + // set up the network stream using the test receiver + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val networkStream = ssc.receiverStream[Int](testReceiver) + val countStream = networkStream.count + + val outputStream = new TestOutputStream(countStream, outputBuffer) + outputStream.register() + ssc.start() + + // Let the data from the receiver be received + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val startTime = System.currentTimeMillis() + while ((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && + System.currentTimeMillis() - startTime < 5000) { + Thread.sleep(100) + clock.advance(batchDuration.milliseconds) + } + Thread.sleep(1000) } - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() // Verify whether data received was as expected logInfo("--------------------------------") @@ -239,30 +239,30 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("queue input stream - oneAtATime = true") { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val queue = new SynchronizedQueue[RDD[String]]() - val queueStream = ssc.queueStream(queue, oneAtATime = true) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - outputStream.register() - ssc.start() - - // Setup data queued into the stream - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - val inputIterator = input.toIterator - for (i <- 0 until input.size) { - // Enqueue more than 1 item per tick but they should dequeue one at a time - inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val queue = new SynchronizedQueue[RDD[String]]() + val queueStream = ssc.queueStream(queue, oneAtATime = true) + val outputStream = new TestOutputStream(queueStream, outputBuffer) + outputStream.register() + ssc.start() + + // Setup data queued into the stream + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + + val inputIterator = input.toIterator + for (i <- 0 until input.size) { + // Enqueue more than 1 item per tick but they should dequeue one at a time + inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.advance(batchDuration.milliseconds) + } + Thread.sleep(1000) } - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() // Verify whether data received was as expected logInfo("--------------------------------") @@ -282,33 +282,33 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("queue input stream - oneAtATime = false") { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val queue = new SynchronizedQueue[RDD[String]]() - val queueStream = ssc.queueStream(queue, oneAtATime = false) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(queueStream, outputBuffer) def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - outputStream.register() - ssc.start() - - // Setup data queued into the stream - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = Seq(Seq("1", "2", "3"), Seq("4", "5")) - // Enqueue the first 3 items (one by one), they should be merged in the next batch - val inputIterator = input.toIterator - inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) - Thread.sleep(1000) - - // Enqueue the remaining items (again one by one), merged in the final batch - inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val queue = new SynchronizedQueue[RDD[String]]() + val queueStream = ssc.queueStream(queue, oneAtATime = false) + val outputStream = new TestOutputStream(queueStream, outputBuffer) + outputStream.register() + ssc.start() + + // Setup data queued into the stream + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + + // Enqueue the first 3 items (one by one), they should be merged in the next batch + val inputIterator = input.toIterator + inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.advance(batchDuration.milliseconds) + Thread.sleep(1000) + + // Enqueue the remaining items (again one by one), merged in the final batch + inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.advance(batchDuration.milliseconds) + Thread.sleep(1000) + } // Verify whether data received was as expected logInfo("--------------------------------") From e69bb9fba79a4bf31ff4665cf584edc15f18a64c Mon Sep 17 00:00:00 2001 From: felixcheung Date: Thu, 3 Dec 2015 13:25:20 -0800 Subject: [PATCH 1029/1168] [SPARK-12019][SPARKR] Support character vector for sparkR.init(), check param and fix doc and add tests. Spark submit expects comma-separated list Author: felixcheung Closes #10034 from felixcheung/sparkrinitdoc. --- R/pkg/R/client.R | 10 ++++-- R/pkg/R/sparkR.R | 56 ++++++++++++++++++++++----------- R/pkg/R/utils.R | 5 +++ R/pkg/inst/tests/test_client.R | 9 ++++++ R/pkg/inst/tests/test_context.R | 20 ++++++++++++ 5 files changed, 79 insertions(+), 21 deletions(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index c811d1dac3bd5..25e99390a9c89 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -44,12 +44,16 @@ determineSparkSubmitBin <- function() { } generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + jars <- paste0(jars, collapse = ",") if (jars != "") { - jars <- paste("--jars", jars) + # construct the jars argument with a space between --jars and comma-separated values + jars <- paste0("--jars ", jars) } - if (!identical(packages, "")) { - packages <- paste("--packages", packages) + packages <- paste0(packages, collapse = ",") + if (packages != "") { + # construct the packages argument with a space between --packages and comma-separated values + packages <- paste0("--packages ", packages) } combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 7ff3fa628b9ca..d2bfad553104f 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -86,13 +86,13 @@ sparkR.stop <- function() { #' and use SparkR, refer to SparkR programming guide at #' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparkcontext-sqlcontext}. #' -#' @param master The Spark master URL. +#' @param master The Spark master URL #' @param appName Application name to register with cluster manager #' @param sparkHome Spark Home directory -#' @param sparkEnvir Named list of environment variables to set on worker nodes. -#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. -#' @param sparkJars Character string vector of jar files to pass to the worker nodes. -#' @param sparkPackages Character string vector of packages from spark-packages.org +#' @param sparkEnvir Named list of environment variables to set on worker nodes +#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors +#' @param sparkJars Character vector of jar files to pass to the worker nodes +#' @param sparkPackages Character vector of packages from spark-packages.org #' @export #' @examples #'\dontrun{ @@ -102,7 +102,9 @@ sparkR.stop <- function() { #' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark", #' list(spark.executor.memory="4g"), #' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), -#' c("jarfile1.jar","jarfile2.jar")) +#' c("one.jar", "two.jar", "three.jar"), +#' c("com.databricks:spark-avro_2.10:2.0.1", +#' "com.databricks:spark-csv_2.10:1.3.0")) #'} sparkR.init <- function( @@ -120,15 +122,8 @@ sparkR.init <- function( return(get(".sparkRjsc", envir = .sparkREnv)) } - jars <- suppressWarnings(normalizePath(as.character(sparkJars))) - - # Classpath separator is ";" on Windows - # URI needs four /// as from http://stackoverflow.com/a/18522792 - if (.Platform$OS.type == "unix") { - uriSep <- "//" - } else { - uriSep <- "////" - } + jars <- processSparkJars(sparkJars) + packages <- processSparkPackages(sparkPackages) sparkEnvirMap <- convertNamedListToEnv(sparkEnvir) @@ -145,7 +140,7 @@ sparkR.init <- function( sparkHome = sparkHome, jars = jars, sparkSubmitOpts = submitOps, - packages = sparkPackages) + packages = packages) # wait atmost 100 seconds for JVM to launch wait <- 0.1 for (i in 1:25) { @@ -195,8 +190,14 @@ sparkR.init <- function( paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } - nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- lapply(nonEmptyJars, + # Classpath separator is ";" on Windows + # URI needs four /// as from http://stackoverflow.com/a/18522792 + if (.Platform$OS.type == "unix") { + uriSep <- "//" + } else { + uriSep <- "////" + } + localJarPaths <- lapply(jars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs @@ -366,3 +367,22 @@ getClientModeSparkSubmitOpts <- function(submitOps, sparkEnvirMap) { # --option must be before the application class "sparkr-shell" in submitOps paste0(paste0(envirToOps, collapse = ""), submitOps) } + +# Utility function that handles sparkJars argument, and normalize paths +processSparkJars <- function(jars) { + splittedJars <- splitString(jars) + if (length(splittedJars) > length(jars)) { + warning("sparkJars as a comma-separated string is deprecated, use character vector instead") + } + normalized <- suppressWarnings(normalizePath(splittedJars)) + normalized +} + +# Utility function that handles sparkPackages argument +processSparkPackages <- function(packages) { + splittedPackages <- splitString(packages) + if (length(splittedPackages) > length(packages)) { + warning("sparkPackages as a comma-separated string is deprecated, use character vector instead") + } + splittedPackages +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 45c77a86c9582..43105aaa38424 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -636,3 +636,8 @@ assignNewEnv <- function(data) { } env } + +# Utility function to split by ',' and whitespace, remove empty tokens +splitString <- function(input) { + Filter(nzchar, unlist(strsplit(input, ",|\\s"))) +} diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R index 8a20991f89af8..a0664f32f31c1 100644 --- a/R/pkg/inst/tests/test_client.R +++ b/R/pkg/inst/tests/test_client.R @@ -34,3 +34,12 @@ test_that("no package specified doesn't add packages flag", { test_that("multiple packages don't produce a warning", { expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning())) }) + +test_that("sparkJars sparkPackages as character vectors", { + args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", + c("com.databricks:spark-avro_2.10:2.0.1", + "com.databricks:spark-csv_2.10:1.3.0")) + expect_match(args, "--jars one.jar,two.jar,three.jar") + expect_match(args, + "--packages com.databricks:spark-avro_2.10:2.0.1,com.databricks:spark-csv_2.10:1.3.0") +}) diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R index 80c1b89a4c627..1707e314beff5 100644 --- a/R/pkg/inst/tests/test_context.R +++ b/R/pkg/inst/tests/test_context.R @@ -92,3 +92,23 @@ test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whiteli " --driver-memory 4g sparkr-shell2")) # nolint end }) + +test_that("sparkJars sparkPackages as comma-separated strings", { + expect_warning(processSparkJars(" a, b ")) + jars <- suppressWarnings(processSparkJars(" a, b ")) + expect_equal(jars, c("a", "b")) + + jars <- suppressWarnings(processSparkJars(" abc ,, def ")) + expect_equal(jars, c("abc", "def")) + + jars <- suppressWarnings(processSparkJars(c(" abc ,, def ", "", "xyz", " ", "a,b"))) + expect_equal(jars, c("abc", "def", "xyz", "a", "b")) + + p <- processSparkPackages(c("ghi", "lmn")) + expect_equal(p, c("ghi", "lmn")) + + # check normalizePath + f <- dir()[[1]] + expect_that(processSparkJars(f), not(gives_warning())) + expect_match(processSparkJars(f), f) +}) From 2d3b87336783ca3c942657074af826db8864b435 Mon Sep 17 00:00:00 2001 From: Anderson de Andrade Date: Thu, 3 Dec 2015 16:37:00 -0800 Subject: [PATCH 1030/1168] [SPARK-12056][CORE] Create a TaskAttemptContext only after calling setConf. TaskAttemptContext's constructor will clone the configuration instead of referencing it. Calling setConf after creating TaskAttemptContext makes any changes to the configuration made inside setConf unperceived by RecordReader instances. As an example, Titan's InputFormat will change conf when calling setConf. They wrap their InputFormat around Cassandra's ColumnFamilyInputFormat, and append Cassandra's configuration. This change fixes the following error when using Titan's CassandraInputFormat with Spark: *java.lang.RuntimeException: org.apache.thrift.protocol.TProtocolException: Required field 'keyspace' was not present! Struct: set_key space_args(keyspace:null)* There's a discussion of this error here: https://groups.google.com/forum/#!topic/aureliusgraphs/4zpwyrYbGAE Author: Anderson de Andrade Closes #10046 from adeandrade/newhadooprdd-fix. --- core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index d1960990da0fe..86f38ae836b2b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -138,14 +138,14 @@ class NewHadoopRDD[K, V]( } inputMetrics.setBytesReadCallback(bytesReadCallback) - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) From 854bff14027694df53ff41143ad680781a7fd480 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Thu, 3 Dec 2015 16:39:12 -0800 Subject: [PATCH 1031/1168] [SPARK-11206] Support SQL UI on the history server (resubmit) Resubmit #9297 and #9991 On the live web UI, there is a SQL tab which provides valuable information for the SQL query. But once the workload is finished, we won't see the SQL tab on the history server. It will be helpful if we support SQL UI on the history server so we can analyze it even after its execution. To support SQL UI on the history server: 1. I added an onOtherEvent method to the SparkListener trait and post all SQL related events to the same event bus. 2. Two SQL events SparkListenerSQLExecutionStart and SparkListenerSQLExecutionEnd are defined in the sql module. 3. The new SQL events are written to event log using Jackson. 4. A new trait SparkHistoryListenerFactory is added to allow the history server to feed events to the SQL history listener. The SQL implementation is loaded at runtime using java.util.ServiceLoader. Author: Carson Wang Closes #10061 from carsonwang/SqlHistoryUI. --- .rat-excludes | 1 + .../org/apache/spark/JavaSparkListener.java | 3 + .../apache/spark/SparkFirehoseListener.java | 4 + .../scheduler/EventLoggingListener.scala | 4 + .../spark/scheduler/SparkListener.scala | 24 ++- .../spark/scheduler/SparkListenerBus.scala | 1 + .../scala/org/apache/spark/ui/SparkUI.scala | 16 +- .../org/apache/spark/util/JsonProtocol.scala | 11 +- ...park.scheduler.SparkHistoryListenerFactory | 1 + .../org/apache/spark/sql/SQLContext.scala | 19 ++- .../spark/sql/execution/SQLExecution.scala | 24 +-- .../spark/sql/execution/SparkPlanInfo.scala | 46 ++++++ .../sql/execution/metric/SQLMetricInfo.scala | 30 ++++ .../sql/execution/metric/SQLMetrics.scala | 56 ++++--- .../sql/execution/ui/ExecutionPage.scala | 4 +- .../spark/sql/execution/ui/SQLListener.scala | 139 ++++++++++++------ .../spark/sql/execution/ui/SQLTab.scala | 12 +- .../sql/execution/ui/SparkPlanGraph.scala | 20 +-- .../execution/metric/SQLMetricsSuite.scala | 4 +- .../sql/execution/ui/SQLListenerSuite.scala | 44 +++--- .../spark/sql/test/SharedSQLContext.scala | 1 + 21 files changed, 329 insertions(+), 135 deletions(-) create mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala diff --git a/.rat-excludes b/.rat-excludes index 08fba6d351d6a..7262c960ed6bb 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -82,4 +82,5 @@ INDEX gen-java.* .*avpr org.apache.spark.sql.sources.DataSourceRegister +org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index fa9acf0a15b88..23bc9a2e81727 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -82,4 +82,7 @@ public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } @Override public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } + @Override + public void onOtherEvent(SparkListenerEvent event) { } + } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 1214d05ba6063..e6b24afd88ad4 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,4 +118,8 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { onEvent(blockUpdated); } + @Override + public void onOtherEvent(SparkListenerEvent event) { + onEvent(event); + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 000a021a528cf..eaa07acc5132e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -207,6 +207,10 @@ private[spark] class EventLoggingListener( // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } + override def onOtherEvent(event: SparkListenerEvent): Unit = { + logEvent(event, flushLogger = true) + } + /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 896f1743332f1..075a7f13172de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -22,15 +22,19 @@ import java.util.Properties import scala.collection.Map import scala.collection.mutable -import org.apache.spark.{Logging, TaskEndReason} +import com.fasterxml.jackson.annotation.JsonTypeInfo + +import org.apache.spark.{Logging, SparkConf, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.ui.SparkUI @DeveloperApi -sealed trait SparkListenerEvent +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") +trait SparkListenerEvent @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) @@ -130,6 +134,17 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent */ private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent +/** + * Interface for creating history listeners defined in other modules like SQL, which are used to + * rebuild the history UI. + */ +private[spark] trait SparkHistoryListenerFactory { + /** + * Create listeners used to rebuild the history UI. + */ + def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] +} + /** * :: DeveloperApi :: * Interface for listening to events from the Spark scheduler. Note that this is an internal @@ -223,6 +238,11 @@ trait SparkListener { * Called when the driver receives a block update info. */ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } + + /** + * Called when other events like SQL-specific events are posted. + */ + def onOtherEvent(event: SparkListenerEvent) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 04afde33f5aad..95722a07144ec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,6 +61,7 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata + case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 4608bce202ec8..8da6884a38535 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,10 +17,13 @@ package org.apache.spark.ui -import java.util.Date +import java.util.{Date, ServiceLoader} + +import scala.collection.JavaConverters._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -154,7 +157,16 @@ private[spark] object SparkUI { appName: String, basePath: String, startTime: Long): SparkUI = { - create(None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + val sparkUI = create( + None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + + val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], + Utils.getContextOrSparkClassLoader).asScala + listenerFactories.foreach { listenerFactory => + val listeners = listenerFactory.createListeners(conf, sparkUI) + listeners.foreach(listenerBus.addListener) + } + sparkUI } /** diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 2d2bd90eb339e..cb0f1bf79f3d5 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -19,19 +19,21 @@ package org.apache.spark.util import java.util.{Properties, UUID} -import org.apache.spark.scheduler.cluster.ExecutorInfo - import scala.collection.JavaConverters._ import scala.collection.Map +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ /** @@ -54,6 +56,8 @@ private[spark] object JsonProtocol { private implicit val format = DefaultFormats + private val mapper = new ObjectMapper().registerModule(DefaultScalaModule) + /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ @@ -96,6 +100,7 @@ private[spark] object JsonProtocol { executorMetricsUpdateToJson(metricsUpdate) case blockUpdated: SparkListenerBlockUpdated => throw new MatchError(blockUpdated) // TODO(ekl) implement this + case _ => parse(mapper.writeValueAsString(event)) } } @@ -506,6 +511,8 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) + .asInstanceOf[SparkListenerEvent] } } diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory new file mode 100644 index 0000000000000..507100be90967 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory @@ -0,0 +1 @@ +org.apache.spark.sql.execution.ui.SQLHistoryListenerFactory diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4e26250868374..db286ea8700b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1245,6 +1245,7 @@ class SQLContext private[sql]( sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { SQLContext.clearInstantiatedContext() + SQLContext.clearSqlListener() } }) @@ -1272,6 +1273,8 @@ object SQLContext { */ @transient private val instantiatedContext = new AtomicReference[SQLContext]() + @transient private val sqlListener = new AtomicReference[SQLListener]() + /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. * @@ -1316,6 +1319,10 @@ object SQLContext { Option(instantiatedContext.get()) } + private[sql] def clearSqlListener(): Unit = { + sqlListener.set(null) + } + /** * Changes the SQLContext that will be returned in this thread and its children when * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives @@ -1364,9 +1371,13 @@ object SQLContext { * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. */ private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { - val listener = new SQLListener(sc.conf) - sc.addSparkListener(listener) - sc.ui.foreach(new SQLTab(listener, _)) - listener + if (sqlListener.get() == null) { + val listener = new SQLListener(sc.conf) + if (sqlListener.compareAndSet(null, listener)) { + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + } + } + sqlListener.get() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 1422e15549c94..34971986261c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,7 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionStart, + SparkListenerSQLExecutionEnd} import org.apache.spark.util.Utils private[sql] object SQLExecution { @@ -45,25 +46,14 @@ private[sql] object SQLExecution { sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) val r = try { val callSite = Utils.getCallSite() - sqlContext.listener.onExecutionStart( - executionId, - callSite.shortForm, - callSite.longForm, - queryExecution.toString, - SparkPlanGraph(queryExecution.executedPlan), - System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) try { body } finally { - // Ideally, we need to make sure onExecutionEnd happens after onJobStart and onJobEnd. - // However, onJobStart and onJobEnd run in the listener thread. Because we cannot add new - // SQL event types to SparkListener since it's a public API, we cannot guarantee that. - // - // SQLListener should handle the case that onExecutionEnd happens before onJobEnd. - // - // The worst case is onExecutionEnd may happen before onJobStart when the listener thread - // is very busy. If so, we cannot track the jobs for the execution. It seems acceptable. - sqlContext.listener.onExecutionEnd(executionId, System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } } finally { sc.setLocalProperty(EXECUTION_ID_KEY, null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala new file mode 100644 index 0000000000000..486ce34064e43 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.metric.SQLMetricInfo +import org.apache.spark.util.Utils + +/** + * :: DeveloperApi :: + * Stores information about a SQL SparkPlan. + */ +@DeveloperApi +class SparkPlanInfo( + val nodeName: String, + val simpleString: String, + val children: Seq[SparkPlanInfo], + val metrics: Seq[SQLMetricInfo]) + +private[sql] object SparkPlanInfo { + + def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { + val metrics = plan.metrics.toSeq.map { case (key, metric) => + new SQLMetricInfo(metric.name.getOrElse(key), metric.id, + Utils.getFormattedClassName(metric.param)) + } + val children = plan.children.map(fromSparkPlan) + + new SparkPlanInfo(plan.nodeName, plan.simpleString, children, metrics) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala new file mode 100644 index 0000000000000..2708219ad3485 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Stores information about a SQL Metric. + */ +@DeveloperApi +class SQLMetricInfo( + val name: String, + val accumulatorId: Long, + val metricParam: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 1c253e3942e95..6c0f6f8a52dc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -104,21 +104,39 @@ private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialVa override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) } +private object LongSQLMetricParam extends LongSQLMetricParam(_.sum.toString, 0L) + +private object StaticsLongSQLMetricParam extends LongSQLMetricParam( + (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.bytesToString) + } + s"\n$sum ($min, $med, $max)" + }, -1L) + private[sql] object SQLMetrics { private def createLongMetric( sc: SparkContext, name: String, - stringValue: Seq[Long] => String, - initialValue: Long): LongSQLMetric = { - val param = new LongSQLMetricParam(stringValue, initialValue) + param: LongSQLMetricParam): LongSQLMetric = { val acc = new LongSQLMetric(name, param) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - createLongMetric(sc, name, _.sum.toString, 0L) + createLongMetric(sc, name, LongSQLMetricParam) } /** @@ -126,31 +144,25 @@ private[sql] object SQLMetrics { * spill size, etc. */ def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = { - val stringValue = (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.bytesToString) - } - s"\n$sum ($min, $med, $max)" - } // The final result of this metric in physical operator UI may looks like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) - createLongMetric(sc, s"$name total (min, med, max)", stringValue, -1L) + createLongMetric(sc, s"$name total (min, med, max)", StaticsLongSQLMetricParam) + } + + def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = { + val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam) + val staticsSQLMetricParam = Utils.getFormattedClassName(StaticsLongSQLMetricParam) + val metricParam = metricParamName match { + case `longSQLMetricParam` => LongSQLMetricParam + case `staticsSQLMetricParam` => StaticsLongSQLMetricParam + } + metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]] } /** * A metric that its value will be ignored. Use this one when we need a metric parameter but don't * care about the value. */ - val nullLongMetric = new LongSQLMetric("null", new LongSQLMetricParam(_.sum.toString, 0L)) + val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index e74d6fb396e1c..c74ad40406992 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, Unparsed} - -import org.apache.commons.lang3.StringEscapeUtils +import scala.xml.Node import org.apache.spark.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 5a072de400b6a..e19a1e3e5851f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,11 +19,34 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricValue, SQLMetricParam} import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} +import org.apache.spark.ui.SparkUI + +@DeveloperApi +case class SparkListenerSQLExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + sparkPlanInfo: SparkPlanInfo, + time: Long) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) + extends SparkListenerEvent + +private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { + + override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { + List(new SQLHistoryListener(conf, sparkUI)) + } +} private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { @@ -118,7 +141,8 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics, finishTask = false) + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(), + finishTask = false) } } @@ -140,7 +164,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskMetrics, + taskEnd.taskMetrics.accumulatorUpdates(), finishTask = true) } @@ -148,15 +172,12 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi * Update the accumulator values of a task with the latest metrics for this task. This is called * every time we receive an executor heartbeat or when a task finishes. */ - private def updateTaskAccumulatorValues( + protected def updateTaskAccumulatorValues( taskId: Long, stageId: Int, stageAttemptID: Int, - metrics: TaskMetrics, + accumulatorUpdates: Map[Long, Any], finishTask: Boolean): Unit = { - if (metrics == null) { - return - } _stageIdToStageMetrics.get(stageId) match { case Some(stageMetrics) => @@ -174,9 +195,9 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi case Some(taskMetrics) => if (finishTask) { taskMetrics.finished = true - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else if (!taskMetrics.finished) { - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else { // If a task is finished, we should not override with accumulator updates from // heartbeat reports @@ -185,7 +206,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi // TODO Now just set attemptId to 0. Should fix here when we can get the attempt // id from SparkListenerExecutorMetricsUpdate stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( - attemptId = 0, finished = finishTask, metrics.accumulatorUpdates()) + attemptId = 0, finished = finishTask, accumulatorUpdates) } } case None => @@ -193,38 +214,40 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } } - def onExecutionStart( - executionId: Long, - description: String, - details: String, - physicalPlanDescription: String, - physicalPlanGraph: SparkPlanGraph, - time: Long): Unit = { - val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => - node.metrics.map(metric => metric.accumulatorId -> metric) - } - - val executionUIData = new SQLExecutionUIData(executionId, description, details, - physicalPlanDescription, physicalPlanGraph, sqlPlanMetrics.toMap, time) - synchronized { - activeExecutions(executionId) = executionUIData - _executionIdToData(executionId) = executionUIData - } - } - - def onExecutionEnd(executionId: Long, time: Long): Unit = synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - executionUIData.completionTime = Some(time) - if (!executionUIData.hasRunningJobs) { - // onExecutionEnd happens after all "onJobEnd"s - // So we should update the execution lists. - markExecutionFinished(executionId) - } else { - // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. - // Then we don't if the execution is successful, so let the last onJobEnd updates the - // execution lists. + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerSQLExecutionStart(executionId, description, details, + physicalPlanDescription, sparkPlanInfo, time) => + val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) + val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + val executionUIData = new SQLExecutionUIData( + executionId, + description, + details, + physicalPlanDescription, + physicalPlanGraph, + sqlPlanMetrics.toMap, + time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + case SparkListenerSQLExecutionEnd(executionId, time) => synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. + } } } + case _ => // Ignore } private def markExecutionFinished(executionId: Long): Unit = { @@ -289,6 +312,38 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } +private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) + extends SQLListener(conf) { + + private var sqlTabAttached = false + + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { + // Do nothing + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskInfo.accumulables.map { acc => + (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong)) + }.toMap, + finishTask = true) + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case _: SparkListenerSQLExecutionStart => + if (!sqlTabAttached) { + new SQLTab(this, sparkUI) + sqlTabAttached = true + } + super.onOtherEvent(event) + case _ => super.onOtherEvent(event) + } +} + /** * Represent all necessary data for an execution that will be used in Web UI. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 9c27944d42fc6..4f50b2ecdc8f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.execution.ui -import java.util.concurrent.atomic.AtomicInteger - import org.apache.spark.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) - extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { + extends SparkUITab(sparkUI, "SQL") with Logging { val parent = sparkUI @@ -35,13 +33,5 @@ private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) } private[sql] object SQLTab { - private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" - - private val nextTabId = new AtomicInteger(0) - - private def nextTabName: String = { - val nextId = nextTabId.getAndIncrement() - if (nextId == 0) "SQL" else s"SQL$nextId" - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index f1fce5478a3fe..7af0ff09c5c6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.SQLMetrics /** * A graph used for storing information of an executionPlan of DataFrame. @@ -48,27 +48,27 @@ private[sql] object SparkPlanGraph { /** * Build a SparkPlanGraph from the root of a SparkPlan tree. */ - def apply(plan: SparkPlan): SparkPlanGraph = { + def apply(planInfo: SparkPlanInfo): SparkPlanGraph = { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(plan, nodeIdGenerator, nodes, edges) + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges) new SparkPlanGraph(nodes, edges) } private def buildSparkPlanGraphNode( - plan: SparkPlan, + planInfo: SparkPlanInfo, nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = plan.metrics.toSeq.map { case (key, metric) => - SQLPlanMetric(metric.name.getOrElse(key), metric.id, - metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) } val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, planInfo.simpleString, metrics) nodes += node - val childrenNodes = plan.children.map( + val childrenNodes = planInfo.children.map( child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) for (child <- childrenNodes) { edges += SparkPlanGraphEdge(child.id, node.id) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 82867ab4967bb..4f2cad19bfb6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -26,6 +26,7 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -82,7 +83,8 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => + val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( + df.queryExecution.executedPlan)).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => val nodeMetrics = node.metrics.map { metric => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index c15aac775096c..12a4e1356fed0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -21,10 +21,10 @@ import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.sql.test.SharedSQLContext class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { @@ -82,7 +82,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val executionId = 0 val df = createTestDataFrame val accumulatorIds = - SparkPlanGraph(df.queryExecution.executedPlan).nodes.flatMap(_.metrics.map(_.accumulatorId)) + SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)) + .nodes.flatMap(_.metrics.map(_.accumulatorId)) // Assume all accumulators are long var accumulatorValue = 0L val accumulatorUpdates = accumulatorIds.map { id => @@ -90,13 +91,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (id, accumulatorValue) }.toMap - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) val executionUIData = listener.executionIdToData(0) @@ -206,7 +207,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), JobSucceeded )) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) assert(executionUIData.runningJobs.isEmpty) assert(executionUIData.succeededJobs === Seq(0)) @@ -219,19 +221,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -248,13 +251,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -271,7 +274,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 1, time = System.currentTimeMillis(), @@ -288,19 +292,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq.empty, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -338,6 +343,7 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly val sc = new SparkContext(conf) try { + SQLContext.clearSqlListener() val sqlContext = new SQLContext(sc) import sqlContext.implicits._ // Run 100 successful executions and 100 failed executions. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 963d10eed62ed..e7b376548787c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -42,6 +42,7 @@ trait SharedSQLContext extends SQLTestUtils { * Initialize the [[TestSQLContext]]. */ protected override def beforeAll(): Unit = { + SQLContext.clearSqlListener() if (_ctx == null) { _ctx = new TestSQLContext } From 99d89f7f2479272efd0fe474a5dc5c61f9959f1f Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Thu, 3 Dec 2015 21:11:10 -0800 Subject: [PATCH 1032/1168] [SPARK-12104][SPARKR] collect() does not handle multiple columns with same name. Author: Sun Rui Closes #10118 from sun-rui/SPARK-12104. --- R/pkg/R/DataFrame.R | 8 ++++---- R/pkg/inst/tests/test_sparkSQL.R | 6 ++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a82ded9c51fac..81b4e6b91d8a2 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -822,21 +822,21 @@ setMethod("collect", # Get a column of complex type returns a list. # Get a cell from a column of complex type returns a list instead of a vector. col <- listCols[[colIndex]] - colName <- dtypes[[colIndex]][[1]] if (length(col) <= 0) { - df[[colName]] <- col + df[[colIndex]] <- col } else { colType <- dtypes[[colIndex]][[2]] # Note that "binary" columns behave like complex types. if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { vec <- do.call(c, col) stopifnot(class(vec) != "list") - df[[colName]] <- vec + df[[colIndex]] <- vec } else { - df[[colName]] <- col + df[[colIndex]] <- col } } } + names(df) <- names(x) df } }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 92ec82096c6df..1e7cb54099703 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -530,6 +530,11 @@ test_that("collect() returns a data.frame", { expect_equal(names(rdf)[1], "age") expect_equal(nrow(rdf), 0) expect_equal(ncol(rdf), 2) + + # collect() correctly handles multiple columns with same name + df <- createDataFrame(sqlContext, list(list(1, 2)), schema = c("name", "name")) + ldf <- collect(df) + expect_equal(names(ldf), c("name", "name")) }) test_that("limit() returns DataFrame with the correct number of rows", { @@ -1197,6 +1202,7 @@ test_that("join() and merge() on a DataFrame", { joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) expect_equal(count(joined), 12) + expect_equal(names(collect(joined)), c("age", "name", "name", "test")) joined2 <- join(df, df2, df$name == df2$name) expect_equal(names(joined2), c("age", "name", "name", "test")) From ff3cb64f437e9a3a769c1f26eaf5d7ba24aab243 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 4 Dec 2015 01:42:29 -0800 Subject: [PATCH 1033/1168] [SPARK-12122][STREAMING] Prevent batches from being submitted twice after recovering StreamingContext from checkpoint Author: Tathagata Das Closes #10127 from tdas/SPARK-12122. --- .../org/apache/spark/streaming/scheduler/JobGenerator.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 2de035d166e7b..8dfdc1f57b403 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -220,7 +220,8 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + pendingTimes.mkString(", ")) // Reschedule jobs for these times - val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) + val timesToReschedule = (pendingTimes ++ downTimes).filter { _ < restartTime } + .distinct.sorted(Time.ordering) logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", ")) timesToReschedule.foreach { time => From a12d099131c383fddfd82b73ecbe14562a9f617f Mon Sep 17 00:00:00 2001 From: kaklakariada Date: Fri, 4 Dec 2015 14:43:16 +0000 Subject: [PATCH 1034/1168] Add links howto to setup IDEs for developing spark These links make it easier for new developers to work with Spark in their IDE. Author: kaklakariada Closes #10104 from kaklakariada/readme-developing-ide-gettting-started. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index c0d6a946035a9..d5804d1a20b43 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ To build Spark and its example programs, run: (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). +For developing Spark using an IDE, see [Eclipse](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-Eclipse) +and [IntelliJ](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IntelliJ). ## Interactive Scala Shell From 9a45fb6fc28185160a7dfb0a5e63a5473ef6d3b4 Mon Sep 17 00:00:00 2001 From: Nong Date: Fri, 4 Dec 2015 10:01:20 -0800 Subject: [PATCH 1035/1168] [SPARK-12089] [SQL] Fix memory corrupt due to freeing a page being referenced When the spillable sort iterator was spilled, it was mistakenly keeping the last page in memory rather than the current page. This causes the current record to get corrupted. Author: Nong Closes #10142 from nongli/spark-12089. --- .../util/collection/unsafe/sort/UnsafeExternalSorter.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 5a97f4f11340c..79d74b23ceaef 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -443,6 +443,7 @@ public long spill() throws IOException { UnsafeInMemorySorter.SortedIterator inMemIterator = ((UnsafeInMemorySorter.SortedIterator) upstream).clone(); + // Iterate over the records that have not been returned and spill them. final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); while (inMemIterator.hasNext()) { @@ -458,9 +459,11 @@ public long spill() throws IOException { long released = 0L; synchronized (UnsafeExternalSorter.this) { - // release the pages except the one that is used + // release the pages except the one that is used. There can still be a caller that + // is accessing the current record. We free this page in that caller's next loadNext() + // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.getBaseObject() != inMemIterator.getBaseObject()) { + if (!loaded || page.getBaseObject() != upstream.getBaseObject()) { released += page.size(); freePage(page); } else { From 5c5776d311258c3077a82272688b31b827ba5a94 Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Fri, 4 Dec 2015 12:03:45 -0800 Subject: [PATCH 1036/1168] [SPARK-6990][BUILD] Add Java linting script; fix minor warnings This replaces https://github.com/apache/spark/pull/9696 Invoke Checkstyle and print any errors to the console, failing the step. Use Google's style rules modified according to https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide Some important checks are disabled (see TODOs in `checkstyle.xml`) due to multiple violations being present in the codebase. Suggest fixing those TODOs in a separate PR(s). More on Checkstyle can be found on the [official website](http://checkstyle.sourceforge.net/). Sample output (from [build 46345](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/46345/consoleFull)) (duplicated because I run the build twice with different profiles): > Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java:[217,7] (coding) MissingSwitchDefault: switch without "default" clause. > [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java:[198,10] (modifier) ModifierOrder: 'protected' modifier out of order with the JLS suggestions. > [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java:[217,7] (coding) MissingSwitchDefault: switch without "default" clause. > [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java:[198,10] (modifier) ModifierOrder: 'protected' modifier out of order with the JLS suggestions. > [error] running /home/jenkins/workspace/SparkPullRequestBuilder2/dev/lint-java ; received return code 1 Also fix some of the minor violations that didn't require sweeping changes. Apologies for the previous botched PRs - I finally figured out the issue. cr: JoshRosen, pwendell > I state that the contribution is my original work, and I license the work to the project under the project's open source license. Author: Dmitry Erastov Closes #9867 from dskrvk/master. --- checkstyle-suppressions.xml | 33 ++++ checkstyle.xml | 164 ++++++++++++++++++ .../unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../map/AbstractBytesToBytesMapSuite.java | 4 +- dev/lint-java | 30 ++++ dev/run-tests-jenkins.py | 1 + dev/run-tests.py | 7 + dev/sparktestsupport/__init__.py | 1 + .../examples/ml/JavaSimpleParamsExample.java | 2 +- .../spark/examples/mllib/JavaLDAExample.java | 3 +- ...ultiLabelClassificationMetricsExample.java | 12 +- ...ulticlassClassificationMetricsExample.java | 12 +- .../mllib/JavaRankingMetricsExample.java | 4 +- .../mllib/JavaRecommendationExample.java | 2 +- .../mllib/JavaRegressionMetricsExample.java | 3 +- .../streaming/JavaSqlNetworkWordCount.java | 4 +- .../ml/feature/JavaStringIndexerSuite.java | 6 +- .../spark/mllib/clustering/JavaLDASuite.java | 2 +- .../shuffle/ExternalShuffleBlockResolver.java | 2 +- .../network/sasl/SaslIntegrationSuite.java | 2 +- pom.xml | 24 +++ .../execution/UnsafeExternalRowSorter.java | 2 +- .../spark/sql/types/SQLUserDefinedType.java | 2 +- .../SpecificParquetRecordReaderBase.java | 2 +- .../apache/spark/sql/hive/test/Complex.java | 86 ++++++--- .../spark/streaming/util/WriteAheadLog.java | 10 +- .../apache/spark/streaming/JavaAPISuite.java | 4 +- .../streaming/JavaTrackStateByKeySuite.java | 2 +- .../apache/spark/tags/ExtendedHiveTest.java | 1 + .../apache/spark/tags/ExtendedYarnTest.java | 1 + .../apache/spark/unsafe/types/UTF8String.java | 8 +- 31 files changed, 368 insertions(+), 70 deletions(-) create mode 100644 checkstyle-suppressions.xml create mode 100644 checkstyle.xml create mode 100755 dev/lint-java diff --git a/checkstyle-suppressions.xml b/checkstyle-suppressions.xml new file mode 100644 index 0000000000000..9242be3d0357a --- /dev/null +++ b/checkstyle-suppressions.xml @@ -0,0 +1,33 @@ + + + + + + + + + diff --git a/checkstyle.xml b/checkstyle.xml new file mode 100644 index 0000000000000..a493ee443c752 --- /dev/null +++ b/checkstyle.xml @@ -0,0 +1,164 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index c91e88f31bf9b..c16cbce9a0f6c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -175,7 +175,7 @@ private SortedIterator(int numRecords) { this.position = 0; } - public SortedIterator clone () { + public SortedIterator clone() { SortedIterator iter = new SortedIterator(numRecords); iter.position = position; iter.baseObject = baseObject; diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index d87a1d2a56d99..a5c583f9f2844 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -356,8 +356,8 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES); final Iterator iter = map.iterator(); - final long key[] = new long[KEY_LENGTH / 8]; - final long value[] = new long[VALUE_LENGTH / 8]; + final long[] key = new long[KEY_LENGTH / 8]; + final long[] value = new long[VALUE_LENGTH / 8]; while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); diff --git a/dev/lint-java b/dev/lint-java new file mode 100755 index 0000000000000..fe8ab83d562d1 --- /dev/null +++ b/dev/lint-java @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" + +ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) + +if test ! -z "$ERRORS"; then + echo -e "Checkstyle checks failed at following occurrences:\n$ERRORS" + exit 1 +else + echo -e "Checkstyle checks passed." +fi diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 4f390ef1eaa32..7aecea25b2099 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -119,6 +119,7 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_GENERAL"]: 'some tests', ERROR_CODES["BLOCK_RAT"]: 'RAT tests', ERROR_CODES["BLOCK_SCALA_STYLE"]: 'Scala style tests', + ERROR_CODES["BLOCK_JAVA_STYLE"]: 'Java style tests', ERROR_CODES["BLOCK_PYTHON_STYLE"]: 'Python style tests', ERROR_CODES["BLOCK_R_STYLE"]: 'R style tests', ERROR_CODES["BLOCK_DOCUMENTATION"]: 'to generate documentation', diff --git a/dev/run-tests.py b/dev/run-tests.py index 9e1abb0697192..e7e10f1d8c725 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -198,6 +198,11 @@ def run_scala_style_checks(): run_cmd([os.path.join(SPARK_HOME, "dev", "lint-scala")]) +def run_java_style_checks(): + set_title_and_block("Running Java style checks", "BLOCK_JAVA_STYLE") + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-java")]) + + def run_python_style_checks(): set_title_and_block("Running Python style checks", "BLOCK_PYTHON_STYLE") run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) @@ -522,6 +527,8 @@ def main(): # style checks if not changed_files or any(f.endswith(".scala") for f in changed_files): run_scala_style_checks() + if not changed_files or any(f.endswith(".java") for f in changed_files): + run_java_style_checks() if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") for f in changed_files): diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 8ab6d9e37ca2f..0e8032d13341e 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -31,5 +31,6 @@ "BLOCK_SPARK_UNIT_TESTS": 18, "BLOCK_PYSPARK_UNIT_TESTS": 19, "BLOCK_SPARKR_UNIT_TESTS": 20, + "BLOCK_JAVA_STYLE": 21, "BLOCK_TIMEOUT": 124 } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 94beeced3d479..ea83e8fef9eb9 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -77,7 +77,7 @@ public static void main(String[] args) { ParamMap paramMap = new ParamMap(); paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. - double thresholds[] = {0.45, 0.55}; + double[] thresholds = {0.45, 0.55}; paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params. // One can also combine ParamMaps. diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java index fd53c81cc4974..de8e739ac9256 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -41,8 +41,9 @@ public static void main(String[] args) { public Vector call(String s) { String[] sarray = s.trim().split(" "); double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) + for (int i = 0; i < sarray.length; i++) { values[i] = Double.parseDouble(sarray[i]); + } return Vectors.dense(values); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java index b54e1ea3f2bcf..5ba01e0d08816 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java @@ -57,12 +57,12 @@ public static void main(String[] args) { // Stats by labels for (int i = 0; i < metrics.labels().length - 1; i++) { - System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision - (metrics.labels()[i])); - System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics - .labels()[i])); - System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure - (metrics.labels()[i])); + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision( + metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall( + metrics.labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure( + metrics.labels()[i])); } // Micro stats diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java index 21f628fb51b6e..5247c9c748618 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java @@ -74,12 +74,12 @@ public Tuple2 call(LabeledPoint p) { // Stats by labels for (int i = 0; i < metrics.labels().length; i++) { - System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision - (metrics.labels()[i])); - System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics - .labels()[i])); - System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure - (metrics.labels()[i])); + System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision( + metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall( + metrics.labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure( + metrics.labels()[i])); } //Weighted stats diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java index 7c4c97e74681f..47ab3fc358246 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -120,8 +120,8 @@ public List call(Rating[] docs) { } } ); - JavaRDD, List>> relevantDocs = userMoviesList.join - (userRecommendedList).values(); + JavaRDD, List>> relevantDocs = userMoviesList.join( + userRecommendedList).values(); // Instantiate the metrics object RankingMetrics metrics = RankingMetrics.of(relevantDocs); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java index 1065fde953b96..c179e7578cdfa 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java @@ -29,7 +29,7 @@ // $example off$ public class JavaRecommendationExample { - public static void main(String args[]) { + public static void main(String[] args) { // $example on$ SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example"); JavaSparkContext jsc = new JavaSparkContext(conf); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java index d2efc6bf97776..4e89dd0c37c52 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java @@ -43,8 +43,9 @@ public static void main(String[] args) { public LabeledPoint call(String line) { String[] parts = line.split(" "); double[] v = new double[parts.length - 1]; - for (int i = 1; i < parts.length - 1; i++) + for (int i = 1; i < parts.length - 1; i++) { v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + } return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); } } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index 46562ddbbcb57..3515d7be45d37 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -112,8 +112,8 @@ public JavaRecord call(String word) { /** Lazily instantiated singleton instance of SQLContext */ class JavaSQLContextSingleton { - static private transient SQLContext instance = null; - static public SQLContext getInstance(SparkContext sparkContext) { + private static transient SQLContext instance = null; + public static SQLContext getInstance(SparkContext sparkContext) { if (instance == null) { instance = new SQLContext(sparkContext); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 6b2c48ef1c342..b2df79ba74feb 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -58,7 +58,7 @@ public void testStringIndexer() { createStructField("label", StringType, false) }); List data = Arrays.asList( - c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")); + cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); DataFrame dataset = sqlContext.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() @@ -67,12 +67,12 @@ public void testStringIndexer() { DataFrame output = indexer.fit(dataset).transform(dataset); Assert.assertArrayEquals( - new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, + new Row[] { cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0) }, output.orderBy("id").select("id", "labelIndex").collect()); } /** An alias for RowFactory.create. */ - private Row c(Object... values) { + private Row cr(Object... values) { return RowFactory.create(values); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 3fea359a3b46c..225a216270b3b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -144,7 +144,7 @@ public Boolean call(Tuple2 tuple2) { } @Test - public void OnlineOptimizerCompatibility() { + public void onlineOptimizerCompatibility() { int k = 3; double topicSmoothing = 1.2; double termSmoothing = 1.2; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0d4dd6afac769..e5cb68c8a4dbb 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -419,7 +419,7 @@ private static void storeVersion(DB db) throws IOException { public static class StoreVersion { - final static byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); + static final byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); public final int major; public final int minor; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 19c870aebb023..f573d962fe361 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -61,7 +61,7 @@ public class SaslIntegrationSuite { // Use a long timeout to account for slow / overloaded build machines. In the normal case, // tests should finish way before the timeout expires. - private final static long TIMEOUT_MS = 10_000; + private static final long TIMEOUT_MS = 10_000; static TransportServer server; static TransportConf conf; diff --git a/pom.xml b/pom.xml index 234fd5dea1a6e..16e656d11961d 100644 --- a/pom.xml +++ b/pom.xml @@ -2256,6 +2256,30 @@ + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.17 + + false + false + true + false + ${basedir}/src/main/java + ${basedir}/src/test/java + checkstyle.xml + ${basedir}/target/checkstyle-output.xml + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} + + + + + check + + + + org.apache.maven.plugins diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 3986d6e18f770..352002b3499a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -51,7 +51,7 @@ final class UnsafeExternalRowSorter { private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; - public static abstract class PrefixComputer { + public abstract static class PrefixComputer { abstract long computePrefix(InternalRow row); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index df64a878b6b36..1e4e5ede8cc11 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -41,5 +41,5 @@ * Returns an instance of the UserDefinedType which can serialize and deserialize the user * class to and from Catalyst built-in types. */ - Class > udt(); + Class> udt(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 2ed30c1f5a8d9..842dcb8c93dc2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -195,7 +195,7 @@ protected static final class NullIntIterator extends IntIterator { * Creates a reader for definition and repetition levels, returning an optimized one if * the levels are not needed. */ - static protected IntIterator createRLEIterator(int maxLevel, BytesInput bytes, + protected static IntIterator createRLEIterator(int maxLevel, BytesInput bytes, ColumnDescriptor descriptor) throws IOException { try { if (maxLevel == 0) return new NullIntIterator(); diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java index e010112bb9327..4ef1f276d1bbb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -489,6 +489,7 @@ public void setFieldValue(_Fields field, Object value) { } break; + default: } } @@ -512,6 +513,7 @@ public Object getFieldValue(_Fields field) { case M_STRING_STRING: return getMStringString(); + default: } throw new IllegalStateException(); } @@ -535,75 +537,91 @@ public boolean isSet(_Fields field) { return isSetLintString(); case M_STRING_STRING: return isSetMStringString(); + default: } throw new IllegalStateException(); } @Override public boolean equals(Object that) { - if (that == null) + if (that == null) { return false; - if (that instanceof Complex) + } + if (that instanceof Complex) { return this.equals((Complex)that); + } return false; } public boolean equals(Complex that) { - if (that == null) + if (that == null) { return false; + } boolean this_present_aint = true; boolean that_present_aint = true; if (this_present_aint || that_present_aint) { - if (!(this_present_aint && that_present_aint)) + if (!(this_present_aint && that_present_aint)) { return false; - if (this.aint != that.aint) + } + if (this.aint != that.aint) { return false; + } } boolean this_present_aString = true && this.isSetAString(); boolean that_present_aString = true && that.isSetAString(); if (this_present_aString || that_present_aString) { - if (!(this_present_aString && that_present_aString)) + if (!(this_present_aString && that_present_aString)) { return false; - if (!this.aString.equals(that.aString)) + } + if (!this.aString.equals(that.aString)) { return false; + } } boolean this_present_lint = true && this.isSetLint(); boolean that_present_lint = true && that.isSetLint(); if (this_present_lint || that_present_lint) { - if (!(this_present_lint && that_present_lint)) + if (!(this_present_lint && that_present_lint)) { return false; - if (!this.lint.equals(that.lint)) + } + if (!this.lint.equals(that.lint)) { return false; + } } boolean this_present_lString = true && this.isSetLString(); boolean that_present_lString = true && that.isSetLString(); if (this_present_lString || that_present_lString) { - if (!(this_present_lString && that_present_lString)) + if (!(this_present_lString && that_present_lString)) { return false; - if (!this.lString.equals(that.lString)) + } + if (!this.lString.equals(that.lString)) { return false; + } } boolean this_present_lintString = true && this.isSetLintString(); boolean that_present_lintString = true && that.isSetLintString(); if (this_present_lintString || that_present_lintString) { - if (!(this_present_lintString && that_present_lintString)) + if (!(this_present_lintString && that_present_lintString)) { return false; - if (!this.lintString.equals(that.lintString)) + } + if (!this.lintString.equals(that.lintString)) { return false; + } } boolean this_present_mStringString = true && this.isSetMStringString(); boolean that_present_mStringString = true && that.isSetMStringString(); if (this_present_mStringString || that_present_mStringString) { - if (!(this_present_mStringString && that_present_mStringString)) + if (!(this_present_mStringString && that_present_mStringString)) { return false; - if (!this.mStringString.equals(that.mStringString)) + } + if (!this.mStringString.equals(that.mStringString)) { return false; + } } return true; @@ -615,33 +633,39 @@ public int hashCode() { boolean present_aint = true; builder.append(present_aint); - if (present_aint) + if (present_aint) { builder.append(aint); + } boolean present_aString = true && (isSetAString()); builder.append(present_aString); - if (present_aString) + if (present_aString) { builder.append(aString); + } boolean present_lint = true && (isSetLint()); builder.append(present_lint); - if (present_lint) + if (present_lint) { builder.append(lint); + } boolean present_lString = true && (isSetLString()); builder.append(present_lString); - if (present_lString) + if (present_lString) { builder.append(lString); + } boolean present_lintString = true && (isSetLintString()); builder.append(present_lintString); - if (present_lintString) + if (present_lintString) { builder.append(lintString); + } boolean present_mStringString = true && (isSetMStringString()); builder.append(present_mStringString); - if (present_mStringString) + if (present_mStringString) { builder.append(mStringString); + } return builder.toHashCode(); } @@ -737,7 +761,9 @@ public String toString() { sb.append("aint:"); sb.append(this.aint); first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("aString:"); if (this.aString == null) { sb.append("null"); @@ -745,7 +771,9 @@ public String toString() { sb.append(this.aString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lint:"); if (this.lint == null) { sb.append("null"); @@ -753,7 +781,9 @@ public String toString() { sb.append(this.lint); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lString:"); if (this.lString == null) { sb.append("null"); @@ -761,7 +791,9 @@ public String toString() { sb.append(this.lString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lintString:"); if (this.lintString == null) { sb.append("null"); @@ -769,7 +801,9 @@ public String toString() { sb.append(this.lintString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("mStringString:"); if (this.mStringString == null) { sb.append("null"); diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java index 3738fc1a235c2..2803cad8095dd 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -37,26 +37,26 @@ public abstract class WriteAheadLog { * ensure that the written data is durable and readable (using the record handle) by the * time this function returns. */ - abstract public WriteAheadLogRecordHandle write(ByteBuffer record, long time); + public abstract WriteAheadLogRecordHandle write(ByteBuffer record, long time); /** * Read a written record based on the given record handle. */ - abstract public ByteBuffer read(WriteAheadLogRecordHandle handle); + public abstract ByteBuffer read(WriteAheadLogRecordHandle handle); /** * Read and return an iterator of all the records that have been written but not yet cleaned up. */ - abstract public Iterator readAll(); + public abstract Iterator readAll(); /** * Clean all the records that are older than the threshold time. It can wait for * the completion of the deletion. */ - abstract public void clean(long threshTime, boolean waitForCompletion); + public abstract void clean(long threshTime, boolean waitForCompletion); /** * Close this log and release any resources. */ - abstract public void close(); + public abstract void close(); } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 609bb4413b6b1..9722c60bba1c3 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1332,12 +1332,12 @@ public Optional call(List values, Optional state) { public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; - List> initial = Arrays.asList ( + List> initial = Arrays.asList( new Tuple2<>("california", 1), new Tuple2<>("new york", 2)); JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); - JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); + JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD(tmpRDD); List>> expected = Arrays.asList( Arrays.asList(new Tuple2<>("california", 5), diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java index eac4cdd14a683..89d0bb7b617e4 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java @@ -95,7 +95,7 @@ public Double call(Optional one, State state) { JavaTrackStateDStream stateDstream2 = wordsDstream.trackStateByKey( - StateSpec. function(trackStateFunc2) + StateSpec.function(trackStateFunc2) .initialState(initialRDD) .numPartitions(10) .partitioner(new HashPartitioner(10)) diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java index 1b0c416b0fe4e..83279e5e93c0e 100644 --- a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java @@ -18,6 +18,7 @@ package org.apache.spark.tags; import java.lang.annotation.*; + import org.scalatest.TagAnnotation; @TagAnnotation diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java index 2a631bfc88cf0..108300168e173 100644 --- a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java @@ -18,6 +18,7 @@ package org.apache.spark.tags; import java.lang.annotation.*; + import org.scalatest.TagAnnotation; @TagAnnotation diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 4bd3fd7772079..5b61386808769 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -900,9 +900,9 @@ public int levenshteinDistance(UTF8String other) { m = swap; } - int p[] = new int[n + 1]; - int d[] = new int[n + 1]; - int swap[]; + int[] p = new int[n + 1]; + int[] d = new int[n + 1]; + int[] swap; int i, i_bytes, j, j_bytes, num_bytes_j, cost; @@ -965,7 +965,7 @@ public UTF8String soundex() { // first character must be a letter return this; } - byte sx[] = {'0', '0', '0', '0'}; + byte[] sx = {'0', '0', '0', '0'}; sx[0] = b; int sxi = 1; int idx = b - 'A'; From ea01d530725a80ed61b4613bc672849bb1a51bd4 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 4 Dec 2015 12:08:42 -0800 Subject: [PATCH 1037/1168] [SPARK-12058][STREAMING][KINESIS][TESTS] fix Kinesis python tests Python tests require access to the `KinesisTestUtils` file. When this file exists under src/test, python can't access it, since it is not available in the assembly jar. However, if we move KinesisTestUtils to src/main, we need to add the KinesisProducerLibrary as a dependency. In order to avoid this, I moved KinesisTestUtils to src/main, and extended it with ExtendedKinesisTestUtils which is under src/test that adds support for the KPL. cc zsxwing tdas Author: Burak Yavuz Closes #10050 from brkyvz/kinesis-py. --- .../streaming/kinesis/KinesisTestUtils.scala | 88 +++++++++---------- .../kinesis/KPLBasedKinesisTestUtils.scala | 72 +++++++++++++++ .../kinesis/KinesisBackedBlockRDDSuite.scala | 2 +- .../kinesis/KinesisStreamSuite.scala | 2 +- python/pyspark/streaming/tests.py | 1 - 5 files changed, 115 insertions(+), 50 deletions(-) rename extras/kinesis-asl/src/{test => main}/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala (82%) create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala similarity index 82% rename from extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala rename to extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 7487aa1c12639..0ace453ee9280 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -31,13 +31,13 @@ import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient import com.amazonaws.services.dynamodbv2.document.DynamoDB import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.model._ -import com.amazonaws.services.kinesis.producer.{KinesisProducer, KinesisProducerConfiguration, UserRecordResult} -import com.google.common.util.concurrent.{FutureCallback, Futures} import org.apache.spark.Logging /** - * Shared utility methods for performing Kinesis tests that actually transfer data + * Shared utility methods for performing Kinesis tests that actually transfer data. + * + * PLEASE KEEP THIS FILE UNDER src/main AS PYTHON TESTS NEED ACCESS TO THIS FILE! */ private[kinesis] class KinesisTestUtils extends Logging { @@ -54,7 +54,7 @@ private[kinesis] class KinesisTestUtils extends Logging { @volatile private var _streamName: String = _ - private lazy val kinesisClient = { + protected lazy val kinesisClient = { val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials()) client.setEndpoint(endpointUrl) client @@ -66,14 +66,12 @@ private[kinesis] class KinesisTestUtils extends Logging { new DynamoDB(dynamoDBClient) } - private lazy val kinesisProducer: KinesisProducer = { - val conf = new KinesisProducerConfiguration() - .setRecordMaxBufferedTime(1000) - .setMaxConnections(1) - .setRegion(regionName) - .setMetricsLevel("none") - - new KinesisProducer(conf) + protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { + if (!aggregate) { + new SimpleDataGenerator(kinesisClient) + } else { + throw new UnsupportedOperationException("Aggregation is not supported through this code path") + } } def streamName: String = { @@ -104,41 +102,8 @@ private[kinesis] class KinesisTestUtils extends Logging { */ def pushData(testData: Seq[Int], aggregate: Boolean): Map[String, Seq[(Int, String)]] = { require(streamCreated, "Stream not yet created, call createStream() to create one") - val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() - - testData.foreach { num => - val str = num.toString - val data = ByteBuffer.wrap(str.getBytes()) - if (aggregate) { - val future = kinesisProducer.addUserRecord(streamName, str, data) - val kinesisCallBack = new FutureCallback[UserRecordResult]() { - override def onFailure(t: Throwable): Unit = {} // do nothing - - override def onSuccess(result: UserRecordResult): Unit = { - val shardId = result.getShardId - val seqNumber = result.getSequenceNumber() - val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, - new ArrayBuffer[(Int, String)]()) - sentSeqNumbers += ((num, seqNumber)) - } - } - - Futures.addCallback(future, kinesisCallBack) - kinesisProducer.flushSync() // make sure we send all data before returning the map - } else { - val putRecordRequest = new PutRecordRequest().withStreamName(streamName) - .withData(data) - .withPartitionKey(str) - - val putRecordResult = kinesisClient.putRecord(putRecordRequest) - val shardId = putRecordResult.getShardId - val seqNumber = putRecordResult.getSequenceNumber() - val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, - new ArrayBuffer[(Int, String)]()) - sentSeqNumbers += ((num, seqNumber)) - } - } - + val producer = getProducer(aggregate) + val shardIdToSeqNumbers = producer.sendData(streamName, testData) logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") shardIdToSeqNumbers.toMap } @@ -264,3 +229,32 @@ private[kinesis] object KinesisTestUtils { } } } + +/** A wrapper interface that will allow us to consolidate the code for synthetic data generation. */ +private[kinesis] trait KinesisDataGenerator { + /** Sends the data to Kinesis and returns the metadata for everything that has been sent. */ + def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] +} + +private[kinesis] class SimpleDataGenerator( + client: AmazonKinesisClient) extends KinesisDataGenerator { + override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = { + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + data.foreach { num => + val str = num.toString + val data = ByteBuffer.wrap(str.getBytes()) + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(data) + .withPartitionKey(str) + + val putRecordResult = client.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + + shardIdToSeqNumbers.toMap + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala new file mode 100644 index 0000000000000..fdb270eaad8c9 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult} +import com.google.common.util.concurrent.{FutureCallback, Futures} + +private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils { + override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { + if (!aggregate) { + new SimpleDataGenerator(kinesisClient) + } else { + new KPLDataGenerator(regionName) + } + } +} + +/** A wrapper for the KinesisProducer provided in the KPL. */ +private[kinesis] class KPLDataGenerator(regionName: String) extends KinesisDataGenerator { + + private lazy val producer: KPLProducer = { + val conf = new KinesisProducerConfiguration() + .setRecordMaxBufferedTime(1000) + .setMaxConnections(1) + .setRegion(regionName) + .setMetricsLevel("none") + + new KPLProducer(conf) + } + + override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = { + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + data.foreach { num => + val str = num.toString + val data = ByteBuffer.wrap(str.getBytes()) + val future = producer.addUserRecord(streamName, str, data) + val kinesisCallBack = new FutureCallback[UserRecordResult]() { + override def onFailure(t: Throwable): Unit = {} // do nothing + + override def onSuccess(result: UserRecordResult): Unit = { + val shardId = result.getShardId + val seqNumber = result.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + } + Futures.addCallback(future, kinesisCallBack) + } + producer.flushSync() + shardIdToSeqNumbers.toMap + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 52c61dfb1c023..d85b4cda8ce98 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -40,7 +40,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) override def beforeAll(): Unit = { runIfTestsEnabled("Prepare KinesisTestUtils") { - testUtils = new KinesisTestUtils() + testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() shardIdToDataAndSeqNumbers = testUtils.pushData(testData, aggregate = aggregateTestData) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index dee30444d8cc6..78cec021b78c1 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -63,7 +63,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun sc = new SparkContext(conf) runIfTestsEnabled("Prepare KinesisTestUtils") { - testUtils = new KinesisTestUtils() + testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() } } diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index d50c6b8d4a428..a2bfd79e1abcd 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1458,7 +1458,6 @@ def test_kinesis_stream_api(self): InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, "awsAccessKey", "awsSecretKey") - @unittest.skip("Enable it when we fix SPAKR-12058") def test_kinesis_stream(self): if not are_kinesis_tests_enabled: sys.stderr.write( From def672c5d0a5cce1448257648fd47d89a87f1084 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 4 Dec 2015 13:05:07 -0800 Subject: [PATCH 1038/1168] [SPARK-11314][BUILD][HOTFIX] Add exclusion for moved YARN classes. Author: Marcelo Vanzin Closes #10147 from vanzin/SPARK-11314. --- project/MimaExcludes.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d3a3c0ceb68c8..b4aa6adc3c620 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -159,7 +159,10 @@ object MimaExcludes { // SPARK-3580 Add getNumPartitions method to JavaRDD ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") - ) + ) ++ + // SPARK-11314: YARN backend moved to yarn sub-module and MiMA complains even though it's a + // private class. + MimaBuild.excludeSparkClass("scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") case v if v.startsWith("1.5") => Seq( MimaBuild.excludeSparkPackage("network"), From dc1274598d00d0928c927f553fadb43b574814aa Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 5 Dec 2015 08:15:30 +0800 Subject: [PATCH 1039/1168] [SPARK-12112][BUILD] Upgrade to SBT 0.13.9 We should upgrade to SBT 0.13.9, since this is a requirement in order to use SBT's new Maven-style resolution features (which will be done in a separate patch, because it's blocked by some binary compatibility issues in the POM reader plugin). I also upgraded Scalastyle to version 0.8.0, which was necessary in order to fix a Scala 2.10.5 compatibility issue (see https://github.com/scalastyle/scalastyle/issues/156). The newer Scalastyle is slightly stricter about whitespace surrounding tokens, so I fixed the new style violations. Author: Josh Rosen Closes #10112 from JoshRosen/upgrade-to-sbt-0.13.9. --- .../org/apache/spark/deploy/JsonProtocol.scala | 2 +- .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 10 +++++----- .../spark/serializer/KryoSerializerSuite.scala | 8 ++++---- .../streaming/clickstream/PageViewGenerator.scala | 4 ++-- .../streaming/clickstream/PageViewStream.scala | 6 ++++-- .../streaming/flume/sink/SparkSinkSuite.scala | 4 ++-- .../spark/graphx/lib/TriangleCountSuite.scala | 2 +- .../scala/org/apache/spark/ml/param/params.scala | 2 ++ .../mllib/stat/test/StreamingTestMethod.scala | 4 ++-- .../DecisionTreeClassifierSuite.scala | 4 ++-- .../ml/regression/DecisionTreeRegressorSuite.scala | 6 +++--- .../spark/mllib/tree/DecisionTreeSuite.scala | 14 +++++++------- pom.xml | 2 +- project/build.properties | 2 +- project/plugins.sbt | 7 +------ .../spark/sql/catalyst/expressions/CastSuite.scala | 2 +- .../datasources/parquet/ParquetRelation.scala | 8 ++++---- .../sql/execution/columnar/ColumnTypeSuite.scala | 2 +- .../compression/RunLengthEncodingSuite.scala | 4 ++-- .../sql/execution/metric/SQLMetricsSuite.scala | 2 +- 20 files changed, 47 insertions(+), 48 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index ccffb36652988..220b20bf7cbd1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -45,7 +45,7 @@ private[deploy] object JsonProtocol { ("id" -> obj.id) ~ ("name" -> obj.desc.name) ~ ("cores" -> obj.desc.maxCores) ~ - ("user" -> obj.desc.user) ~ + ("user" -> obj.desc.user) ~ ("memoryperslave" -> obj.desc.memoryPerExecutorMB) ~ ("submitdate" -> obj.submitDate.toString) ~ ("state" -> obj.state.toString) ~ diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 46ed5c04f4338..007a71f87cf10 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -101,21 +101,21 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } test("SparkContext.union creates UnionRDD if at least one RDD has no partitioner") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) - val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2 -> true)) val unionRdd = sc.union(rddWithNoPartitioner, rddWithPartitioner) assert(unionRdd.isInstanceOf[UnionRDD[_]]) } test("SparkContext.union creates PartitionAwareUnionRDD if all RDDs have partitioners") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) val unionRdd = sc.union(rddWithPartitioner, rddWithPartitioner) assert(unionRdd.isInstanceOf[PartitionerAwareUnionRDD[_]]) } test("PartitionAwareUnionRDD raises exception if at least one RDD has no partitioner") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) - val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2 -> true)) intercept[IllegalArgumentException] { new PartitionerAwareUnionRDD(sc, Seq(rddWithNoPartitioner, rddWithPartitioner)) } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index e428414cf6e85..f81fe3113106f 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -144,10 +144,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { check(mutable.Map("one" -> 1, "two" -> 2)) check(mutable.HashMap(1 -> "one", 2 -> "two")) check(mutable.HashMap("one" -> 1, "two" -> 2)) - check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List(Some(mutable.HashMap(1 -> 1, 2 -> 2)), None, Some(mutable.HashMap(3 -> 4)))) check(List( mutable.HashMap("one" -> 1, "two" -> 2), - mutable.HashMap(1->"one", 2->"two", 3->"three"))) + mutable.HashMap(1 -> "one", 2 -> "two", 3 -> "three"))) } test("Bug: SPARK-10251") { @@ -174,10 +174,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { check(mutable.Map("one" -> 1, "two" -> 2)) check(mutable.HashMap(1 -> "one", 2 -> "two")) check(mutable.HashMap("one" -> 1, "two" -> 2)) - check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List(Some(mutable.HashMap(1 -> 1, 2 -> 2)), None, Some(mutable.HashMap(3 -> 4)))) check(List( mutable.HashMap("one" -> 1, "two" -> 2), - mutable.HashMap(1->"one", 2->"two", 3->"three"))) + mutable.HashMap(1 -> "one", 2 -> "two", 3 -> "three"))) } test("ranges") { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index bea7a47cb2855..2fcccb22dddf7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -51,8 +51,8 @@ object PageView extends Serializable { */ // scalastyle:on object PageViewGenerator { - val pages = Map("http://foo.com/" -> .7, - "http://foo.com/news" -> 0.2, + val pages = Map("http://foo.com/" -> .7, + "http://foo.com/news" -> 0.2, "http://foo.com/contact" -> .1) val httpStatus = Map(200 -> .95, 404 -> .05) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 4ef238606f82e..723616817f6a2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -86,8 +86,10 @@ object PageViewStream { .map("Unique active users: " + _) // An external dataset we want to join to this stream - val userList = ssc.sparkContext.parallelize( - Map(1 -> "Patrick Wendell", 2 -> "Reynold Xin", 3 -> "Matei Zaharia").toSeq) + val userList = ssc.sparkContext.parallelize(Seq( + 1 -> "Patrick Wendell", + 2 -> "Reynold Xin", + 3 -> "Matei Zaharia")) metric match { case "pageCounts" => pageCounts.print() diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index d2654700ea729..941fde45cd7b7 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -36,11 +36,11 @@ import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory // Spark core main, which has too many dependencies to require here manually. // For this reason, we continue to use FunSuite and ignore the scalastyle checks // that fail if this is detected. -//scalastyle:off +// scalastyle:off import org.scalatest.FunSuite class SparkSinkSuite extends FunSuite { -//scalastyle:on +// scalastyle:on val eventsPerBatch = 1000 val channelCapacity = 5000 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index c47552cf3a3bd..608e43cf3ff53 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -26,7 +26,7 @@ class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle") { withSpark { sc => - val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2) + val rawEdges = sc.parallelize(Array( 0L -> 1L, 1L -> 2L, 2L -> 0L ), 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() val triangleCount = graph.triangleCount() val verts = triangleCount.vertices diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d182b0a98896c..ee7e89edd8798 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -82,7 +82,9 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali def w(value: T): ParamPair[T] = this -> value /** Creates a param pair with the given value (for Scala). */ + // scalastyle:off def ->(value: T): ParamPair[T] = ParamPair(this, value) + // scalastyle:on /** Encodes a param value into JSON, which can be decoded by [[jsonDecode()]]. */ def jsonEncode(value: T): String = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala index a7eaed51b4d55..911b4b9237356 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala @@ -152,8 +152,8 @@ private[stat] object StudentTTest extends StreamingTestMethod with Logging { private[stat] object StreamingTestMethod { // Note: after new `StreamingTestMethod`s are implemented, please update this map. private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map( - "welch"->WelchTTest, - "student"->StudentTTest) + "welch" -> WelchTTest, + "student" -> StudentTTest) def getTestMethodFromName(method: String): StreamingTestMethod = TEST_NAME_TO_OBJECT.get(method) match { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 92b8f84144ab0..fda2711fed0fd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -73,7 +73,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setMaxDepth(2) .setMaxBins(100) .setSeed(1) - val categoricalFeatures = Map(0 -> 3, 1-> 3) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) val numClasses = 2 compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) } @@ -214,7 +214,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setMaxBins(2) .setMaxDepth(2) .setMinInstancesPerNode(2) - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) val numClasses = 2 compareAPIs(rdd, dt, categoricalFeatures, numClasses) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index e0d5afa7a7e97..6999a910c34a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -50,7 +50,7 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setMaxDepth(2) .setMaxBins(100) .setSeed(1) - val categoricalFeatures = Map(0 -> 3, 1-> 3) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } @@ -59,12 +59,12 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setImpurity("variance") .setMaxDepth(2) .setMaxBins(100) - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } test("copied model must have the same parent") { - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) val model = new DecisionTreeRegressor() .setImpurity("variance") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 1a4299db4eab2..bf8fe1acac2fe 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -64,7 +64,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, numClasses = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -178,7 +178,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, numClasses = 100, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(metadata.isUnordered(featureIndex = 0)) @@ -237,7 +237,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, numClasses = 100, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) + categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) // 2^(10-1) - 1 > 100, so categorical features will be ordered val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) @@ -421,7 +421,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 2, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -455,7 +455,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { Variance, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -484,7 +484,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { Variance, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -788,7 +788,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, - maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), + maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2), numClasses = 2, minInstancesPerNode = 2) val rootNode = DecisionTree.train(rdd, strategy).topNode diff --git a/pom.xml b/pom.xml index 16e656d11961d..ae2ff8878b0a5 100644 --- a/pom.xml +++ b/pom.xml @@ -2235,7 +2235,7 @@ org.scalastyle scalastyle-maven-plugin - 0.7.0 + 0.8.0 false true diff --git a/project/build.properties b/project/build.properties index 064ec843da9ea..86ca8755820a4 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.7 +sbt.version=0.13.9 diff --git a/project/plugins.sbt b/project/plugins.sbt index c06687d8f197b..5e23224cf8aa5 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -10,14 +10,9 @@ addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") -// For Sonatype publishing -//resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) - -//addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6") - addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.7.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index ab77a764483e8..a98e16c253214 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -734,7 +734,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val complex = Literal.create( Row( Seq("123", "true", "f"), - Map("a" ->"123", "b" -> "true", "c" -> "f"), + Map("a" -> "123", "b" -> "true", "c" -> "f"), Row(0)), StructType(Seq( StructField("a", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index fdd745f48e973..bb3e2786978c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -862,9 +862,9 @@ private[sql] object ParquetRelation extends Logging { // The parquet compression short names val shortParquetCompressionCodecNames = Map( - "NONE" -> CompressionCodecName.UNCOMPRESSED, + "NONE" -> CompressionCodecName.UNCOMPRESSED, "UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED, - "SNAPPY" -> CompressionCodecName.SNAPPY, - "GZIP" -> CompressionCodecName.GZIP, - "LZO" -> CompressionCodecName.LZO) + "SNAPPY" -> CompressionCodecName.SNAPPY, + "GZIP" -> CompressionCodecName.GZIP, + "LZO" -> CompressionCodecName.LZO) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 34dd96929e6c1..706ff1f998501 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -35,7 +35,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { test("defaultSize") { val checks = Map( - NULL-> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, + NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12, STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 16, MAP_TYPE -> 32) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index ce3affba55c71..95642e93ae9f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -100,11 +100,11 @@ class RunLengthEncodingSuite extends SparkFunSuite { } test(s"$RunLengthEncoding with $typeName: simple case") { - skeleton(2, Seq(0 -> 2, 1 ->2)) + skeleton(2, Seq(0 -> 2, 1 -> 2)) } test(s"$RunLengthEncoding with $typeName: run length == 1") { - skeleton(2, Seq(0 -> 1, 1 ->1)) + skeleton(2, Seq(0 -> 1, 1 -> 1)) } test(s"$RunLengthEncoding with $typeName: single long run") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 4f2cad19bfb6b..4339f7260dcb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -116,7 +116,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) val df = person.select('name) testSparkPlanMetrics(df, 1, Map( - 0L ->("Project", Map( + 0L -> ("Project", Map( "number of rows" -> 2L))) ) } From 5c49b17849365ab67844735151cd49112e5e5617 Mon Sep 17 00:00:00 2001 From: meiyoula <1039320815@qq.com> Date: Fri, 4 Dec 2015 16:50:40 -0800 Subject: [PATCH 1040/1168] [SPARK-12142][CORE]Reply false when container allocator is not ready and reset target Using Dynamic Allocation function, when a new AM is starting, and ExecutorAllocationManager send RequestExecutor message to AM. If the container allocator is not ready, the whole app will hang on Author: meiyoula <1039320815@qq.com> Closes #10138 from XuTingjun/patch-1. --- .../scala/org/apache/spark/ExecutorAllocationManager.scala | 1 + .../scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 6419218f47c85..34c32ce312964 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -370,6 +370,7 @@ private[spark] class ExecutorAllocationManager( } else { logWarning( s"Unable to reach the cluster manager to request $numExecutorsTarget total executors!") + numExecutorsTarget = oldNumExecutorsTarget 0 } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 13ef4dfd64165..1970f7d150feb 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -600,11 +600,12 @@ private[spark] class ApplicationMaster( localityAwareTasks, hostToLocalTaskCount)) { resetAllocatorInterval() } + context.reply(true) case None => logWarning("Container allocator is not ready to request executors yet.") + context.reply(false) } - context.reply(true) case KillExecutors(executorIds) => logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") From 0a7f72268ceaded245473e8e0e373067c3aedf1d Mon Sep 17 00:00:00 2001 From: rotems Date: Fri, 4 Dec 2015 16:58:31 -0800 Subject: [PATCH 1041/1168] [SPARK-12080][CORE] Kryo - Support multiple user registrators Author: rotems Closes #10078 from Botnaim/KryoMultipleCustomRegistrators. --- .../scala/org/apache/spark/serializer/KryoSerializer.scala | 6 ++++-- docs/configuration.md | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index d5ba690ed04be..7b77f78ce6f1a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -70,7 +70,9 @@ class KryoSerializer(conf: SparkConf) private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) - private val userRegistrator = conf.getOption("spark.kryo.registrator") + private val userRegistrators = conf.get("spark.kryo.registrator", "") + .split(',') + .filter(!_.isEmpty) private val classesToRegister = conf.get("spark.kryo.classesToRegister", "") .split(',') .filter(!_.isEmpty) @@ -119,7 +121,7 @@ class KryoSerializer(conf: SparkConf) classesToRegister .foreach { className => kryo.register(Class.forName(className, true, classLoader)) } // Allow the user to register their own classes by setting spark.kryo.registrator. - userRegistrator + userRegistrators .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } // scalastyle:on classforname diff --git a/docs/configuration.md b/docs/configuration.md index c39b4890851bc..fd61ddc244f44 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -647,10 +647,10 @@ Apart from these, the following properties are also available, and may be useful spark.kryo.registrator (none) - If you use Kryo serialization, set this class to register your custom classes with Kryo. This + If you use Kryo serialization, give a comma-separated list of classes that register your custom classes with Kryo. This property is useful if you need to register your classes in a custom way, e.g. to specify a custom field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be - set to a class that extends + set to classes that extend KryoRegistrator. See the tuning guide for more details. From 40d8a512155d74c6a2a70f6e85deec80ba59a197 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 4 Dec 2015 17:02:04 -0800 Subject: [PATCH 1042/1168] [SPARK-12084][CORE] Fix codes that uses ByteBuffer.array incorrectly `ByteBuffer` doesn't guarantee all contents in `ByteBuffer.array` are valid. E.g, a ByteBuffer returned by `ByteBuffer.slice`. We should not use the whole content of `ByteBuffer` unless we know that's correct. This patch fixed all places that use `ByteBuffer.array` incorrectly. Author: Shixiong Zhu Closes #10083 from zsxwing/bytebuffer-array. --- .../network/netty/NettyBlockTransferService.scala | 12 +++--------- .../org/apache/spark/scheduler/DAGScheduler.scala | 6 ++++-- .../scala/org/apache/spark/scheduler/Task.scala | 4 ++-- .../spark/serializer/GenericAvroSerializer.scala | 5 ++++- .../apache/spark/serializer/KryoSerializer.scala | 4 ++-- .../spark/storage/TachyonBlockManager.scala | 2 +- .../main/scala/org/apache/spark/util/Utils.scala | 15 ++++++++++++++- .../unsafe/map/AbstractBytesToBytesMapSuite.java | 5 +++-- .../apache/spark/scheduler/TaskContextSuite.scala | 3 ++- .../pythonconverters/AvroConverters.scala | 5 ++++- .../spark/streaming/flume/FlumeInputDStream.scala | 6 +++--- .../streaming/flume/FlumePollingStreamSuite.scala | 4 ++-- .../spark/streaming/flume/FlumeStreamSuite.scala | 4 ++-- .../streaming/kinesis/KinesisStreamSuite.scala | 3 ++- .../parquet/UnsafeRowParquetRecordReader.java | 7 ++++--- .../spark/sql/execution/SparkSqlSerializer.scala | 3 ++- .../columnar/InMemoryColumnarTableScan.scala | 5 ++++- .../parquet/CatalystRowConverter.scala | 8 ++++---- .../scheduler/ReceivedBlockTracker.scala | 5 +++-- .../streaming/util/BatchedWriteAheadLog.scala | 15 ++++++--------- .../util/FileBasedWriteAheadLogWriter.scala | 14 +++----------- .../spark/streaming/JavaWriteAheadLogSuite.java | 15 +++++++-------- 22 files changed, 81 insertions(+), 69 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 82c16e855b0c0..40604a4da18d5 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -30,6 +30,7 @@ import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -123,17 +124,10 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded // using our binary protocol. - val levelBytes = serializer.newInstance().serialize(level).array() + val levelBytes = JavaUtils.bufferToArray(serializer.newInstance().serialize(level)) // Convert or copy nio buffer into array in order to serialize it. - val nioBuffer = blockData.nioByteBuffer() - val array = if (nioBuffer.hasArray) { - nioBuffer.array() - } else { - val data = new Array[Byte](nioBuffer.remaining()) - nioBuffer.get(data) - data - } + val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer, new RpcResponseCallback { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index e01a9609b9a0d..5582720bbcff2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -34,6 +34,7 @@ import org.apache.commons.lang3.SerializationUtils import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics +import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcTimeout @@ -997,9 +998,10 @@ class DAGScheduler( // For ResultTask, serialize and broadcast (rdd, func). val taskBinaryBytes: Array[Byte] = stage match { case stage: ShuffleMapStage => - closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() + JavaUtils.bufferToArray( + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) case stage: ResultStage => - closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array() + JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) } taskBinary = sc.broadcast(taskBinaryBytes) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 2fcd5aa57d11b..5fe5ae8c45819 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -191,8 +191,8 @@ private[spark] object Task { // Write the task itself and finish dataOut.flush() - val taskBytes = serializer.serialize(task).array() - out.write(taskBytes) + val taskBytes = serializer.serialize(task) + Utils.writeByteBuffer(taskBytes, out) ByteBuffer.wrap(out.toByteArray) } diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 62f8aae7f2126..8d6af9cae8927 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -81,7 +81,10 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) * seen values so to limit the number of times that decompression has to be done. */ def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, { - val bis = new ByteArrayInputStream(schemaBytes.array()) + val bis = new ByteArrayInputStream( + schemaBytes.array(), + schemaBytes.arrayOffset() + schemaBytes.position(), + schemaBytes.remaining()) val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) new Schema.Parser().parse(new String(bytes, "UTF-8")) }) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 7b77f78ce6f1a..62d445f3d7bd9 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -309,7 +309,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val kryo = borrowKryo() try { - input.setBuffer(bytes.array) + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { releaseKryo(kryo) @@ -321,7 +321,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ val oldClassLoader = kryo.getClassLoader try { kryo.setClassLoader(loader) - input.setBuffer(bytes.array) + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { kryo.setClassLoader(oldClassLoader) diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 22878783fca67..d14fe4613528a 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -103,7 +103,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log val file = getFile(blockId) val os = file.getOutStream(WriteType.TRY_CACHE) try { - os.write(bytes.array()) + Utils.writeByteBuffer(bytes, os) } catch { case NonFatal(e) => logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index af632349c9cae..9dbe66e7eefbd 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -178,7 +178,20 @@ private[spark] object Utils extends Logging { /** * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] */ - def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = { + def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = { + if (bb.hasArray) { + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } else { + val bbval = new Array[Byte](bb.remaining()) + bb.get(bbval) + out.write(bbval) + } + } + + /** + * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]] + */ + def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index a5c583f9f2844..8724a34988421 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -41,6 +41,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; @@ -430,7 +431,7 @@ public void randomizedStressTest() { } for (Map.Entry entry : expected.entrySet()) { - final byte[] key = entry.getKey().array(); + final byte[] key = JavaUtils.bufferToArray(entry.getKey()); final byte[] value = entry.getValue(); final BytesToBytesMap.Location loc = map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); @@ -480,7 +481,7 @@ public void randomizedTestWithRecordsLargerThanPageSize() { } } for (Map.Entry entry : expected.entrySet()) { - final byte[] key = entry.getKey().array(); + final byte[] key = JavaUtils.bufferToArray(entry.getKey()); final byte[] value = entry.getValue(); final BytesToBytesMap.Location loc = map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 450ab7b9fe92b..d83d0aee42254 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -23,6 +23,7 @@ import org.mockito.Matchers.any import org.scalatest.BeforeAndAfter import org.apache.spark._ +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} import org.apache.spark.metrics.source.JvmSource @@ -57,7 +58,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array) + val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty) intercept[RuntimeException] { diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index 805184e740f06..cf12c98b4af6c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -79,7 +79,10 @@ object AvroConversionUtil extends Serializable { def unpackBytes(obj: Any): Array[Byte] = { val bytes: Array[Byte] = obj match { - case buf: java.nio.ByteBuffer => buf.array() + case buf: java.nio.ByteBuffer => + val arr = new Array[Byte](buf.remaining()) + buf.get(arr) + arr case arr: Array[Byte] => arr case other => throw new SparkException( s"Unknown BYTES type ${other.getClass.getName}") diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index c8780aa83bdbd..2b9116eb3c790 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -93,9 +93,9 @@ class SparkFlumeEvent() extends Externalizable { /* Serialize to bytes. */ def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - val body = event.getBody.array() - out.writeInt(body.length) - out.write(body) + val body = event.getBody + out.writeInt(body.remaining()) + Utils.writeByteBuffer(body, out) val numHeaders = event.getHeaders.size() out.writeInt(numHeaders) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 5fd2711f5f7df..bb951a6ef100d 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -24,11 +24,11 @@ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets.UTF_8 import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} @@ -119,7 +119,7 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log val headers = flattenOutputBuffer.map(_.event.getHeaders.asScala.map { case (key, value) => (key.toString, value.toString) }).map(_.asJava) - val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) + val bodies = flattenOutputBuffer.map(e => JavaUtils.bytesToString(e.event.getBody)) utils.assertOutput(headers.asJava, bodies.asJava) } } finally { diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index f315e0a7ca23c..b29e591c07374 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -22,7 +22,6 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory @@ -31,6 +30,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} @@ -63,7 +63,7 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w event => event.getHeaders.get("test") should be("header") } - val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) + val output = outputEvents.map(event => JavaUtils.bytesToString(event.getBody)) output should be (input) } } finally { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 78cec021b78c1..6fe24fe81165b 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ @@ -196,7 +197,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun testIfEnabled("custom message handling") { val awsCredentials = KinesisTestUtils.getAWSCredentials() - def addFive(r: Record): Int = new String(r.getData.array()).toInt + 5 + def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5 val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, Seconds(10), StorageLevel.MEMORY_ONLY, addFive, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index dade488ca281b..0cc4566c9cdde 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -332,12 +332,13 @@ private void decodeBinaryBatch(int col, int num) throws IOException { for (int n = 0; n < num; ++n) { if (columnReaders[col].next()) { ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer(); - int len = bytes.limit() - bytes.position(); + int len = bytes.remaining(); if (originalTypes[col] == OriginalType.UTF8) { - UTF8String str = UTF8String.fromBytes(bytes.array(), bytes.position(), len); + UTF8String str = + UTF8String.fromBytes(bytes.array(), bytes.arrayOffset() + bytes.position(), len); rowWriters[n].write(col, str); } else { - rowWriters[n].write(col, bytes.array(), bytes.position(), len); + rowWriters[n].write(col, bytes.array(), bytes.arrayOffset() + bytes.position(), len); } rows[n].setNotNullAt(col); } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 8317f648ccb4e..45a8e03248267 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -26,6 +26,7 @@ import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Kryo, Serializer} import com.twitter.chill.ResourcePool +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} import org.apache.spark.sql.types.Decimal import org.apache.spark.util.MutablePair @@ -76,7 +77,7 @@ private[sql] object SparkSqlSerializer { def serialize[T: ClassTag](o: T): Array[Byte] = acquireRelease { k => - k.serialize(o).array() + JavaUtils.bufferToArray(k.serialize(o)) } def deserialize[T: ClassTag](bytes: Array[Byte]): T = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index ce701fb3a7f28..3c5a8cb2aa935 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar import scala.collection.mutable.ArrayBuffer +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -163,7 +164,9 @@ private[sql] case class InMemoryRelation( .flatMap(_.values)) batchStats += stats - CachedBatch(rowCount, columnBuilders.map(_.build().array()), stats) + CachedBatch(rowCount, columnBuilders.map { builder => + JavaUtils.bufferToArray(builder.build()) + }, stats) } def hasNext: Boolean = rowIterator.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 94298fae2d69b..8851bc23cd050 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -327,8 +327,8 @@ private[parquet] class CatalystRowConverter( // are using `Binary.toByteBuffer.array()` to steal the underlying byte array without copying // it. val buffer = value.toByteBuffer - val offset = buffer.position() - val numBytes = buffer.limit() - buffer.position() + val offset = buffer.arrayOffset() + buffer.position() + val numBytes = buffer.remaining() updater.set(UTF8String.fromBytes(buffer.array(), offset, numBytes)) } } @@ -644,8 +644,8 @@ private[parquet] object CatalystRowConverter { // copying it. val buffer = binary.toByteBuffer val bytes = buffer.array() - val start = buffer.position() - val end = buffer.limit() + val start = buffer.arrayOffset() + buffer.position() + val end = buffer.arrayOffset() + buffer.limit() var unscaled = 0L var i = start diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 500dc70c98506..4dab64d696b3e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -27,6 +27,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.network.util.JavaUtils import org.apache.spark.streaming.Time import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} @@ -210,9 +211,9 @@ private[streaming] class ReceivedBlockTracker( writeAheadLogOption.foreach { writeAheadLog => logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") writeAheadLog.readAll().asScala.foreach { byteBuffer => - logTrace("Recovering record " + byteBuffer) + logInfo("Recovering record " + byteBuffer) Utils.deserialize[ReceivedBlockTrackerLogEvent]( - byteBuffer.array, Thread.currentThread().getContextClassLoader) match { + JavaUtils.bufferToArray(byteBuffer), Thread.currentThread().getContextClassLoader) match { case BlockAdditionEvent(receivedBlockInfo) => insertAddedBlock(receivedBlockInfo) case BatchAllocationEvent(time, allocatedBlocks) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 6e6ed8d819721..7158abc08894a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -28,6 +28,7 @@ import scala.concurrent.duration._ import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils /** @@ -197,17 +198,10 @@ private[util] object BatchedWriteAheadLog { */ case class Record(data: ByteBuffer, time: Long, promise: Promise[WriteAheadLogRecordHandle]) - /** Copies the byte array of a ByteBuffer. */ - private def getByteArray(buffer: ByteBuffer): Array[Byte] = { - val byteArray = new Array[Byte](buffer.remaining()) - buffer.get(byteArray) - byteArray - } - /** Aggregate multiple serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. */ def aggregate(records: Seq[Record]): ByteBuffer = { ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]]( - records.map(record => getByteArray(record.data)).toArray)) + records.map(record => JavaUtils.bufferToArray(record.data)).toArray)) } /** @@ -216,10 +210,13 @@ private[util] object BatchedWriteAheadLog { * method therefore needs to be backwards compatible. */ def deaggregate(buffer: ByteBuffer): Array[ByteBuffer] = { + val prevPosition = buffer.position() try { - Utils.deserialize[Array[Array[Byte]]](getByteArray(buffer)).map(ByteBuffer.wrap) + Utils.deserialize[Array[Array[Byte]]](JavaUtils.bufferToArray(buffer)).map(ByteBuffer.wrap) } catch { case _: ClassCastException => // users may restart a stream with batching enabled + // Restore `position` so that the user can read `buffer` later + buffer.position(prevPosition) Array(buffer) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala index e146bec32a456..1185f30265f63 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala @@ -24,6 +24,8 @@ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FSDataOutputStream +import org.apache.spark.util.Utils + /** * A writer for writing byte-buffers to a write ahead log file. */ @@ -48,17 +50,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: val lengthToWrite = data.remaining() val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite) stream.writeInt(lengthToWrite) - if (data.hasArray) { - stream.write(data.array()) - } else { - // If the buffer is not backed by an array, we transfer using temp array - // Note that despite the extra array copy, this should be faster than byte-by-byte copy - while (data.hasRemaining) { - val array = new Array[Byte](data.remaining) - data.get(array) - stream.write(array) - } - } + Utils.writeByteBuffer(data, stream: OutputStream) flush() nextOffset = stream.getPos() segment diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java index 09b5f8ed03279..f02fa87f6194b 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.streaming; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.nio.ByteBuffer; import java.util.Arrays; @@ -27,6 +26,7 @@ import com.google.common.base.Function; import com.google.common.collect.Iterators; import org.apache.spark.SparkConf; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.streaming.util.WriteAheadLog; import org.apache.spark.streaming.util.WriteAheadLogRecordHandle; import org.apache.spark.streaming.util.WriteAheadLogUtils; @@ -112,20 +112,19 @@ public void testCustomWAL() { WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null); String data1 = "data1"; - WriteAheadLogRecordHandle handle = - wal.write(ByteBuffer.wrap(data1.getBytes(StandardCharsets.UTF_8)), 1234); + WriteAheadLogRecordHandle handle = wal.write(JavaUtils.stringToBytes(data1), 1234); Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle); - Assert.assertEquals(new String(wal.read(handle).array(), StandardCharsets.UTF_8), data1); + Assert.assertEquals(JavaUtils.bytesToString(wal.read(handle)), data1); - wal.write(ByteBuffer.wrap("data2".getBytes(StandardCharsets.UTF_8)), 1235); - wal.write(ByteBuffer.wrap("data3".getBytes(StandardCharsets.UTF_8)), 1236); - wal.write(ByteBuffer.wrap("data4".getBytes(StandardCharsets.UTF_8)), 1237); + wal.write(JavaUtils.stringToBytes("data2"), 1235); + wal.write(JavaUtils.stringToBytes("data3"), 1236); + wal.write(JavaUtils.stringToBytes("data4"), 1237); wal.clean(1236, false); Iterator dataIterator = wal.readAll(); List readData = new ArrayList<>(); while (dataIterator.hasNext()) { - readData.add(new String(dataIterator.next().array(), StandardCharsets.UTF_8)); + readData.add(JavaUtils.bytesToString(dataIterator.next())); } Assert.assertEquals(readData, Arrays.asList("data3", "data4")); } From 29c481b2d5e90fb63e596b2acd8d612a3d6572ca Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sat, 5 Dec 2015 15:27:31 +0000 Subject: [PATCH 1043/1168] [SPARK-12096][MLLIB] remove the old constraint in word2vec jira: https://issues.apache.org/jira/browse/SPARK-12096 word2vec now can handle much bigger vocabulary. The old constraint vocabSize.toLong * vectorSize < Ine.max / 8 should be removed. new constraint is vocabSize.toLong * vectorSize < max array length (usually a little less than Int.MaxValue) I tested with vocabsize over 18M and vectorsize = 100. srowen jkbradley Sorry to miss this in last PR. I was reminded today. Author: Yuhao Yang Closes #10103 from hhbyyh/w2vCapacity. --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 655ac0bb5545b..be12d45286034 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -306,10 +306,10 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) - if (vocabSize.toLong * vectorSize * 8 >= Int.MaxValue) { + if (vocabSize.toLong * vectorSize >= Int.MaxValue) { throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" + " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " + - "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue/8`.") + "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue`.") } val syn0Global = From f3aa57aadbbc38c6611977ee8034134d27841370 Mon Sep 17 00:00:00 2001 From: Antonio Murgia Date: Sat, 5 Dec 2015 15:42:02 +0000 Subject: [PATCH 1044/1168] [SPARK-11994][MLLIB] Word2VecModel load and save cause SparkException when model is bigger than spark.kryoserializer.buffer.max Author: Antonio Murgia Closes #9989 from tmnd1991/SPARK-11932. --- .../apache/spark/mllib/feature/Word2Vec.scala | 16 ++++++++++++---- .../spark/mllib/feature/Word2VecSuite.scala | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index be12d45286034..b693f3c8e4bd9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -604,13 +604,21 @@ object Word2VecModel extends Loader[Word2VecModel] { val vectorSize = model.values.head.size val numWords = model.size - val metadata = compact(render - (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ - ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) + val metadata = compact(render( + ("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ + ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + // We want to partition the model in partitions of size 32MB + val partitionSize = (1L << 25) + // We calculate the approximate size of the model + // We only calculate the array size, not considering + // the string size, the formula is: + // floatSize * numWords * vectorSize + val approxSize = 4L * numWords * vectorSize + val nPartitions = ((approxSize / partitionSize) + 1).toInt val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } - sc.parallelize(dataArray.toSeq, 1).toDF().write.parquet(Loader.dataPath(path)) + sc.parallelize(dataArray.toSeq, nPartitions).toDF().write.parquet(Loader.dataPath(path)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index a864eec460f2b..37d01e2876695 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -92,4 +92,23 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } } + + test("big model load / save") { + // create a model bigger than 32MB since 9000 * 1000 * 4 > 2^25 + val word2VecMap = Map((0 to 9000).map(i => s"$i" -> Array.fill(1000)(0.1f)): _*) + val model = new Word2VecModel(word2VecMap) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + model.save(sc, path) + val sameModel = Word2VecModel.load(sc, path) + assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq)) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + } From 9982fd25ed694b1cef991eeda86aec037a4668d3 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 5 Dec 2015 15:52:52 +0000 Subject: [PATCH 1045/1168] [SPARK-11988][ML][MLLIB] Update JPMML to 1.2.7 Update JPMML pmml-model to 1.2.7 Author: Sean Owen Closes #9972 from srowen/SPARK-11988. --- LICENSE | 3 +- mllib/pom.xml | 2 +- .../BinaryClassificationPMMLModelExport.scala | 32 +++++++------- .../GeneralizedLinearPMMLModelExport.scala | 26 +++++------ .../pmml/export/KMeansPMMLModelExport.scala | 44 +++++++++---------- .../mllib/pmml/export/PMMLModelExport.scala | 17 +++---- 6 files changed, 59 insertions(+), 65 deletions(-) diff --git a/LICENSE b/LICENSE index 0db2d14465bd3..a2f75b817ab37 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,3 @@ - Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ @@ -237,7 +236,7 @@ The following components are provided under a BSD-style license. See project lin The text of each license is also included at licenses/LICENSE-[project].txt. (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) - (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model) + (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.4 - http://jblas.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) diff --git a/mllib/pom.xml b/mllib/pom.xml index 70139121d8c78..df50aca1a3f76 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -109,7 +109,7 @@ org.jpmml pmml-model - 1.1.15 + 1.2.7 com.sun.xml.fastinfoset diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index 622b53a252ac5..7abb1bf7ce967 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -45,7 +45,7 @@ private[mllib] class BinaryClassificationPMMLModelExport( val fields = new SArray[FieldName](model.weights.size) val dataDictionary = new DataDictionary val miningSchema = new MiningSchema - val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") + val regressionTableYES = new RegressionTable(model.intercept).setTargetCategory("1") var interceptNO = threshold if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) { if (threshold <= 0) { @@ -56,35 +56,35 @@ private[mllib] class BinaryClassificationPMMLModelExport( interceptNO = -math.log(1 / threshold - 1) } } - val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0") + val regressionTableNO = new RegressionTable(interceptNO).setTargetCategory("0") val regressionModel = new RegressionModel() - .withFunctionName(MiningFunctionType.CLASSIFICATION) - .withMiningSchema(miningSchema) - .withModelName(description) - .withNormalizationMethod(normalizationMethod) - .withRegressionTables(regressionTableYES, regressionTableNO) + .setFunctionName(MiningFunctionType.CLASSIFICATION) + .setMiningSchema(miningSchema) + .setModelName(description) + .setNormalizationMethod(normalizationMethod) + .addRegressionTables(regressionTableYES, regressionTableNO) for (i <- 0 until model.weights.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + regressionTableYES.addNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } // add target field val targetField = FieldName.create("target") dataDictionary - .withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) + .addDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) miningSchema - .withMiningFields(new MiningField(targetField) - .withUsageType(FieldUsageType.TARGET)) + .addMiningFields(new MiningField(targetField) + .setUsageType(FieldUsageType.TARGET)) - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) pmml.setDataDictionary(dataDictionary) - pmml.withModels(regressionModel) + pmml.addModels(regressionModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala index 1874786af0002..4d951d2973a6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala @@ -45,31 +45,31 @@ private[mllib] class GeneralizedLinearPMMLModelExport( val miningSchema = new MiningSchema val regressionTable = new RegressionTable(model.intercept) val regressionModel = new RegressionModel() - .withFunctionName(MiningFunctionType.REGRESSION) - .withMiningSchema(miningSchema) - .withModelName(description) - .withRegressionTables(regressionTable) + .setFunctionName(MiningFunctionType.REGRESSION) + .setMiningSchema(miningSchema) + .setModelName(description) + .addRegressionTables(regressionTable) for (i <- 0 until model.weights.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + regressionTable.addNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } // for completeness add target field val targetField = FieldName.create("target") - dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(targetField) - .withUsageType(FieldUsageType.TARGET)) + .addMiningFields(new MiningField(targetField) + .setUsageType(FieldUsageType.TARGET)) - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) pmml.setDataDictionary(dataDictionary) - pmml.withModels(regressionModel) + pmml.addModels(regressionModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala index 069e7afc9fca0..b5b824bb9c9b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala @@ -42,42 +42,42 @@ private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLMode val dataDictionary = new DataDictionary val miningSchema = new MiningSchema val comparisonMeasure = new ComparisonMeasure() - .withKind(ComparisonMeasure.Kind.DISTANCE) - .withMeasure(new SquaredEuclidean()) + .setKind(ComparisonMeasure.Kind.DISTANCE) + .setMeasure(new SquaredEuclidean()) val clusteringModel = new ClusteringModel() - .withModelName("k-means") - .withMiningSchema(miningSchema) - .withComparisonMeasure(comparisonMeasure) - .withFunctionName(MiningFunctionType.CLUSTERING) - .withModelClass(ClusteringModel.ModelClass.CENTER_BASED) - .withNumberOfClusters(model.clusterCenters.length) + .setModelName("k-means") + .setMiningSchema(miningSchema) + .setComparisonMeasure(comparisonMeasure) + .setFunctionName(MiningFunctionType.CLUSTERING) + .setModelClass(ClusteringModel.ModelClass.CENTER_BASED) + .setNumberOfClusters(model.clusterCenters.length) for (i <- 0 until clusterCenter.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - clusteringModel.withClusteringFields( - new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF)) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + clusteringModel.addClusteringFields( + new ClusteringField(fields(i)).setCompareFunction(CompareFunctionType.ABS_DIFF)) } - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) - for (i <- 0 until model.clusterCenters.length) { + for (i <- model.clusterCenters.indices) { val cluster = new Cluster() - .withName("cluster_" + i) - .withArray(new org.dmg.pmml.Array() - .withType(Array.Type.REAL) - .withN(clusterCenter.size) - .withValue(model.clusterCenters(i).toArray.mkString(" "))) + .setName("cluster_" + i) + .setArray(new org.dmg.pmml.Array() + .setType(Array.Type.REAL) + .setN(clusterCenter.size) + .setValue(model.clusterCenters(i).toArray.mkString(" "))) // we don't have the size of the single cluster but only the centroids (withValue) // .withSize(value) - clusteringModel.withClusters(cluster) + clusteringModel.addClusters(cluster) } pmml.setDataDictionary(dataDictionary) - pmml.withModels(clusteringModel) + pmml.addModels(clusteringModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index 9267e6dbdb857..426bb818c9266 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -30,19 +30,14 @@ private[mllib] trait PMMLModelExport { * Holder of the exported model in PMML format */ @BeanProperty - val pmml: PMML = new PMML - - pmml.setVersion("4.2") - setHeader(pmml) - - private def setHeader(pmml: PMML): Unit = { + val pmml: PMML = { val version = getClass.getPackage.getImplementationVersion - val app = new Application().withName("Apache Spark MLlib").withVersion(version) + val app = new Application("Apache Spark MLlib").setVersion(version) val timestamp = new Timestamp() - .withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) val header = new Header() - .withApplication(app) - .withTimestamp(timestamp) - pmml.setHeader(header) + .setApplication(app) + .setTimestamp(timestamp) + new PMML("4.2", header, null) } } From b33e47bfefd627513f8844308d0350c73234d819 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Sat, 5 Dec 2015 15:49:51 -0800 Subject: [PATCH 1046/1168] [SPARK-11774][SPARKR] Implement struct(), encode(), decode() functions in SparkR. Author: Sun Rui Closes #9804 from sun-rui/SPARK-11774. --- R/pkg/NAMESPACE | 3 ++ R/pkg/R/functions.R | 59 ++++++++++++++++++++++++++++++++ R/pkg/R/generics.R | 12 +++++++ R/pkg/inst/tests/test_sparkSQL.R | 37 ++++++++++++++++---- 4 files changed, 105 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 43e5e0119e7fe..565a2b1a68b5f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -134,8 +134,10 @@ exportMethods("%in%", "datediff", "dayofmonth", "dayofyear", + "decode", "dense_rank", "desc", + "encode", "endsWith", "exp", "explode", @@ -225,6 +227,7 @@ exportMethods("%in%", "stddev", "stddev_pop", "stddev_samp", + "struct", "sqrt", "startsWith", "substr", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index b30331c61c9a7..7432cb8e7ccf6 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -357,6 +357,40 @@ setMethod("dayofyear", column(jc) }) +#' decode +#' +#' Computes the first argument into a string from a binary using the provided character set +#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' +#' @rdname decode +#' @name decode +#' @family string_funcs +#' @export +#' @examples \dontrun{decode(df$c, "UTF-8")} +setMethod("decode", + signature(x = "Column", charset = "character"), + function(x, charset) { + jc <- callJStatic("org.apache.spark.sql.functions", "decode", x@jc, charset) + column(jc) + }) + +#' encode +#' +#' Computes the first argument into a binary from a string using the provided character set +#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' +#' @rdname encode +#' @name encode +#' @family string_funcs +#' @export +#' @examples \dontrun{encode(df$c, "UTF-8")} +setMethod("encode", + signature(x = "Column", charset = "character"), + function(x, charset) { + jc <- callJStatic("org.apache.spark.sql.functions", "encode", x@jc, charset) + column(jc) + }) + #' exp #' #' Computes the exponential of the given value. @@ -1039,6 +1073,31 @@ setMethod("stddev_samp", column(jc) }) +#' struct +#' +#' Creates a new struct column that composes multiple input columns. +#' +#' @rdname struct +#' @name struct +#' @family normal_funcs +#' @export +#' @examples +#' \dontrun{ +#' struct(df$c, df$d) +#' struct("col1", "col2") +#' } +setMethod("struct", + signature(x = "characterOrColumn"), + function(x, ...) { + if (class(x) == "Column") { + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "struct", jcols) + } else { + jc <- callJStatic("org.apache.spark.sql.functions", "struct", x, list(...)) + } + column(jc) + }) + #' sqrt #' #' Computes the square root of the specified float value. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 711ce38f9e104..4b5f786d39461 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -744,10 +744,18 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @export setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) +#' @rdname decode +#' @export +setGeneric("decode", function(x, charset) { standardGeneric("decode") }) + #' @rdname dense_rank #' @export setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") }) +#' @rdname encode +#' @export +setGeneric("encode", function(x, charset) { standardGeneric("encode") }) + #' @rdname explode #' @export setGeneric("explode", function(x) { standardGeneric("explode") }) @@ -1001,6 +1009,10 @@ setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) #' @export setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) +#' @rdname struct +#' @export +setGeneric("struct", function(x, ...) { standardGeneric("struct") }) + #' @rdname substring_index #' @export setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1e7cb54099703..2d26b92ac7275 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -27,6 +27,11 @@ checkStructField <- function(actual, expectedName, expectedType, expectedNullabl expect_equal(actual$nullable(), expectedNullable) } +markUtf8 <- function(s) { + Encoding(s) <- "UTF-8" + s +} + # Tests for SparkSQL functions in SparkR sc <- sparkR.init() @@ -551,11 +556,6 @@ test_that("collect() and take() on a DataFrame return the same number of rows an }) test_that("collect() support Unicode characters", { - markUtf8 <- function(s) { - Encoding(s) <- "UTF-8" - s - } - lines <- c("{\"name\":\"안녕하세요\"}", "{\"name\":\"您好\", \"age\":30}", "{\"name\":\"こんにちは\", \"age\":19}", @@ -933,8 +933,33 @@ test_that("column functions", { # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) + + # Test struct() + df <- createDataFrame(sqlContext, + list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + schema = c("a", "b", "c")) + result <- collect(select(df, struct("a", "c"))) + expected <- data.frame(row.names = 1:2) + expected$"struct(a,c)" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) + expect_equal(result, expected) + + result <- collect(select(df, struct(df$a, df$b))) + expected <- data.frame(row.names = 1:2) + expected$"struct(a,b)" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) + expect_equal(result, expected) + + # Test encode(), decode() + bytes <- as.raw(c(0xe5, 0xa4, 0xa7, 0xe5, 0x8d, 0x83, 0xe4, 0xb8, 0x96, 0xe7, 0x95, 0x8c)) + df <- createDataFrame(sqlContext, + list(list(markUtf8("大千世界"), "utf-8", bytes)), + schema = c("a", "b", "c")) + result <- collect(select(df, encode(df$a, "utf-8"), decode(df$c, "utf-8"))) + expect_equal(result[[1]][[1]], bytes) + expect_equal(result[[2]], markUtf8("大千世界")) }) -# + test_that("column binary mathfunctions", { lines <- c("{\"a\":1, \"b\":5}", "{\"a\":2, \"b\":6}", From 4d3ad9458851f11366997865d69cd1bed85a4e66 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sat, 5 Dec 2015 16:00:12 -0800 Subject: [PATCH 1047/1168] [SPARK-11715][SPARKR] Add R support corr for Column Aggregration Need to match existing method signature Author: felixcheung Closes #9680 from felixcheung/rcorr. --- R/pkg/R/functions.R | 15 +++++++++++++++ R/pkg/R/generics.R | 2 +- R/pkg/R/stats.R | 9 +++++---- R/pkg/inst/tests/test_sparkSQL.R | 2 +- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 7432cb8e7ccf6..25231451df3d2 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -259,6 +259,21 @@ setMethod("column", function(x) { col(x) }) +#' corr +#' +#' Computes the Pearson Correlation Coefficient for two Columns. +#' +#' @rdname corr +#' @name corr +#' @family math_funcs +#' @export +#' @examples \dontrun{corr(df$c, df$d)} +setMethod("corr", signature(x = "Column"), + function(x, col2) { + stopifnot(class(col2) == "Column") + jc <- callJStatic("org.apache.spark.sql.functions", "corr", x@jc, col2@jc) + column(jc) + }) #' cos #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4b5f786d39461..acfd4841e19af 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -411,7 +411,7 @@ setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) #' @rdname statfunctions #' @export -setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") }) +setGeneric("corr", function(x, ...) {standardGeneric("corr") }) #' @rdname summary #' @export diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index f79329b115404..d17cce9c756e2 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -77,7 +77,7 @@ setMethod("cov", #' Calculates the correlation of two columns of a DataFrame. #' Currently only supports the Pearson Correlation Coefficient. #' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics. -#' +#' #' @param x A SparkSQL DataFrame #' @param col1 the name of the first column #' @param col2 the name of the second column @@ -95,8 +95,9 @@ setMethod("cov", #' corr <- corr(df, "title", "gender", method = "pearson") #' } setMethod("corr", - signature(x = "DataFrame", col1 = "character", col2 = "character"), + signature(x = "DataFrame"), function(x, col1, col2, method = "pearson") { + stopifnot(class(col1) == "character" && class(col2) == "character") statFunctions <- callJMethod(x@sdf, "stat") callJMethod(statFunctions, "corr", col1, col2, method) }) @@ -109,7 +110,7 @@ setMethod("corr", #' #' @param x A SparkSQL DataFrame. #' @param cols A vector column names to search frequent items in. -#' @param support (Optional) The minimum frequency for an item to be considered `frequent`. +#' @param support (Optional) The minimum frequency for an item to be considered `frequent`. #' Should be greater than 1e-4. Default support = 0.01. #' @return a local R data.frame with the frequent items in each column #' @@ -131,7 +132,7 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"), #' sampleBy #' #' Returns a stratified sample without replacement based on the fraction given on each stratum. -#' +#' #' @param x A SparkSQL DataFrame #' @param col column that defines strata #' @param fractions A named list giving sampling fraction for each stratum. If a stratum is diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 2d26b92ac7275..a5a234a02d9f2 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -892,7 +892,7 @@ test_that("column functions", { c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) c12 <- variance(c) c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) - c14 <- cume_dist() + ntile(1) + c14 <- cume_dist() + ntile(1) + corr(c, c1) c15 <- dense_rank() + percent_rank() + rank() + row_number() # Test if base::rank() is exposed From 83f77bf2912a34224571d7d0cd6f4242675018f0 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 5 Dec 2015 16:39:01 -0800 Subject: [PATCH 1048/1168] [SPARK-12115][SPARKR] Change numPartitions() to getNumPartitions() to be consistent with Scala/Python Change ```numPartitions()``` to ```getNumPartitions()``` to be consistent with Scala/Python. Note: If we can not catch up with 1.6 release, it will be breaking change for 1.7 that we also need to explain in release note. cc sun-rui felixcheung shivaram Author: Yanbo Liang Closes #10123 from yanboliang/spark-12115. --- R/pkg/R/RDD.R | 55 ++++++++++++++++++++++--------------- R/pkg/R/generics.R | 6 +++- R/pkg/R/pairRDD.R | 4 +-- R/pkg/inst/tests/test_rdd.R | 10 +++---- 4 files changed, 45 insertions(+), 30 deletions(-) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 47945c2825da9..00c40c38cabc9 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -306,17 +306,28 @@ setMethod("checkpoint", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10, 2L) -#' numPartitions(rdd) # 2L +#' getNumPartitions(rdd) # 2L #'} -#' @rdname numPartitions +#' @rdname getNumPartitions +#' @aliases getNumPartitions,RDD-method +#' @noRd +setMethod("getNumPartitions", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "getNumPartitions") + }) + +#' Gets the number of partitions of an RDD, the same as getNumPartitions. +#' But this function has been deprecated, please use getNumPartitions. +#' +#' @rdname getNumPartitions #' @aliases numPartitions,RDD-method #' @noRd setMethod("numPartitions", signature(x = "RDD"), function(x) { - jrdd <- getJRDD(x) - partitions <- callJMethod(jrdd, "partitions") - callJMethod(partitions, "size") + .Deprecated("getNumPartitions") + getNumPartitions(x) }) #' Collect elements of an RDD @@ -443,7 +454,7 @@ setMethod("countByValue", signature(x = "RDD"), function(x) { ones <- lapply(x, function(item) { list(item, 1L) }) - collect(reduceByKey(ones, `+`, numPartitions(x))) + collect(reduceByKey(ones, `+`, getNumPartitions(x))) }) #' Apply a function to all elements @@ -759,7 +770,7 @@ setMethod("take", resList <- list() index <- -1 jrdd <- getJRDD(x) - numPartitions <- numPartitions(x) + numPartitions <- getNumPartitions(x) serializedModeRDD <- getSerializedMode(x) # TODO(shivaram): Collect more than one partition based on size @@ -823,7 +834,7 @@ setMethod("first", #' @noRd setMethod("distinct", signature(x = "RDD"), - function(x, numPartitions = SparkR:::numPartitions(x)) { + function(x, numPartitions = SparkR:::getNumPartitions(x)) { identical.mapped <- lapply(x, function(x) { list(x, NULL) }) reduced <- reduceByKey(identical.mapped, function(x, y) { x }, @@ -993,8 +1004,8 @@ setMethod("keyBy", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) -#' numPartitions(rdd) # 4 -#' numPartitions(repartition(rdd, 2L)) # 2 +#' getNumPartitions(rdd) # 4 +#' getNumPartitions(repartition(rdd, 2L)) # 2 #'} #' @rdname repartition #' @aliases repartition,RDD @@ -1014,8 +1025,8 @@ setMethod("repartition", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) -#' numPartitions(rdd) # 3 -#' numPartitions(coalesce(rdd, 1L)) # 1 +#' getNumPartitions(rdd) # 3 +#' getNumPartitions(coalesce(rdd, 1L)) # 1 #'} #' @rdname coalesce #' @aliases coalesce,RDD @@ -1024,7 +1035,7 @@ setMethod("coalesce", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, shuffle = FALSE) { numPartitions <- numToInt(numPartitions) - if (shuffle || numPartitions > SparkR:::numPartitions(x)) { + if (shuffle || numPartitions > SparkR:::getNumPartitions(x)) { func <- function(partIndex, part) { set.seed(partIndex) # partIndex as seed start <- as.integer(base::sample(numPartitions, 1) - 1) @@ -1112,7 +1123,7 @@ setMethod("saveAsTextFile", #' @noRd setMethod("sortBy", signature(x = "RDD", func = "function"), - function(x, func, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { + function(x, func, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { values(sortByKey(keyBy(x, func), ascending, numPartitions)) }) @@ -1144,7 +1155,7 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { resList <- list() index <- -1 jrdd <- getJRDD(newRdd) - numPartitions <- numPartitions(newRdd) + numPartitions <- getNumPartitions(newRdd) serializedModeRDD <- getSerializedMode(newRdd) while (TRUE) { @@ -1368,7 +1379,7 @@ setMethod("setName", setMethod("zipWithUniqueId", signature(x = "RDD"), function(x) { - n <- numPartitions(x) + n <- getNumPartitions(x) partitionFunc <- function(partIndex, part) { mapply( @@ -1409,7 +1420,7 @@ setMethod("zipWithUniqueId", setMethod("zipWithIndex", signature(x = "RDD"), function(x) { - n <- numPartitions(x) + n <- getNumPartitions(x) if (n > 1) { nums <- collect(lapplyPartition(x, function(part) { @@ -1521,8 +1532,8 @@ setMethod("unionRDD", setMethod("zipRDD", signature(x = "RDD", other = "RDD"), function(x, other) { - n1 <- numPartitions(x) - n2 <- numPartitions(other) + n1 <- getNumPartitions(x) + n2 <- getNumPartitions(other) if (n1 != n2) { stop("Can only zip RDDs which have the same number of partitions.") } @@ -1588,7 +1599,7 @@ setMethod("cartesian", #' @noRd setMethod("subtract", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { mapFunction <- function(e) { list(e, NA) } rdd1 <- map(x, mapFunction) rdd2 <- map(other, mapFunction) @@ -1620,7 +1631,7 @@ setMethod("subtract", #' @noRd setMethod("intersection", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { rdd1 <- map(x, function(v) { list(v, NA) }) rdd2 <- map(other, function(v) { list(v, NA) }) @@ -1661,7 +1672,7 @@ setMethod("zipPartitions", if (length(rrdds) == 1) { return(rrdds[[1]]) } - nPart <- sapply(rrdds, numPartitions) + nPart <- sapply(rrdds, getNumPartitions) if (length(unique(nPart)) != 1) { stop("Can only zipPartitions RDDs which have the same number of partitions.") } diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index acfd4841e19af..29dd11f41ff5e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -133,7 +133,11 @@ setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) # @export setGeneric("name", function(x) { standardGeneric("name") }) -# @rdname numPartitions +# @rdname getNumPartitions +# @export +setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) + +# @rdname getNumPartitions # @export setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 991bea4d2022d..334c11d2f89a1 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -750,7 +750,7 @@ setMethod("cogroup", #' @noRd setMethod("sortByKey", signature(x = "RDD"), - function(x, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { + function(x, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { rangeBounds <- list() if (numPartitions > 1) { @@ -818,7 +818,7 @@ setMethod("sortByKey", #' @noRd setMethod("subtractByKey", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { filterFunction <- function(elem) { iters <- elem[[2]] (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 71aed2bb9d6a8..7423b4f2bed1f 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -28,8 +28,8 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { - expect_equal(numPartitions(rdd), 2) - expect_equal(numPartitions(intRdd), 2) + expect_equal(getNumPartitions(rdd), 2) + expect_equal(getNumPartitions(intRdd), 2) }) test_that("first on RDD", { @@ -304,18 +304,18 @@ test_that("repartition/coalesce on RDDs", { # repartition r1 <- repartition(rdd, 2) - expect_equal(numPartitions(r1), 2L) + expect_equal(getNumPartitions(r1), 2L) count <- length(collectPartition(r1, 0L)) expect_true(count >= 8 && count <= 12) r2 <- repartition(rdd, 6) - expect_equal(numPartitions(r2), 6L) + expect_equal(getNumPartitions(r2), 6L) count <- length(collectPartition(r2, 0L)) expect_true(count >= 0 && count <= 4) # coalesce r3 <- coalesce(rdd, 1) - expect_equal(numPartitions(r3), 1L) + expect_equal(getNumPartitions(r3), 1L) count <- length(collectPartition(r3, 0L)) expect_equal(count, 20) }) From 4102584a71f007fd8c70c99d83865007ef8c979c Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 5 Dec 2015 22:51:05 -0800 Subject: [PATCH 1049/1168] [SPARK-12044][SPARKR] Fix usage of isnan, isNaN 1, Add ```isNaN``` to ```Column``` for SparkR. ```Column``` should has three related variable functions: ```isNaN, isNull, isNotNull```. 2, Replace ```DataFrame.isNaN``` with ```DataFrame.isnan``` at SparkR side. Because ```DataFrame.isNaN``` has been deprecated and will be removed at Spark 2.0. 3, Add ```isnull``` to ```DataFrame``` for SparkR. ```DataFrame``` should has two related functions: ```isnan, isnull```. cc shivaram sun-rui felixcheung Author: Yanbo Liang Closes #10037 from yanboliang/spark-12044. --- R/pkg/R/column.R | 2 +- R/pkg/R/functions.R | 26 +++++++++++++++++++------- R/pkg/R/generics.R | 8 ++++++-- R/pkg/inst/tests/test_sparkSQL.R | 6 +++++- 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 20de3907b7dd9..7bb8ef2595b59 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -56,7 +56,7 @@ operators <- list( "&" = "and", "|" = "or", #, "!" = "unary_$bang" "^" = "pow" ) -column_functions1 <- c("asc", "desc", "isNull", "isNotNull") +column_functions1 <- c("asc", "desc", "isNaN", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") createOperator <- function(op) { diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 25231451df3d2..09e4e04335a33 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -537,19 +537,31 @@ setMethod("initcap", column(jc) }) -#' isNaN +#' is.nan #' -#' Return true iff the column is NaN. +#' Return true if the column is NaN, alias for \link{isnan} #' -#' @rdname isNaN -#' @name isNaN +#' @rdname is.nan +#' @name is.nan #' @family normal_funcs #' @export -#' @examples \dontrun{isNaN(df$c)} -setMethod("isNaN", +#' @examples +#' \dontrun{ +#' is.nan(df$c) +#' isnan(df$c) +#' } +setMethod("is.nan", + signature(x = "Column"), + function(x) { + isnan(x) + }) + +#' @rdname is.nan +#' @name isnan +setMethod("isnan", signature(x = "Column"), function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "isNaN", x@jc) + jc <- callJStatic("org.apache.spark.sql.functions", "isnan", x@jc) column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 29dd11f41ff5e..c383e6e78b8b4 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -625,6 +625,10 @@ setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) +#' @rdname column +#' @export +setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) + #' @rdname column #' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) @@ -808,9 +812,9 @@ setGeneric("initcap", function(x) { standardGeneric("initcap") }) #' @export setGeneric("instr", function(y, x) { standardGeneric("instr") }) -#' @rdname isNaN +#' @rdname is.nan #' @export -setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) +setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @rdname kurtosis #' @export diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index a5a234a02d9f2..6ef03ae97635e 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -883,7 +883,7 @@ test_that("column functions", { c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) c3 <- cosh(c) + count(c) + crc32(c) + exp(c) c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) - c5 <- hour(c) + initcap(c) + isNaN(c) + last(c) + last_day(c) + length(c) + c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c) c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) @@ -894,6 +894,10 @@ test_that("column functions", { c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) c14 <- cume_dist() + ntile(1) + corr(c, c1) c15 <- dense_rank() + percent_rank() + rank() + row_number() + c16 <- is.nan(c) + isnan(c) + isNaN(c) + + # Test if base::is.nan() is exposed + expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) # Test if base::rank() is exposed expect_equal(class(rank())[[1]], "Column") From 680ab413f757a183ba3b49676c326434148777dc Mon Sep 17 00:00:00 2001 From: gcc Date: Sun, 6 Dec 2015 16:27:40 +0000 Subject: [PATCH 1050/1168] [SPARK-12048][SQL] Prevent to close JDBC resources twice Author: gcc Closes #10101 from rh99/master. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index b9dd7f6b4099b..1c348ed62fc78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -511,6 +511,7 @@ private[sql] class JDBCRDD( } catch { case e: Exception => logWarning("Exception closing connection", e) } + closed = true } override def hasNext: Boolean = { From 55c2ae6e1e3b2ebeed1c96730c93ef08202a472d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 6 Dec 2015 11:15:02 -0800 Subject: [PATCH 1051/1168] [SPARK-12138][SQL] Escape \u in the generated comments of codegen When \u appears in a comment block (i.e. in /**/), code gen will break. So, in Expression and CodegenFallback, we escape \u to \\u. yhuai Please review it. I did reproduce it and it works after the fix. Thanks! Author: gatorsmile Closes #10155 from gatorsmile/escapeU. --- .../spark/sql/catalyst/expressions/Expression.scala | 4 +++- .../sql/catalyst/expressions/CodeGenerationSuite.scala | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 614f0c075fd23..6d807c9ecf302 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -220,7 +220,9 @@ abstract class Expression extends TreeNode[Expression] { * Returns the string representation of this expression that is safe to be put in * code comments of generated code. */ - protected def toCommentSafeString: String = this.toString.replace("*/", "\\*\\/") + protected def toCommentSafeString: String = this.toString + .replace("*/", "\\*\\/") + .replace("\\u", "\\\\u") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index fe754240dcd67..cd2ef7dcd0cd3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -107,4 +107,13 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { true, InternalRow(UTF8String.fromString("*/"))) } + + test("\\u in the data") { + // When \ u appears in a comment block (i.e. in /**/), code gen will break. + // So, in Expression and CodegenFallback, we escape \ u to \\u. + checkEvaluation( + EqualTo(BoundReference(0, StringType, false), Literal.create("\\u", StringType)), + true, + InternalRow(UTF8String.fromString("\\u"))) + } } From 8e03d8f3a799beb1b33f7f50f582d725f94f18a1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 6 Dec 2015 17:35:01 -0800 Subject: [PATCH 1052/1168] [SPARK-12152][PROJECT-INFRA] Speed up Scalastyle checks by only invoking SBT once Currently, `dev/scalastyle` invokes SBT four times, but these invocations can be replaced with a single invocation, saving about one minute of build time. Author: Josh Rosen Closes #10151 from JoshRosen/speed-up-scalastyle. --- dev/scalastyle | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/dev/scalastyle b/dev/scalastyle index ad93f7e85b27c..8fd3604b9f451 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,14 +17,17 @@ # limitations under the License. # -echo -e "q\n" | build/sbt -Pkinesis-asl -Phive -Phive-thriftserver scalastyle > scalastyle.txt -echo -e "q\n" | build/sbt -Pkinesis-asl -Phive -Phive-thriftserver test:scalastyle >> scalastyle.txt -# Check style with YARN built too -echo -e "q\n" | build/sbt -Pkinesis-asl -Pyarn -Phadoop-2.2 scalastyle >> scalastyle.txt -echo -e "q\n" | build/sbt -Pkinesis-asl -Pyarn -Phadoop-2.2 test:scalastyle >> scalastyle.txt - -ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') -rm scalastyle.txt +# NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file +# with failure (either resolution or compilation); the "q" makes SBT quit. +ERRORS=$(echo -e "q\n" \ + | build/sbt \ + -Pkinesis-asl \ + -Pyarn \ + -Phive \ + -Phive-thriftserver \ + scalastyle test:scalastyle \ + | awk '{if($1~/error/)print}' \ +) if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" From ef02e1dbbe5e4e7df32c58aa81c6c8f264c9091d Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 7 Dec 2015 00:21:55 -0800 Subject: [PATCH 1053/1168] [SPARK-12106][STREAMING][FLAKY-TEST] BatchedWAL test transiently flaky when Jenkins load is high We need to make sure that the last entry is indeed the last entry in the queue. Author: Burak Yavuz Closes #10110 from brkyvz/batch-wal-test-fix. --- .../streaming/util/BatchedWriteAheadLog.scala | 6 ++++-- .../spark/streaming/util/WriteAheadLogSuite.scala | 14 ++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 7158abc08894a..b2cd524f28b74 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -166,10 +166,12 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp var segment: WriteAheadLogRecordHandle = null if (buffer.length > 0) { logDebug(s"Batched ${buffer.length} records for Write Ahead Log write") + // threads may not be able to add items in order by time + val sortedByTime = buffer.sortBy(_.time) // We take the latest record for the timestamp. Please refer to the class Javadoc for // detailed explanation - val time = buffer.last.time - segment = wrappedLog.write(aggregate(buffer), time) + val time = sortedByTime.last.time + segment = wrappedLog.write(aggregate(sortedByTime), time) } buffer.foreach(_.promise.success(segment)) } catch { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index eaa88ea3cd380..ef1e89df31305 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -480,7 +480,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( p } - test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") { + test("BatchedWriteAheadLog - name log with the highest timestamp of aggregated entries") { val blockingWal = new BlockingWriteAheadLog(wal, walHandle) val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) @@ -500,8 +500,14 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( // rest of the records will be batched while it takes time for 3 to get written writeAsync(batchedWal, event2, 5L) writeAsync(batchedWal, event3, 8L) - writeAsync(batchedWal, event4, 12L) - writeAsync(batchedWal, event5, 10L) + // we would like event 5 to be written before event 4 in order to test that they get + // sorted before being aggregated + writeAsync(batchedWal, event5, 12L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 3) + } + writeAsync(batchedWal, event4, 10L) eventually(timeout(1 second)) { assert(walBatchingThreadPool.getActiveCount === 5) assert(batchedWal.invokePrivate(queueLength()) === 4) @@ -517,7 +523,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( // the file name should be the timestamp of the last record, as events should be naturally // in order of timestamp, and we need the last element. val bufferCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) - verify(wal, times(1)).write(bufferCaptor.capture(), meq(10L)) + verify(wal, times(1)).write(bufferCaptor.capture(), meq(12L)) val records = BatchedWriteAheadLog.deaggregate(bufferCaptor.getValue).map(byteBufferToString) assert(records.toSet === queuedEvents) } From 0e51b47dbb68587af0b26198fa39c181a085daa4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 7 Dec 2015 10:34:18 -0800 Subject: [PATCH 1054/1168] [SPARK-12032] [SQL] Re-order inner joins to do join with conditions first Currently, the order of joins is exactly the same as SQL query, some conditions may not pushed down to the correct join, then those join will become cross product and is extremely slow. This patch try to re-order the inner joins (which are common in SQL query), pick the joins that have self-contain conditions first, delay those that does not have conditions. After this patch, the TPCDS query Q64/65 can run hundreds times faster. cc marmbrus nongli Author: Davies Liu Closes #10073 from davies/reorder_joins. --- .../sql/catalyst/optimizer/Optimizer.scala | 56 ++++++++++- .../sql/catalyst/planning/patterns.scala | 40 +++++++- .../catalyst/optimizer/JoinOrderSuite.scala | 95 +++++++++++++++++++ 3 files changed, 185 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 06d14fcf8b9c2..f6088695a9276 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,14 +18,12 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet + import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.FullOuter -import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.RightOuter -import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ @@ -44,6 +42,7 @@ object DefaultOptimizer extends Optimizer { // Operator push down SetOperationPushDown, SamplePushDown, + ReorderJoin, PushPredicateThroughJoin, PushPredicateThroughProject, PushPredicateThroughGenerate, @@ -711,6 +710,53 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel } } +/** + * Reorder the joins and push all the conditions into join, so that the bottom ones have at least + * one condition. + * + * The order of joins will not be changed if all of them already have at least one condition. + */ +object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Join a list of plans together and push down the conditions into them. + * + * The joined plan are picked from left to right, prefer those has at least one join condition. + * + * @param input a list of LogicalPlans to join. + * @param conditions a list of condition for join. + */ + def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = { + assert(input.size >= 2) + if (input.size == 2) { + Join(input(0), input(1), Inner, conditions.reduceLeftOption(And)) + } else { + val left :: rest = input.toList + // find out the first join that have at least one join condition + val conditionalJoin = rest.find { plan => + val refs = left.outputSet ++ plan.outputSet + conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan)) + .exists(_.references.subsetOf(refs)) + } + // pick the next one if no condition left + val right = conditionalJoin.getOrElse(rest.head) + + val joinedRefs = left.outputSet ++ right.outputSet + val (joinConditions, others) = conditions.partition(_.references.subsetOf(joinedRefs)) + val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And)) + + // should not have reference to same logical plan + createOrderedJoin(Seq(joined) ++ rest.filterNot(_ eq right), others) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ ExtractFiltersAndInnerJoins(input, conditions) + if input.size > 2 && conditions.nonEmpty => + createOrderedJoin(input, conditions) + } +} + /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 6f4f11406d7c4..cd3f15cbe107b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -21,7 +21,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNodeRef /** * A pattern that matches any number of project or filter operations on top of another relational @@ -132,6 +131,45 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } } +/** + * A pattern that collects the filter and inner joins. + * + * Filter + * | + * inner Join + * / \ ----> (Seq(plan0, plan1, plan2), conditions) + * Filter plan2 + * | + * inner join + * / \ + * plan0 plan1 + * + * Note: This pattern currently only works for left-deep trees. + */ +object ExtractFiltersAndInnerJoins extends PredicateHelper { + + // flatten all inner joins, which are next to each other + def flattenJoin(plan: LogicalPlan): (Seq[LogicalPlan], Seq[Expression]) = plan match { + case Join(left, right, Inner, cond) => + val (plans, conditions) = flattenJoin(left) + (plans ++ Seq(right), conditions ++ cond.toSeq) + + case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition)) => + val (plans, conditions) = flattenJoin(j) + (plans, conditions ++ splitConjunctivePredicates(filterCondition)) + + case _ => (Seq(plan), Seq()) + } + + def unapply(plan: LogicalPlan): Option[(Seq[LogicalPlan], Seq[Expression])] = plan match { + case f @ Filter(filterCondition, j @ Join(_, _, Inner, _)) => + Some(flattenJoin(f)) + case j @ Join(_, _, Inner, _) => + Some(flattenJoin(j)) + case _ => None + } +} + /** * A pattern that collects all adjacent unions and returns their children as a Seq. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala new file mode 100644 index 0000000000000..9b1e16c727647 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class JoinOrderSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Filter Pushdown", Once, + CombineFilters, + PushPredicateThroughProject, + BooleanSimplification, + ReorderJoin, + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + PushPredicateThroughAggregate, + ColumnPruning, + ProjectCollapsing) :: Nil + + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation1 = LocalRelation('d.int) + + test("extract filters and joins") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + val z = testRelation.subquery('z) + + def testExtract(plan: LogicalPlan, expected: Option[(Seq[LogicalPlan], Seq[Expression])]) { + assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected) + } + + testExtract(x, None) + testExtract(x.where("x.b".attr === 1), None) + testExtract(x.join(y), Some(Seq(x, y), Seq())) + testExtract(x.join(y, condition = Some("x.b".attr === "y.d".attr)), + Some(Seq(x, y), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).where("x.b".attr === "y.d".attr), + Some(Seq(x, y), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).join(z), Some(Seq(x, y, z), Seq())) + testExtract(x.join(y).where("x.b".attr === "y.d".attr).join(z), + Some(Seq(x, y, z), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).join(x.join(z)), Some(Seq(x, y, x.join(z)), Seq())) + testExtract(x.join(y).join(x.join(z)).where("x.b".attr === "y.d".attr), + Some(Seq(x, y, x.join(z)), Seq("x.b".attr === "y.d".attr))) + } + + test("reorder inner joins") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + val z = testRelation.subquery('z) + + val originalQuery = { + x.join(y).join(z) + .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.join(z, condition = Some("x.b".attr === "z.b".attr)) + .join(y, condition = Some("y.d".attr === "z.a".attr)) + .analyze + + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + } +} From 9e0cb2c6cef2ac71a8c2aa19582710f5072ea9ff Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Mon, 7 Dec 2015 10:38:17 -0800 Subject: [PATCH 1055/1168] [SPARK-12034][SPARKR] Eliminate warnings in SparkR test cases. This PR: 1. Suppress all known warnings. 2. Cleanup test cases and fix some errors in test cases. 3. Fix errors in HiveContext related test cases. These test cases are actually not run previously due to a bug of creating TestHiveContext. 4. Support 'testthat' package version 0.11.0 which prefers that test cases be under 'tests/testthat' 5. Make sure the default Hadoop file system is local when running test cases. 6. Turn on warnings into errors. Author: Sun Rui Closes #10030 from sun-rui/SPARK-12034. --- R/pkg/inst/tests/{ => testthat}/jarTest.R | 0 .../tests/{ => testthat}/packageInAJarTest.R | 0 R/pkg/inst/tests/{ => testthat}/test_Serde.R | 0 .../tests/{ => testthat}/test_binaryFile.R | 0 .../{ => testthat}/test_binary_function.R | 0 .../tests/{ => testthat}/test_broadcast.R | 0 R/pkg/inst/tests/{ => testthat}/test_client.R | 0 .../inst/tests/{ => testthat}/test_context.R | 0 .../tests/{ => testthat}/test_includeJAR.R | 2 +- .../{ => testthat}/test_includePackage.R | 0 R/pkg/inst/tests/{ => testthat}/test_mllib.R | 14 ++-- .../{ => testthat}/test_parallelize_collect.R | 0 R/pkg/inst/tests/{ => testthat}/test_rdd.R | 0 .../inst/tests/{ => testthat}/test_shuffle.R | 0 .../inst/tests/{ => testthat}/test_sparkSQL.R | 68 +++++++++++-------- R/pkg/inst/tests/{ => testthat}/test_take.R | 0 .../inst/tests/{ => testthat}/test_textFile.R | 0 R/pkg/inst/tests/{ => testthat}/test_utils.R | 0 R/pkg/tests/run-all.R | 3 + R/run-tests.sh | 2 +- 20 files changed, 50 insertions(+), 39 deletions(-) rename R/pkg/inst/tests/{ => testthat}/jarTest.R (100%) rename R/pkg/inst/tests/{ => testthat}/packageInAJarTest.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_Serde.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_binaryFile.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_binary_function.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_broadcast.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_client.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_context.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_includeJAR.R (94%) rename R/pkg/inst/tests/{ => testthat}/test_includePackage.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_mllib.R (90%) rename R/pkg/inst/tests/{ => testthat}/test_parallelize_collect.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_rdd.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_shuffle.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_sparkSQL.R (97%) rename R/pkg/inst/tests/{ => testthat}/test_take.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_textFile.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_utils.R (100%) diff --git a/R/pkg/inst/tests/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R similarity index 100% rename from R/pkg/inst/tests/jarTest.R rename to R/pkg/inst/tests/testthat/jarTest.R diff --git a/R/pkg/inst/tests/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R similarity index 100% rename from R/pkg/inst/tests/packageInAJarTest.R rename to R/pkg/inst/tests/testthat/packageInAJarTest.R diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R similarity index 100% rename from R/pkg/inst/tests/test_Serde.R rename to R/pkg/inst/tests/testthat/test_Serde.R diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R similarity index 100% rename from R/pkg/inst/tests/test_binaryFile.R rename to R/pkg/inst/tests/testthat/test_binaryFile.R diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R similarity index 100% rename from R/pkg/inst/tests/test_binary_function.R rename to R/pkg/inst/tests/testthat/test_binary_function.R diff --git a/R/pkg/inst/tests/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R similarity index 100% rename from R/pkg/inst/tests/test_broadcast.R rename to R/pkg/inst/tests/testthat/test_broadcast.R diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/testthat/test_client.R similarity index 100% rename from R/pkg/inst/tests/test_client.R rename to R/pkg/inst/tests/testthat/test_client.R diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/testthat/test_context.R similarity index 100% rename from R/pkg/inst/tests/test_context.R rename to R/pkg/inst/tests/testthat/test_context.R diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/testthat/test_includeJAR.R similarity index 94% rename from R/pkg/inst/tests/test_includeJAR.R rename to R/pkg/inst/tests/testthat/test_includeJAR.R index cc1faeabffe30..f89aa8e507fd5 100644 --- a/R/pkg/inst/tests/test_includeJAR.R +++ b/R/pkg/inst/tests/testthat/test_includeJAR.R @@ -20,7 +20,7 @@ runScript <- function() { sparkHome <- Sys.getenv("SPARK_HOME") sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) - scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") + scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/testthat/jarTest.R") submitPath <- file.path(sparkHome, "bin/spark-submit") res <- system2(command = submitPath, args = c(jarPath, scriptPath), diff --git a/R/pkg/inst/tests/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R similarity index 100% rename from R/pkg/inst/tests/test_includePackage.R rename to R/pkg/inst/tests/testthat/test_includePackage.R diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R similarity index 90% rename from R/pkg/inst/tests/test_mllib.R rename to R/pkg/inst/tests/testthat/test_mllib.R index e0667e5e22c18..08099dd96a87b 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -26,7 +26,7 @@ sc <- sparkR.init() sqlContext <- sparkRSQL.init(sc) test_that("glm and predict", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) test <- select(training, "Sepal_Length") model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") prediction <- predict(model, test) @@ -39,7 +39,7 @@ test_that("glm and predict", { }) test_that("glm should work with long formula", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) training$LongLongLongLongLongName <- training$Sepal_Width training$VeryLongLongLongLonLongName <- training$Sepal_Length training$AnotherLongLongLongLongName <- training$Species @@ -51,7 +51,7 @@ test_that("glm should work with long formula", { }) test_that("predictions match with native glm", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) vals <- collect(select(predict(model, training), "prediction")) rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) @@ -59,7 +59,7 @@ test_that("predictions match with native glm", { }) test_that("dot minus and intercept vs native glm", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) model <- glm(Sepal_Width ~ . - Species + 0, data = training) vals <- collect(select(predict(model, training), "prediction")) rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) @@ -67,7 +67,7 @@ test_that("dot minus and intercept vs native glm", { }) test_that("feature interaction vs native glm", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) vals <- collect(select(predict(model, training), "prediction")) rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) @@ -75,7 +75,7 @@ test_that("feature interaction vs native glm", { }) test_that("summary coefficients match with native glm", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal")) coefs <- unlist(stats$coefficients) devianceResiduals <- unlist(stats$devianceResiduals) @@ -92,7 +92,7 @@ test_that("summary coefficients match with native glm", { }) test_that("summary coefficients match with native glm of family 'binomial'", { - df <- createDataFrame(sqlContext, iris) + df <- suppressWarnings(createDataFrame(sqlContext, iris)) training <- filter(df, df$Species != "setosa") stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial")) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R similarity index 100% rename from R/pkg/inst/tests/test_parallelize_collect.R rename to R/pkg/inst/tests/testthat/test_parallelize_collect.R diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R similarity index 100% rename from R/pkg/inst/tests/test_rdd.R rename to R/pkg/inst/tests/testthat/test_rdd.R diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R similarity index 100% rename from R/pkg/inst/tests/test_shuffle.R rename to R/pkg/inst/tests/testthat/test_shuffle.R diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R similarity index 97% rename from R/pkg/inst/tests/test_sparkSQL.R rename to R/pkg/inst/tests/testthat/test_sparkSQL.R index 6ef03ae97635e..39fc94aea5fb1 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -133,38 +133,45 @@ test_that("create DataFrame from RDD", { expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - df <- jsonFile(sqlContext, jsonPathNa) - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") - insertInto(df, "people") - expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) - expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) - schema <- structType(structField("name", "string"), structField("age", "integer"), structField("height", "float")) - df2 <- createDataFrame(sqlContext, df.toRDD, schema) - df2AsDF <- as.DataFrame(sqlContext, df.toRDD, schema) + df <- read.df(sqlContext, jsonPathNa, "json", schema) + df2 <- createDataFrame(sqlContext, toRDD(df), schema) + df2AsDF <- as.DataFrame(sqlContext, toRDD(df), schema) expect_equal(columns(df2), c("name", "age", "height")) expect_equal(columns(df2AsDF), c("name", "age", "height")) expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) expect_equal(dtypes(df2AsDF), list(c("name", "string"), c("age", "int"), c("height", "float"))) - expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) - expect_equal(collect(where(df2AsDF, df2$name == "Bob")), c("Bob", 16, 176.5)) + expect_equal(as.list(collect(where(df2, df2$name == "Bob"))), + list(name = "Bob", age = 16, height = 176.5)) + expect_equal(as.list(collect(where(df2AsDF, df2AsDF$name == "Bob"))), + list(name = "Bob", age = 16, height = 176.5)) localDF <- data.frame(name=c("John", "Smith", "Sarah"), - age=c(19, 23, 18), - height=c(164.10, 181.4, 173.7)) + age=c(19L, 23L, 18L), + height=c(176.5, 181.4, 173.7)) df <- createDataFrame(sqlContext, localDF, schema) expect_is(df, "DataFrame") expect_equal(count(df), 3) expect_equal(columns(df), c("name", "age", "height")) expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) - expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10)) + expect_equal(as.list(collect(where(df, df$name == "John"))), + list(name = "John", age = 19L, height = 176.5)) + + ssc <- callJMethod(sc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") + df <- read.df(hiveCtx, jsonPathNa, "json", schema) + invisible(insertInto(df, "people")) + expect_equal(collect(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"))$age, + c(16)) + expect_equal(collect(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"))$height, + c(176.5)) }) test_that("convert NAs to null type in DataFrames", { @@ -250,7 +257,7 @@ test_that("create DataFrame from list or data.frame", { ldf2 <- collect(df) expect_equal(ldf$a, ldf2$a) - irisdf <- createDataFrame(sqlContext, iris) + irisdf <- suppressWarnings(createDataFrame(sqlContext, iris)) iris_collected <- collect(irisdf) expect_equivalent(iris_collected[,-5], iris[,-5]) expect_equal(iris_collected$Species, as.character(iris$Species)) @@ -463,7 +470,7 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) expect_is(unioned, "RDD") - expect_equal(SparkR:::getSerializedMode(unioned), "byte") + expect_equal(getSerializedMode(unioned), "byte") expect_equal(collect(unioned)[[2]]$name, "Andy") }) @@ -485,13 +492,13 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { unionByte <- unionRDD(rdd, dfRDD) expect_is(unionByte, "RDD") - expect_equal(SparkR:::getSerializedMode(unionByte), "byte") + expect_equal(getSerializedMode(unionByte), "byte") expect_equal(collect(unionByte)[[1]], 1) expect_equal(collect(unionByte)[[12]]$name, "Andy") unionString <- unionRDD(textRDD, dfRDD) expect_is(unionString, "RDD") - expect_equal(SparkR:::getSerializedMode(unionString), "byte") + expect_equal(getSerializedMode(unionString), "byte") expect_equal(collect(unionString)[[1]], "Michael") expect_equal(collect(unionString)[[5]]$name, "Andy") }) @@ -504,7 +511,7 @@ test_that("objectFile() works with row serialization", { objectIn <- objectFile(sc, objectPath) expect_is(objectIn, "RDD") - expect_equal(SparkR:::getSerializedMode(objectIn), "byte") + expect_equal(getSerializedMode(objectIn), "byte") expect_equal(collect(objectIn)[[2]]$age, 30) }) @@ -849,6 +856,7 @@ test_that("write.df() as parquet file", { }) test_that("test HiveContext", { + ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) }, @@ -863,10 +871,10 @@ test_that("test HiveContext", { expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") - saveAsTable(df, "json", "json", "append", path = jsonPath2) - df3 <- sql(hiveCtx, "select * from json") + invisible(saveAsTable(df, "json2", "json", "append", path = jsonPath2)) + df3 <- sql(hiveCtx, "select * from json2") expect_is(df3, "DataFrame") - expect_equal(count(df3), 6) + expect_equal(count(df3), 3) }) test_that("column operators", { @@ -1311,7 +1319,7 @@ test_that("toJSON() returns an RDD of the correct values", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toJSON(df) expect_is(testRDD, "RDD") - expect_equal(SparkR:::getSerializedMode(testRDD), "string") + expect_equal(getSerializedMode(testRDD), "string") expect_equal(collect(testRDD)[[1]], mockLines[1]) }) @@ -1641,7 +1649,7 @@ test_that("SQL error message is returned from JVM", { expect_equal(grepl("Table not found: blah", retError), TRUE) }) -irisDF <- createDataFrame(sqlContext, iris) +irisDF <- suppressWarnings(createDataFrame(sqlContext, iris)) test_that("Method as.data.frame as a synonym for collect()", { expect_equal(as.data.frame(irisDF), collect(irisDF)) @@ -1670,7 +1678,7 @@ test_that("attach() on a DataFrame", { }) test_that("with() on a DataFrame", { - df <- createDataFrame(sqlContext, iris) + df <- suppressWarnings(createDataFrame(sqlContext, iris)) expect_error(Sepal_Length) sum1 <- with(df, list(summary(Sepal_Length), summary(Sepal_Width))) expect_equal(collect(sum1[[1]])[1, "Sepal_Length"], "150") diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/testthat/test_take.R similarity index 100% rename from R/pkg/inst/tests/test_take.R rename to R/pkg/inst/tests/testthat/test_take.R diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R similarity index 100% rename from R/pkg/inst/tests/test_textFile.R rename to R/pkg/inst/tests/testthat/test_textFile.R diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R similarity index 100% rename from R/pkg/inst/tests/test_utils.R rename to R/pkg/inst/tests/testthat/test_utils.R diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 4f8a1ed2d83ef..1d04656ac2594 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -18,4 +18,7 @@ library(testthat) library(SparkR) +# Turn all warnings into errors +options("warn" = 2) + test_package("SparkR") diff --git a/R/run-tests.sh b/R/run-tests.sh index e82ad0ba2cd06..e64a4ea94c584 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) if [[ $FAILED != 0 ]]; then From 3af95967c894e37b061237b930193133130e12f0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 7 Dec 2015 11:00:25 -0800 Subject: [PATCH 1056/1168] [SPARK-12132] [PYSPARK] raise KeyboardInterrupt inside SIGINT handler Currently, the current line is not cleared by Cltr-C After this patch ``` >>> asdfasdf^C Traceback (most recent call last): File "~/spark/python/pyspark/context.py", line 225, in signal_handler raise KeyboardInterrupt() KeyboardInterrupt ``` It's still worse than 1.5 (and before). Author: Davies Liu Closes #10134 from davies/fix_cltrc. --- python/pyspark/context.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 77710a13394c6..529d16b480399 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -222,6 +222,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # create a signal handler which would be invoked on receiving SIGINT def signal_handler(signal, frame): self.cancelAllJobs() + raise KeyboardInterrupt() # see http://stackoverflow.com/questions/23206787/ if isinstance(threading.current_thread(), threading._MainThread): From e0847436c7354bee160b1cb9c41e1832b7eacef7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Dec 2015 11:03:59 -0800 Subject: [PATCH 1057/1168] [SPARK-11932][STREAMING] Partition previous TrackStateRDD if partitioner not present The reason is that TrackStateRDDs generated by trackStateByKey expect the previous batch's TrackStateRDDs to have a partitioner. However, when recovery from DStream checkpoints, the RDDs recovered from RDD checkpoints do not have a partitioner attached to it. This is because RDD checkpoints do not preserve the partitioner (SPARK-12004). While #9983 solves SPARK-12004 by preserving the partitioner through RDD checkpoints, there may be a non-zero chance that the saving and recovery fails. To be resilient, this PR repartitions the previous state RDD if the partitioner is not detected. Author: Tathagata Das Closes #9988 from tdas/SPARK-11932. --- .../apache/spark/streaming/Checkpoint.scala | 2 +- .../streaming/dstream/TrackStateDStream.scala | 39 ++-- .../spark/streaming/rdd/TrackStateRDD.scala | 29 ++- .../spark/streaming/CheckpointSuite.scala | 189 +++++++++++++----- .../spark/streaming/TestSuiteBase.scala | 6 + .../streaming/TrackStateByKeySuite.scala | 77 +++++-- 6 files changed, 258 insertions(+), 84 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index fd0e8d5d690b6..d0046afdeb447 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -277,7 +277,7 @@ class CheckpointWriter( val bytes = Checkpoint.serialize(checkpoint, conf) executor.execute(new CheckpointWriteHandler( checkpoint.checkpointTime, bytes, clearCheckpointDataLater)) - logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") + logInfo("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => logError("Could not submit checkpoint task to the thread pool executor", rej) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala index 0ada1111ce30a..ea6213420e7ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -132,22 +132,37 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT /** Method that generates a RDD for the given time */ override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD - val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { - TrackStateRDD.createFromPairRDD[K, V, S, E]( - spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), - partitioner, validTime - ) + val prevStateRDD = getOrCompute(validTime - slideDuration) match { + case Some(rdd) => + if (rdd.partitioner != Some(partitioner)) { + // If the RDD is not partitioned the right way, let us repartition it using the + // partition index as the key. This is to ensure that state RDD is always partitioned + // before creating another state RDD using it + TrackStateRDD.createFromRDD[K, V, S, E]( + rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime) + } else { + rdd + } + case None => + TrackStateRDD.createFromPairRDD[K, V, S, E]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, + validTime + ) } + // Compute the new state RDD with previous state RDD and partitioned data RDD - parent.getOrCompute(validTime).map { dataRDD => - val partitionedDataRDD = dataRDD.partitionBy(partitioner) - val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => - (validTime - interval).milliseconds - } - new TrackStateRDD( - prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime) + // Even if there is no data RDD, use an empty one to create a new state RDD + val dataRDD = parent.getOrCompute(validTime).getOrElse { + context.sparkContext.emptyRDD[(K, V)] + } + val partitionedDataRDD = dataRDD.partitionBy(partitioner) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds } + Some(new TrackStateRDD( + prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index 7050378d0feb0..30aafcf1460e3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -179,22 +179,43 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: private[streaming] object TrackStateRDD { - def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( pairRDD: RDD[(K, S)], partitioner: Partitioner, - updateTime: Time): TrackStateRDD[K, V, S, T] = { + updateTime: Time): TrackStateRDD[K, V, S, E] = { val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => val stateMap = StateMap.create[K, S](SparkEnv.get.conf) iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) } - Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T])) + Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E])) }, preservesPartitioning = true) val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None - new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) + new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) + } + + def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + rdd: RDD[(K, S, Long)], + partitioner: Partitioner, + updateTime: Time): TrackStateRDD[K, V, S, E] = { + + val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) } + val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator => + val stateMap = StateMap.create[K, S](SparkEnv.get.conf) + iterator.foreach { case (key, (state, updateTime)) => + stateMap.put(key, state, updateTime) + } + Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E])) + }, preservesPartitioning = true) + + val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) + + val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None + + new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index b1cbc7163bee3..cd28d3cf408d5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -33,17 +33,149 @@ import org.mockito.Mockito.mock import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.TestUtils +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils} import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils} +/** + * A trait of that can be mixed in to get methods for testing DStream operations under + * DStream checkpointing. Note that the implementations of this trait has to implement + * the `setupCheckpointOperation` + */ +trait DStreamCheckpointTester { self: SparkFunSuite => + + /** + * Tests a streaming operation under checkpointing, by restarting the operation + * from checkpoint file and verifying whether the final output is correct. + * The output is assumed to have come from a reliable queue which an replay + * data as required. + * + * NOTE: This takes into consideration that the last batch processed before + * master failure will be re-processed after restart/recovery. + */ + protected def testCheckpointedOperation[U: ClassTag, V: ClassTag]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + numBatchesBeforeRestart: Int, + batchDuration: Duration = Milliseconds(500), + stopSparkContextAfterTest: Boolean = true + ) { + require(numBatchesBeforeRestart < expectedOutput.size, + "Number of batches before context restart less than number of expected output " + + "(i.e. number of total batches to run)") + require(StreamingContext.getActive().isEmpty, + "Cannot run test with already active streaming context") + + // Current code assumes that number of batches to be run = number of inputs + val totalNumBatches = input.size + val batchDurationMillis = batchDuration.milliseconds + + // Setup the stream computation + val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString + logDebug(s"Using checkpoint directory $checkpointDir") + val ssc = createContextForCheckpointOperation(batchDuration) + require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName, + "Cannot run test without manual clock in the conf") + + val inputStream = new TestInputStream(ssc, input, numPartitions = 2) + val operatedStream = operation(inputStream) + operatedStream.print() + val outputStream = new TestOutputStreamWithPartitions(operatedStream, + new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]]) + outputStream.register() + ssc.checkpoint(checkpointDir) + + // Do the computation for initial number of batches, create checkpoint file and quit + val beforeRestartOutput = generateOutput[V](ssc, + Time(batchDurationMillis * numBatchesBeforeRestart), checkpointDir, stopSparkContextAfterTest) + assertOutput(beforeRestartOutput, expectedOutput, beforeRestart = true) + // Restart and complete the computation from checkpoint file + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation " + + "\n-------------------------------------------\n" + ) + + val restartedSsc = new StreamingContext(checkpointDir) + val afterRestartOutput = generateOutput[V](restartedSsc, + Time(batchDurationMillis * totalNumBatches), checkpointDir, stopSparkContextAfterTest) + assertOutput(afterRestartOutput, expectedOutput, beforeRestart = false) + } + + protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = { + val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + new StreamingContext(SparkContext.getOrCreate(conf), batchDuration) + } + + private def generateOutput[V: ClassTag]( + ssc: StreamingContext, + targetBatchTime: Time, + checkpointDir: String, + stopSparkContext: Boolean + ): Seq[Seq[V]] = { + try { + val batchDuration = ssc.graph.batchDuration + val batchCounter = new BatchCounter(ssc) + ssc.start() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val currentTime = clock.getTimeMillis() + + logInfo("Manual clock before advancing = " + clock.getTimeMillis()) + clock.setTime(targetBatchTime.milliseconds) + logInfo("Manual clock after advancing = " + clock.getTimeMillis()) + + val outputStream = ssc.graph.getOutputStreams().filter { dstream => + dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] + }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + + eventually(timeout(10 seconds)) { + ssc.awaitTerminationOrTimeout(10) + assert(batchCounter.getLastCompletedBatchTime === targetBatchTime) + } + + eventually(timeout(10 seconds)) { + val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter { + _.toString.contains(clock.getTimeMillis.toString) + } + // Checkpoint files are written twice for every batch interval. So assert that both + // are written to make sure that both of them have been written. + assert(checkpointFilesOfLatestTime.size === 2) + } + outputStream.output.map(_.flatten) + + } finally { + ssc.stop(stopSparkContext = stopSparkContext) + } + } + + private def assertOutput[V: ClassTag]( + output: Seq[Seq[V]], + expectedOutput: Seq[Seq[V]], + beforeRestart: Boolean): Unit = { + val expectedPartialOutput = if (beforeRestart) { + expectedOutput.take(output.size) + } else { + expectedOutput.takeRight(output.size) + } + val setComparison = output.zip(expectedPartialOutput).forall { + case (o, e) => o.toSet === e.toSet + } + assert(setComparison, s"set comparison failed\n" + + s"Expected output items:\n${expectedPartialOutput.mkString("\n")}\n" + + s"Generated output items: ${output.mkString("\n")}" + ) + } +} + /** * This test suites tests the checkpointing functionality of DStreams - * the checkpointing of a DStream's RDDs as well as the checkpointing of * the whole DStream graph. */ -class CheckpointSuite extends TestSuiteBase { +class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester { var ssc: StreamingContext = null @@ -56,7 +188,7 @@ class CheckpointSuite extends TestSuiteBase { override def afterFunction() { super.afterFunction() - if (ssc != null) ssc.stop() + if (ssc != null) { ssc.stop() } Utils.deleteRecursively(new File(checkpointDir)) } @@ -251,7 +383,9 @@ class CheckpointSuite extends TestSuiteBase { Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), - Seq(("", 2)), Seq() ), + Seq(("", 2)), + Seq() + ), 3 ) } @@ -634,53 +768,6 @@ class CheckpointSuite extends TestSuiteBase { checkpointWriter.stop() } - /** - * Tests a streaming operation under checkpointing, by restarting the operation - * from checkpoint file and verifying whether the final output is correct. - * The output is assumed to have come from a reliable queue which an replay - * data as required. - * - * NOTE: This takes into consideration that the last batch processed before - * master failure will be re-processed after restart/recovery. - */ - def testCheckpointedOperation[U: ClassTag, V: ClassTag]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - initialNumBatches: Int - ) { - - // Current code assumes that: - // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val nextNumBatches = totalNumBatches - initialNumBatches - val initialNumExpectedOutputs = initialNumBatches - val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1 - // because the last batch will be processed again - - // Do the computation for initial number of batches, create checkpoint file and quit - ssc = setupStreams[U, V](input, operation) - ssc.start() - val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches) - ssc.stop() - verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) - Thread.sleep(1000) - - // Restart and complete the computation from checkpoint file - logInfo( - "\n-------------------------------------------\n" + - " Restarting stream computation " + - "\n-------------------------------------------\n" - ) - ssc = new StreamingContext(checkpointDir) - ssc.start() - val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) - // the first element will be re-processed data of the last batch before restart - verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) - ssc.stop() - ssc = null - } - /** * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index a45c92d9c7bc8..be0f4636a6cb8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -142,6 +142,7 @@ class BatchCounter(ssc: StreamingContext) { // All access to this state should be guarded by `BatchCounter.this.synchronized` private var numCompletedBatches = 0 private var numStartedBatches = 0 + private var lastCompletedBatchTime: Time = null private val listener = new StreamingListener { override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = @@ -152,6 +153,7 @@ class BatchCounter(ssc: StreamingContext) { override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = BatchCounter.this.synchronized { numCompletedBatches += 1 + lastCompletedBatchTime = batchCompleted.batchInfo.batchTime BatchCounter.this.notifyAll() } } @@ -165,6 +167,10 @@ class BatchCounter(ssc: StreamingContext) { numStartedBatches } + def getLastCompletedBatchTime: Time = this.synchronized { + lastCompletedBatchTime + } + /** * Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if * `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index 58aef74c0040f..1fc320d31b18b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -25,31 +25,27 @@ import scala.reflect.ClassTag import org.scalatest.PrivateMethodTester._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} +import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { +class TrackStateByKeySuite extends SparkFunSuite + with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { private var sc: SparkContext = null - private var ssc: StreamingContext = null - private var checkpointDir: File = null - private val batchDuration = Seconds(1) + protected var checkpointDir: File = null + protected val batchDuration = Seconds(1) before { - StreamingContext.getActive().foreach { - _.stop(stopSparkContext = false) - } + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } checkpointDir = Utils.createTempDir("checkpoint") - - ssc = new StreamingContext(sc, batchDuration) - ssc.checkpoint(checkpointDir.toString) } after { - StreamingContext.getActive().foreach { - _.stop(stopSparkContext = false) + if (checkpointDir != null) { + Utils.deleteRecursively(checkpointDir) } + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } } override def beforeAll(): Unit = { @@ -242,7 +238,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef assert(dstreamImpl.stateClass === classOf[Double]) assert(dstreamImpl.emittedClass === classOf[Long]) } - + val ssc = new StreamingContext(sc, batchDuration) val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2) // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types @@ -451,8 +447,9 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef expectedCheckpointDuration: Duration, explicitCheckpointDuration: Option[Duration] = None ): Unit = { + val ssc = new StreamingContext(sc, batchDuration) + try { - ssc = new StreamingContext(sc, batchDuration) val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) val dummyFunc = (value: Option[Int], state: State[Int]) => 0 val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc)) @@ -462,11 +459,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef trackStateStream.checkpoint(d) } trackStateStream.register() + ssc.checkpoint(checkpointDir.toString) ssc.start() // should initialize all the checkpoint durations assert(trackStateStream.checkpointDuration === null) assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration) } finally { - StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + ssc.stop(stopSparkContext = false) } } @@ -479,6 +477,50 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) } + + test("trackStateByKey - driver failure recovery") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + def operation(dstream: DStream[String]): DStream[(String, Int)] = { + + val checkpointDuration = batchDuration * (stateData.size / 2) + + val runningCount = (value: Option[Int], state: State[Int]) => { + state.update(state.getOption().getOrElse(0) + value.getOrElse(0)) + state.get() + } + + val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey( + StateSpec.function(runningCount)) + // Set internval make sure there is one RDD checkpointing + trackStateStream.checkpoint(checkpointDuration) + trackStateStream.stateSnapshots() + } + + testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2, + batchDuration = batchDuration, stopSparkContextAfterTest = false) + } + private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], trackStateSpec: StateSpec[K, Int, S, T], @@ -500,6 +542,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { // Setup the stream computation + val ssc = new StreamingContext(sc, Seconds(1)) val inputStream = new TestInputStream(ssc, input, numPartitions = 2) val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec) val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] @@ -511,12 +554,14 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef stateSnapshotStream.register() val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir.toString) ssc.start() val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] clock.advance(batchDuration.milliseconds * numBatches) batchCounter.waitUntilBatchesCompleted(numBatches, 10000) + ssc.stop(stopSparkContext = false) (collectedOutputs, collectedStateSnapshots) } From c7bb3b2bc3f83574d60aa56fda68b69c5d32403d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 7 Dec 2015 12:01:09 -0800 Subject: [PATCH 1058/1168] [SPARK-12060][CORE] Avoid memory copy in JavaSerializerInstance.serialize Merged #10051 again since #10083 is resolved. This reverts commit 328b757d5d4486ea3c2e246780792d7a57ee85e5. Author: Shixiong Zhu Closes #10167 from zsxwing/merge-SPARK-12060. --- .../spark/serializer/JavaSerializer.scala | 7 ++--- .../spark/util/ByteBufferOutputStream.scala | 31 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index b463a71d5bd7d..ea718a0edbe71 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -24,8 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.ByteBufferInputStream -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} private[spark] class JavaSerializationStream( out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) @@ -96,11 +95,11 @@ private[spark] class JavaSerializerInstance( extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bos = new ByteArrayOutputStream() + val bos = new ByteBufferOutputStream() val out = serializeStream(bos) out.writeObject(t) out.close() - ByteBuffer.wrap(bos.toByteArray) + bos.toByteBuffer } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala new file mode 100644 index 0000000000000..92e45224db81c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer + +/** + * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer + */ +private[spark] class ByteBufferOutputStream extends ByteArrayOutputStream { + + def toByteBuffer: ByteBuffer = { + return ByteBuffer.wrap(buf, 0, count) + } +} From bb017dd5d67c485d22e7e787bafc95b58dd46982 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 7 Dec 2015 13:16:47 -0800 Subject: [PATCH 1059/1168] [SPARK-11963][DOC] Add docs for QuantileDiscretizer https://issues.apache.org/jira/browse/SPARK-11963 Author: Xusen Yin Closes #9962 from yinxusen/SPARK-11963. --- docs/ml-features.md | 65 +++++++++++++++++ .../ml/JavaQuantileDiscretizerExample.java | 71 +++++++++++++++++++ .../ml/QuantileDiscretizerExample.scala | 49 +++++++++++++ 3 files changed, 185 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 05c2c96c5ec5a..b499d6ec3b90e 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1705,6 +1705,71 @@ print(output.select("features", "clicked").first()) +## QuantileDiscretizer + +`QuantileDiscretizer` takes a column with continuous features and outputs a column with binned +categorical features. +The bin ranges are chosen by taking a sample of the data and dividing it into roughly equal parts. +The lower and upper bin bounds will be `-Infinity` and `+Infinity`, covering all real values. +This attempts to find `numBuckets` partitions based on a sample of the given input data, but it may +find fewer depending on the data sample values. + +Note that the result may be different every time you run it, since the sample strategy behind it is +non-deterministic. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `hour`: + +~~~ + id | hour +----|------ + 0 | 18.0 +----|------ + 1 | 19.0 +----|------ + 2 | 8.0 +----|------ + 3 | 5.0 +----|------ + 4 | 2.2 +~~~ + +`hour` is a continuous feature with `Double` type. We want to turn the continuous feature into +categorical one. Given `numBuckets = 3`, we should get the following DataFrame: + +~~~ + id | hour | result +----|------|------ + 0 | 18.0 | 2.0 +----|------|------ + 1 | 19.0 | 2.0 +----|------|------ + 2 | 8.0 | 1.0 +----|------|------ + 3 | 5.0 | 1.0 +----|------|------ + 4 | 2.2 | 0.0 +~~~ + +
    +
    + +Refer to the [QuantileDiscretizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.QuantileDiscretizer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala %} +
    + +
    + +Refer to the [QuantileDiscretizer Java docs](api/java/org/apache/spark/ml/feature/QuantileDiscretizer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java %} +
    +
    + # Feature Selectors ## VectorSlicer diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java new file mode 100644 index 0000000000000..251ae79d9a108 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.QuantileDiscretizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaQuantileDiscretizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaQuantileDiscretizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize( + Arrays.asList( + RowFactory.create(0, 18.0), + RowFactory.create(1, 19.0), + RowFactory.create(2, 8.0), + RowFactory.create(3, 5.0), + RowFactory.create(4, 2.2) + ) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("hour", DataTypes.DoubleType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + QuantileDiscretizer discretizer = new QuantileDiscretizer() + .setInputCol("hour") + .setOutputCol("result") + .setNumBuckets(3); + + DataFrame result = discretizer.fit(df).transform(df); + result.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala new file mode 100644 index 0000000000000..8f29b7eaa6d26 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.QuantileDiscretizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object QuantileDiscretizerExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("QuantileDiscretizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // $example on$ + val data = Array((0, 18.0), (1, 19.0), (2, 8.0), (3, 5.0), (4, 2.2)) + val df = sc.parallelize(data).toDF("id", "hour") + + val discretizer = new QuantileDiscretizer() + .setInputCol("hour") + .setOutputCol("result") + .setNumBuckets(3) + + val result = discretizer.fit(df).transform(df) + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println From db2eca8c9a5da66c0d244dce53084dae5a7e583b Mon Sep 17 00:00:00 2001 From: tedyu Date: Mon, 7 Dec 2015 14:58:09 -0800 Subject: [PATCH 1060/1168] [SPARK-11884] Drop multiple columns in the DataFrame API See the thread Ben started: http://search-hadoop.com/m/q3RTtveEuhjsr7g/ This PR adds drop() method to DataFrame which accepts multiple column names Author: tedyu Closes #9862 from ted-yu/master. --- .../org/apache/spark/sql/DataFrame.scala | 24 ++++++++++++------- .../org/apache/spark/sql/DataFrameSuite.scala | 7 ++++++ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index eb8700369275e..243a8c853f90e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1261,16 +1261,24 @@ class DataFrame private[sql]( * @since 1.4.0 */ def drop(colName: String): DataFrame = { + drop(Seq(colName) : _*) + } + + /** + * Returns a new [[DataFrame]] with columns dropped. + * This is a no-op if schema doesn't contain column name(s). + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def drop(colNames: String*): DataFrame = { val resolver = sqlContext.analyzer.resolver - val shouldDrop = schema.exists(f => resolver(f.name, colName)) - if (shouldDrop) { - val colsAfterDrop = schema.filter { field => - val name = field.name - !resolver(name, colName) - }.map(f => Column(f.name)) - select(colsAfterDrop : _*) - } else { + val remainingCols = + schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name)) + if (remainingCols.size == this.schema.size) { this + } else { + this.select(remainingCols: _*) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 76e9648aa7533..605a6549dd686 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -378,6 +378,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("value")) } + test("drop columns using drop") { + val src = Seq((0, 2, 3)).toDF("a", "b", "c") + val df = src.drop("a", "b") + checkAnswer(df, Row(3)) + assert(df.schema.map(_.name) === Seq("c")) + } + test("drop unknown column (no-op)") { val df = testData.drop("random") checkAnswer( From 8658058afc47827c702233d44fed879615f51099 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Mon, 7 Dec 2015 15:01:00 -0800 Subject: [PATCH 1061/1168] [SPARK-12184][PYTHON] Make python api doc for pivot consistant with scala doc In SPARK-11946 the API for pivot was changed a bit and got updated doc, the doc changes were not made for the python api though. This PR updates the python doc to be consistent. Author: Andrew Ray Closes #10176 from aray/sql-pivot-python-doc. --- python/pyspark/sql/group.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 1911588309aff..9ca303a974cd4 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -169,16 +169,20 @@ def sum(self, *cols): @since(1.6) def pivot(self, pivot_col, values=None): - """Pivots a column of the current DataFrame and perform the specified aggregation. + """ + Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + There are two versions of pivot function: one that requires the caller to specify the list + of distinct values to pivot on, and one that does not. The latter is more concise but less + efficient, because Spark needs to first compute the list of distinct values internally. - :param pivot_col: Column to pivot - :param values: Optional list of values of pivot column that will be translated to columns in - the output DataFrame. If values are not provided the method will do an immediate call - to .distinct() on the pivot column. + :param pivot_col: Name of the column to pivot. + :param values: List of values that will be translated to columns in the output DataFrame. + // Compute the sum of earnings for each year by course with each course as a separate column >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] + // Or without specifying column values (less efficient) >>> df4.groupBy("year").pivot("course").sum("earnings").collect() [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ From 621193071402462a3850647531c4473049d1676f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 7 Dec 2015 16:37:09 -0800 Subject: [PATCH 1062/1168] [SPARK-12160][MLLIB] Use SQLContext.getOrCreate in MLlib Switched from using SQLContext constructor to using getOrCreate, mainly in model save/load methods. This covers all instances in spark.mllib. There were no uses of the constructor in spark.ml. CC: mengxr yhuai Author: Joseph K. Bradley Closes #10161 from jkbradley/mllib-sqlcontext-fix. --- .../apache/spark/mllib/api/python/PythonMLLibAPI.scala | 6 +++--- .../apache/spark/mllib/classification/NaiveBayes.scala | 8 ++++---- .../classification/impl/GLMClassificationModel.scala | 4 ++-- .../spark/mllib/clustering/GaussianMixtureModel.scala | 4 ++-- .../org/apache/spark/mllib/clustering/KMeansModel.scala | 4 ++-- .../spark/mllib/clustering/PowerIterationClustering.scala | 4 ++-- .../org/apache/spark/mllib/feature/ChiSqSelector.scala | 4 ++-- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 ++-- .../mllib/recommendation/MatrixFactorizationModel.scala | 4 ++-- .../spark/mllib/regression/IsotonicRegression.scala | 4 ++-- .../spark/mllib/regression/impl/GLMRegressionModel.scala | 4 ++-- .../apache/spark/mllib/tree/model/DecisionTreeModel.scala | 4 ++-- .../spark/mllib/tree/model/treeEnsembleModels.scala | 4 ++-- 13 files changed, 29 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 54b03a9f90283..2aa6aec0b4347 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1191,7 +1191,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = { // We use DataFrames for serialization of IndexedRows to Python, // so return a DataFrame. - val sqlContext = new SQLContext(indexedRowMatrix.rows.sparkContext) + val sqlContext = SQLContext.getOrCreate(indexedRowMatrix.rows.sparkContext) sqlContext.createDataFrame(indexedRowMatrix.rows) } @@ -1201,7 +1201,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = { // We use DataFrames for serialization of MatrixEntry entries to // Python, so return a DataFrame. - val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext) + val sqlContext = SQLContext.getOrCreate(coordinateMatrix.entries.sparkContext) sqlContext.createDataFrame(coordinateMatrix.entries) } @@ -1211,7 +1211,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = { // We use DataFrames for serialization of sub-matrix blocks to // Python, so return a DataFrame. - val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext) + val sqlContext = SQLContext.getOrCreate(blockMatrix.blocks.sparkContext) sqlContext.createDataFrame(blockMatrix.blocks) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index a956084ae06e8..aef9ef2cb052d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -192,7 +192,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { modelType: String) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -208,7 +208,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { @Since("1.3.0") def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. @@ -239,7 +239,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -254,7 +254,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { } def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index fe09f6b75d28b..2910c027ae06d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel { weights: Vector, intercept: Double, threshold: Option[Double]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -74,7 +74,7 @@ private[classification] object GLMClassificationModel { */ def loadData(sc: SparkContext, path: String, modelClass: String): Data = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 2115f7d99c182..74d13e4f77945 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -145,7 +145,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { weights: Array[Double], gaussians: Array[MultivariateGaussian]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -162,7 +162,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { def load(sc: SparkContext, path: String): GaussianMixtureModel = { val dataPath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a741584982725..91fa9b0d3590d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -124,7 +124,7 @@ object KMeansModel extends Loader[KMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) @@ -137,7 +137,7 @@ object KMeansModel extends Loader[KMeansModel] { def load(sc: SparkContext, path: String): KMeansModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 7cd9b08fa8e0e..bb1804505948b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -70,7 +70,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( @@ -84,7 +84,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index d4d022afde051..eaa99cfe82e27 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -134,7 +134,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel" def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -150,7 +150,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { def load(sc: SparkContext, path: String): ChiSqSelectorModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index b693f3c8e4bd9..23b1514e3080e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -587,7 +587,7 @@ object Word2VecModel extends Loader[Word2VecModel] { def load(sc: SparkContext, path: String): Word2VecModel = { val dataPath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) @@ -599,7 +599,7 @@ object Word2VecModel extends Loader[Word2VecModel] { def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val vectorSize = model.values.head.size diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 46562eb2ad0f7..0dc40483dd0ff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -353,7 +353,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { */ def save(model: MatrixFactorizationModel, path: String): Unit = { val sc = model.userFeatures.sparkContext - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) @@ -364,7 +364,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { def load(sc: SparkContext, path: String): MatrixFactorizationModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index ec78ea24539b5..f235089873ab8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -185,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -198,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(dataPath(path)) checkSchema[Data](dataRDD.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index 317d3a5702636..02af281fb726b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel { modelClass: String, weights: Vector, intercept: Double): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -71,7 +71,7 @@ private[regression] object GLMRegressionModel { */ def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 54c136aecf660..89c470d573431 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -201,7 +201,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { } def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // SPARK-6120: We do a hacky check here so users understand why save() is failing @@ -242,7 +242,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(datapath) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 90e032e3d9842..3f427f0be3af2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -408,7 +408,7 @@ private[tree] object TreeEnsembleModel extends Logging { case class EnsembleNodeData(treeId: Int, node: NodeData) def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // SPARK-6120: We do a hacky check here so users understand why save() is failing @@ -468,7 +468,7 @@ private[tree] object TreeEnsembleModel extends Logging { path: String, treeAlgo: String): Array[DecisionTreeModel] = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply) val trees = constructTrees(nodes) trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) From cfd3bf4c15f14f835887d2db8fe1ff2f8b766792 Mon Sep 17 00:00:00 2001 From: somideshmukh Date: Mon, 7 Dec 2015 23:26:34 -0800 Subject: [PATCH 1063/1168] [SPARK-11551][DOC][EXAMPLE] Replace example code in ml-features.md using include_example Made new patch contaning only markdown examples moved to exmaple/folder. Ony three java code were not shfted since they were contaning compliation error ,these classes are 1)StandardScale 2)NormalizerExample 3)VectorIndexer Author: Xusen Yin Author: somideshmukh Closes #10002 from somideshmukh/SomilBranch1.33. --- docs/ml-features.md | 1109 +---------------- .../examples/ml/JavaBinarizerExample.java | 68 + .../examples/ml/JavaBucketizerExample.java | 70 ++ .../spark/examples/ml/JavaDCTExample.java | 65 + .../ml/JavaElementwiseProductExample.java | 75 ++ .../examples/ml/JavaMinMaxScalerExample.java | 50 + .../spark/examples/ml/JavaNGramExample.java | 71 ++ .../examples/ml/JavaNormalizerExample.java | 52 + .../examples/ml/JavaOneHotEncoderExample.java | 77 ++ .../spark/examples/ml/JavaPCAExample.java | 71 ++ .../ml/JavaPolynomialExpansionExample.java | 71 ++ .../examples/ml/JavaRFormulaExample.java | 69 + .../ml/JavaStandardScalerExample.java | 53 + .../ml/JavaStopWordsRemoverExample.java | 65 + .../examples/ml/JavaStringIndexerExample.java | 66 + .../examples/ml/JavaTokenizerExample.java | 75 ++ .../ml/JavaVectorAssemblerExample.java | 67 + .../examples/ml/JavaVectorIndexerExample.java | 60 + .../examples/ml/JavaVectorSlicerExample.java | 73 ++ .../src/main/python/ml/binarizer_example.py | 43 + .../src/main/python/ml/bucketizer_example.py | 42 + .../python/ml/elementwise_product_example.py | 39 + examples/src/main/python/ml/n_gram_example.py | 42 + .../src/main/python/ml/normalizer_example.py | 41 + .../main/python/ml/onehot_encoder_example.py | 47 + examples/src/main/python/ml/pca_example.py | 42 + .../python/ml/polynomial_expansion_example.py | 43 + .../src/main/python/ml/rformula_example.py | 44 + .../main/python/ml/standard_scaler_example.py | 42 + .../python/ml/stopwords_remover_example.py | 40 + .../main/python/ml/string_indexer_example.py | 39 + .../src/main/python/ml/tokenizer_example.py | 44 + .../python/ml/vector_assembler_example.py | 42 + .../main/python/ml/vector_indexer_example.py | 39 + .../spark/examples/ml/BinarizerExample.scala | 48 + .../spark/examples/ml/BucketizerExample.scala | 51 + .../apache/spark/examples/ml/DCTExample.scala | 54 + .../ml/ElementWiseProductExample.scala | 53 + .../examples/ml/MinMaxScalerExample.scala | 49 + .../spark/examples/ml/NGramExample.scala | 47 + .../spark/examples/ml/NormalizerExample.scala | 50 + .../examples/ml/OneHotEncoderExample.scala | 58 + .../apache/spark/examples/ml/PCAExample.scala | 54 + .../ml/PolynomialExpansionExample.scala | 53 + .../spark/examples/ml/RFormulaExample.scala | 49 + .../examples/ml/StandardScalerExample.scala | 51 + .../examples/ml/StopWordsRemoverExample.scala | 48 + .../examples/ml/StringIndexerExample.scala | 49 + .../spark/examples/ml/TokenizerExample.scala | 54 + .../examples/ml/VectorAssemblerExample.scala | 49 + .../examples/ml/VectorIndexerExample.scala | 53 + .../examples/ml/VectorSlicerExample.scala | 58 + 52 files changed, 2806 insertions(+), 1058 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java create mode 100644 examples/src/main/python/ml/binarizer_example.py create mode 100644 examples/src/main/python/ml/bucketizer_example.py create mode 100644 examples/src/main/python/ml/elementwise_product_example.py create mode 100644 examples/src/main/python/ml/n_gram_example.py create mode 100644 examples/src/main/python/ml/normalizer_example.py create mode 100644 examples/src/main/python/ml/onehot_encoder_example.py create mode 100644 examples/src/main/python/ml/pca_example.py create mode 100644 examples/src/main/python/ml/polynomial_expansion_example.py create mode 100644 examples/src/main/python/ml/rformula_example.py create mode 100644 examples/src/main/python/ml/standard_scaler_example.py create mode 100644 examples/src/main/python/ml/stopwords_remover_example.py create mode 100644 examples/src/main/python/ml/string_indexer_example.py create mode 100644 examples/src/main/python/ml/tokenizer_example.py create mode 100644 examples/src/main/python/ml/vector_assembler_example.py create mode 100644 examples/src/main/python/ml/vector_indexer_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index b499d6ec3b90e..5105a948fec8e 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -170,25 +170,7 @@ Refer to the [Tokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.fea and the [RegexTokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} - -val sentenceDataFrame = sqlContext.createDataFrame(Seq( - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -)).toDF("label", "sentence") -val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) - -val tokenized = tokenizer.transform(sentenceDataFrame) -tokenized.select("words", "label").take(3).foreach(println) -val regexTokenized = regexTokenizer.transform(sentenceDataFrame) -regexTokenized.select("words", "label").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/TokenizerExample.scala %}
    @@ -197,44 +179,7 @@ Refer to the [Tokenizer Java docs](api/java/org/apache/spark/ml/feature/Tokenize and the [RegexTokenizer Java docs](api/java/org/apache/spark/ml/feature/RegexTokenizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RegexTokenizer; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(1, "I wish Java could use case classes"), - RowFactory.create(2, "Logistic,regression,models,are,neat") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); -Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); -DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); -for (Row r : wordsDataFrame.select("words", "label").take(3)) { - java.util.List words = r.getList(0); - for (String word : words) System.out.print(word + " "); - System.out.println(); -} - -RegexTokenizer regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaTokenizerExample.java %}
    @@ -243,21 +188,7 @@ Refer to the [Tokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.featu the the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Tokenizer, RegexTokenizer - -sentenceDataFrame = sqlContext.createDataFrame([ - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -], ["label", "sentence"]) -tokenizer = Tokenizer(inputCol="sentence", outputCol="words") -wordsDataFrame = tokenizer.transform(sentenceDataFrame) -for words_label in wordsDataFrame.select("words", "label").take(3): - print(words_label) -regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") -# alternatively, pattern="\\w+", gaps(False) -{% endhighlight %} +{% include_example python/ml/tokenizer_example.py %}
    @@ -306,19 +237,7 @@ filtered out. Refer to the [StopWordsRemover Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StopWordsRemover - -val remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered") -val dataSet = sqlContext.createDataFrame(Seq( - (0, Seq("I", "saw", "the", "red", "baloon")), - (1, Seq("Mary", "had", "a", "little", "lamb")) -)).toDF("id", "raw") - -remover.transform(dataSet).show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala %}
    @@ -326,34 +245,7 @@ remover.transform(dataSet).show() Refer to the [StopWordsRemover Java docs](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StopWordsRemover; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -StopWordsRemover remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered"); - -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), - RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) -)); -StructType schema = new StructType(new StructField[] { - new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame dataset = jsql.createDataFrame(rdd, schema); - -remover.transform(dataset).show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java %}
    @@ -361,17 +253,7 @@ remover.transform(dataset).show(); Refer to the [StopWordsRemover Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StopWordsRemover - -sentenceData = sqlContext.createDataFrame([ - (0, ["I", "saw", "the", "red", "baloon"]), - (1, ["Mary", "had", "a", "little", "lamb"]) -], ["label", "raw"]) - -remover = StopWordsRemover(inputCol="raw", outputCol="filtered") -remover.transform(sentenceData).show(truncate=False) -{% endhighlight %} +{% include_example python/ml/stopwords_remover_example.py %}
    @@ -388,19 +270,7 @@ An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (t Refer to the [NGram Scala docs](api/scala/index.html#org.apache.spark.ml.feature.NGram) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.NGram - -val wordDataFrame = sqlContext.createDataFrame(Seq( - (0, Array("Hi", "I", "heard", "about", "Spark")), - (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), - (2, Array("Logistic", "regression", "models", "are", "neat")) -)).toDF("label", "words") - -val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") -val ngramDataFrame = ngram.transform(wordDataFrame) -ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NGramExample.scala %}
    @@ -408,38 +278,7 @@ ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(pri Refer to the [NGram Java docs](api/java/org/apache/spark/ml/feature/NGram.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.NGram; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), - RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), - RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); -NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); -DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); -for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { - java.util.List ngrams = r.getList(0); - for (String ngram : ngrams) System.out.print(ngram + " --- "); - System.out.println(); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNGramExample.java %}
    @@ -447,19 +286,7 @@ for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { Refer to the [NGram Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import NGram - -wordDataFrame = sqlContext.createDataFrame([ - (0, ["Hi", "I", "heard", "about", "Spark"]), - (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), - (2, ["Logistic", "regression", "models", "are", "neat"]) -], ["label", "words"]) -ngram = NGram(inputCol="words", outputCol="ngrams") -ngramDataFrame = ngram.transform(wordDataFrame) -for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): - print(ngrams_label) -{% endhighlight %} +{% include_example python/ml/n_gram_example.py %}
    @@ -476,26 +303,7 @@ Binarization is the process of thresholding numerical features to binary (0/1) f Refer to the [Binarizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Binarizer -import org.apache.spark.sql.DataFrame - -val data = Array( - (0, 0.1), - (1, 0.8), - (2, 0.2) -) -val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") - -val binarizer: Binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5) - -val binarizedDataFrame = binarizer.transform(dataFrame) -val binarizedFeatures = binarizedDataFrame.select("binarized_feature") -binarizedFeatures.collect().foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BinarizerExample.scala %}
    @@ -503,40 +311,7 @@ binarizedFeatures.collect().foreach(println) Refer to the [Binarizer Java docs](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Binarizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, 0.1), - RowFactory.create(1, 0.8), - RowFactory.create(2, 0.2) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); -Binarizer binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5); -DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); -DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); -for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); - System.out.println(binarized_value); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBinarizerExample.java %}
    @@ -544,20 +319,7 @@ for (Row r : binarizedFeatures.collect()) { Refer to the [Binarizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Binarizer - -continuousDataFrame = sqlContext.createDataFrame([ - (0, 0.1), - (1, 0.8), - (2, 0.2) -], ["label", "feature"]) -binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") -binarizedDataFrame = binarizer.transform(continuousDataFrame) -binarizedFeatures = binarizedDataFrame.select("binarized_feature") -for binarized_feature, in binarizedFeatures.collect(): - print(binarized_feature) -{% endhighlight %} +{% include_example python/ml/binarizer_example.py %}
    @@ -571,25 +333,7 @@ for binarized_feature, in binarizedFeatures.collect(): Refer to the [PCA Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PCA) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PCA -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), - Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), - Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df) -val pcaDF = pca.transform(df) -val result = pcaDF.select("pcaFeatures") -result.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PCAExample.scala %}
    @@ -597,42 +341,7 @@ result.show() Refer to the [PCA Java docs](api/java/org/apache/spark/ml/feature/PCA.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.PCA -import org.apache.spark.ml.feature.PCAModel -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), - RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -PCAModel pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df); -DataFrame result = pca.transform(df).select("pcaFeatures"); -result.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPCAExample.java %}
    @@ -640,19 +349,7 @@ result.show(); Refer to the [PCA Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PCA -from pyspark.mllib.linalg import Vectors - -data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), - (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), - (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] -df = sqlContext.createDataFrame(data,["features"]) -pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") -model = pca.fit(df) -result = model.transform(df).select("pcaFeatures") -result.show(truncate=False) -{% endhighlight %} +{% include_example python/ml/pca_example.py %}
    @@ -666,23 +363,7 @@ result.show(truncate=False) Refer to the [PolynomialExpansion Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PolynomialExpansion -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0), - Vectors.dense(0.6, -1.1) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val polynomialExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3) -val polyDF = polynomialExpansion.transform(df) -polyDF.select("polyFeatures").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala %}
    @@ -690,43 +371,7 @@ polyDF.select("polyFeatures").take(3).foreach(println) Refer to the [PolynomialExpansion Java docs](api/java/org/apache/spark/ml/feature/PolynomialExpansion.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -PolynomialExpansion polyExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3); -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(-2.0, 2.3)), - RowFactory.create(Vectors.dense(0.0, 0.0)), - RowFactory.create(Vectors.dense(0.6, -1.1)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DataFrame polyDF = polyExpansion.transform(df); -Row[] row = polyDF.select("polyFeatures").take(3); -for (Row r : row) { - System.out.println(r.get(0)); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java %}
    @@ -734,20 +379,7 @@ for (Row r : row) { Refer to the [PolynomialExpansion Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PolynomialExpansion -from pyspark.mllib.linalg import Vectors - -df = sqlContext.createDataFrame( - [(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], - ["features"]) -px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") -polyDF = px.transform(df) -for expanded in polyDF.select("polyFeatures").take(3): - print(expanded) -{% endhighlight %} +{% include_example python/ml/polynomial_expansion_example.py %}
    @@ -771,22 +403,7 @@ $0$th DCT coefficient and _not_ the $N/2$th). Refer to the [DCT Scala docs](api/scala/index.html#org.apache.spark.ml.feature.DCT) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.DCT -import org.apache.spark.mllib.linalg.Vectors - -val data = Seq( - Vectors.dense(0.0, 1.0, -2.0, 3.0), - Vectors.dense(-1.0, 2.0, 4.0, -7.0), - Vectors.dense(14.0, -2.0, -5.0, 1.0)) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false) -val dctDf = dct.transform(df) -dctDf.select("featuresDCT").show(3) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/DCTExample.scala %}
    @@ -794,39 +411,7 @@ dctDf.select("featuresDCT").show(3) Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.DCT; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), - RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), - RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DCT dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false); -DataFrame dctDf = dct.transform(df); -dctDf.select("featuresDCT").show(3); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}}
    @@ -881,18 +466,7 @@ index `2`. Refer to the [StringIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StringIndexer - -val df = sqlContext.createDataFrame( - Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) -).toDF("id", "category") -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") -val indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StringIndexerExample.scala %}
    @@ -900,37 +474,7 @@ indexed.show() Refer to the [StringIndexer Java docs](api/java/org/apache/spark/ml/feature/StringIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import static org.apache.spark.sql.types.DataTypes.*; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[] { - createStructField("id", DoubleType, false), - createStructField("category", StringType, false) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexer indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex"); -DataFrame indexed = indexer.fit(df).transform(df); -indexed.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStringIndexerExample.java %}
    @@ -938,16 +482,7 @@ indexed.show(); Refer to the [StringIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StringIndexer - -df = sqlContext.createDataFrame( - [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], - ["id", "category"]) -indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example python/ml/string_indexer_example.py %}
    @@ -961,29 +496,7 @@ indexed.show() Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} - -val df = sqlContext.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -)).toDF("id", "category") - -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) -val indexed = indexer.transform(df) - -val encoder = new OneHotEncoder().setInputCol("categoryIndex"). - setOutputCol("categoryVec") -val encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %}
    @@ -991,45 +504,7 @@ encoded.select("id", "categoryVec").foreach(println) Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); -DataFrame indexed = indexer.transform(df); - -OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); -DataFrame encoded = encoder.transform(indexed); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %}
    @@ -1037,24 +512,7 @@ DataFrame encoded = encoder.transform(indexed); Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import OneHotEncoder, StringIndexer - -df = sqlContext.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -], ["id", "category"]) - -stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -model = stringIndexer.fit(df) -indexed = model.transform(df) -encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") -encoded = encoder.transform(indexed) -{% endhighlight %} +{% include_example python/ml/onehot_encoder_example.py %}
    @@ -1078,23 +536,7 @@ In the example below, we read in a dataset of labeled points and then use `Vecto Refer to the [VectorIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.VectorIndexer - -val data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10) -val indexerModel = indexer.fit(data) -val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet -println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) - -// Create new column "indexed" with categorical values transformed to indices -val indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorIndexerExample.scala %}
    @@ -1102,30 +544,7 @@ val indexedData = indexerModel.transform(data) Refer to the [VectorIndexer Java docs](api/java/org/apache/spark/ml/feature/VectorIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Map; - -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame data = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -VectorIndexer indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10); -VectorIndexerModel indexerModel = indexer.fit(data); -Map> categoryMaps = indexerModel.javaCategoryMaps(); -System.out.print("Chose " + categoryMaps.size() + "categorical features:"); -for (Integer feature : categoryMaps.keySet()) { - System.out.print(" " + feature); -} -System.out.println(); - -// Create new column "indexed" with categorical values transformed to indices -DataFrame indexedData = indexerModel.transform(data); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java %}
    @@ -1133,17 +552,7 @@ DataFrame indexedData = indexerModel.transform(data); Refer to the [VectorIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import VectorIndexer - -data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) -indexerModel = indexer.fit(data) - -# Create new column "indexed" with categorical values transformed to indices -indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example python/ml/vector_indexer_example.py %}
    @@ -1160,22 +569,7 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Normalizer - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -// Normalize each Vector using $L^1$ norm. -val normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0) -val l1NormData = normalizer.transform(dataFrame) - -// Normalize each Vector using $L^\infty$ norm. -val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %}
    @@ -1183,24 +577,7 @@ val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.Positi Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); - -// Normalize each Vector using $L^1$ norm. -Normalizer normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0); -DataFrame l1NormData = normalizer.transform(dataFrame); - -// Normalize each Vector using $L^\infty$ norm. -DataFrame lInfNormData = - normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %}
    @@ -1208,19 +585,7 @@ DataFrame lInfNormData = Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Normalizer - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -# Normalize each Vector using $L^1$ norm. -normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) -l1NormData = normalizer.transform(dataFrame) - -# Normalize each Vector using $L^\infty$ norm. -lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) -{% endhighlight %} +{% include_example python/ml/normalizer_example.py %}
    @@ -1244,23 +609,7 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StandardScaler - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false) - -// Compute summary statistics by fitting the StandardScaler -val scalerModel = scaler.fit(dataFrame) - -// Normalize each feature to have unit standard deviation. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %}
    @@ -1268,25 +617,7 @@ val scaledData = scalerModel.transform(dataFrame) Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.ml.feature.StandardScalerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -StandardScaler scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false); - -// Compute summary statistics by fitting the StandardScaler -StandardScalerModel scalerModel = scaler.fit(dataFrame); - -// Normalize each feature to have unit standard deviation. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %}
    @@ -1294,20 +625,7 @@ DataFrame scaledData = scalerModel.transform(dataFrame); Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StandardScaler - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", - withStd=True, withMean=False) - -# Compute summary statistics by fitting the StandardScaler -scalerModel = scaler.fit(dataFrame) - -# Normalize each feature to have unit standard deviation. -scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/standard_scaler_example.py %}
    @@ -1337,21 +655,7 @@ Refer to the [MinMaxScaler Scala docs](api/scala/index.html#org.apache.spark.ml. and the [MinMaxScalerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.MinMaxScaler - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - -// Compute summary statistics and generate MinMaxScalerModel -val scalerModel = scaler.fit(dataFrame) - -// rescale each feature to range [min, max]. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala %}
    @@ -1360,24 +664,7 @@ Refer to the [MinMaxScaler Java docs](api/java/org/apache/spark/ml/feature/MinMa and the [MinMaxScalerModel Java docs](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) for more details on the API. -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.MinMaxScaler; -import org.apache.spark.ml.feature.MinMaxScalerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -MinMaxScaler scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures"); - -// Compute summary statistics and generate MinMaxScalerModel -MinMaxScalerModel scalerModel = scaler.fit(dataFrame); - -// rescale each feature to range [min, max]. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %}
    @@ -1401,23 +688,7 @@ The following example demonstrates how to bucketize a column of `Double`s into a Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Bucketizer -import org.apache.spark.sql.DataFrame - -val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) - -val data = Array(-0.5, -0.3, 0.0, 0.2) -val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - -val bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits) - -// Transform original data into its bucket index. -val bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %}
    @@ -1425,38 +696,7 @@ val bucketedData = bucketizer.transform(dataFrame) Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(-0.5), - RowFactory.create(-0.3), - RowFactory.create(0.0), - RowFactory.create(0.2) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame dataFrame = jsql.createDataFrame(data, schema); - -Bucketizer bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits); - -// Transform original data into its bucket index. -DataFrame bucketedData = bucketizer.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %}
    @@ -1464,19 +704,7 @@ DataFrame bucketedData = bucketizer.transform(dataFrame); Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Bucketizer - -splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] - -data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] -dataFrame = sqlContext.createDataFrame(data, ["features"]) - -bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") - -# Transform original data into its bucket index. -bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/bucketizer_example.py %}
    @@ -1508,25 +736,7 @@ This example below demonstrates how to transform vectors using a transforming ve Refer to the [ElementwiseProduct Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.ElementwiseProduct -import org.apache.spark.mllib.linalg.Vectors - -// Create some vector data; also works for sparse vectors -val dataFrame = sqlContext.createDataFrame(Seq( - ("a", Vectors.dense(1.0, 2.0, 3.0)), - ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") - -val transformingVector = Vectors.dense(0.0, 1.0, 2.0) -val transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector") - -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show() - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala %}
    @@ -1534,41 +744,7 @@ transformer.transform(dataFrame).show() Refer to the [ElementwiseProduct Java docs](api/java/org/apache/spark/ml/feature/ElementwiseProduct.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.ElementwiseProduct; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -// Create some vector data; also works for sparse vectors -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), - RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) -)); -List fields = new ArrayList(2); -fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); -fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); -StructType schema = DataTypes.createStructType(fields); -DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); -Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); -ElementwiseProduct transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector"); -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show(); - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java %}
    @@ -1576,19 +752,8 @@ transformer.transform(dataFrame).show(); Refer to the [ElementwiseProduct Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import ElementwiseProduct -from pyspark.mllib.linalg import Vectors - -data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] -df = sqlContext.createDataFrame(data, ["vector"]) -transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), - inputCol="vector", outputCol="transformedVector") -transformer.transform(df).show() - -{% endhighlight %} +{% include_example python/ml/elementwise_product_example.py %}
    - ## VectorAssembler @@ -1632,19 +797,7 @@ output column to `features`, after transformation we should get the following Da Refer to the [VectorAssembler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.feature.VectorAssembler - -val dataset = sqlContext.createDataFrame( - Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) -).toDF("id", "hour", "mobile", "userFeatures", "clicked") -val assembler = new VectorAssembler() - .setInputCols(Array("hour", "mobile", "userFeatures")) - .setOutputCol("features") -val output = assembler.transform(dataset) -println(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala %}
    @@ -1652,36 +805,7 @@ println(output.select("features", "clicked").first()) Refer to the [VectorAssembler Java docs](api/java/org/apache/spark/ml/feature/VectorAssembler.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("hour", IntegerType, false), - createStructField("mobile", DoubleType, false), - createStructField("userFeatures", new VectorUDT(), false), - createStructField("clicked", DoubleType, false) -}); -Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); -JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) - .setOutputCol("features"); - -DataFrame output = assembler.transform(dataset); -System.out.println(output.select("features", "clicked").first()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java %}
    @@ -1689,19 +813,7 @@ System.out.println(output.select("features", "clicked").first()); Refer to the [VectorAssembler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) for more details on the API. -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.ml.feature import VectorAssembler - -dataset = sqlContext.createDataFrame( - [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], - ["id", "hour", "mobile", "userFeatures", "clicked"]) -assembler = VectorAssembler( - inputCols=["hour", "mobile", "userFeatures"], - outputCol="features") -output = assembler.transform(dataset) -print(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example python/ml/vector_assembler_example.py %}
    @@ -1831,33 +943,7 @@ Suppose also that we have a potential input attributes for the `userFeatures`, i Refer to the [VectorSlicer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} -import org.apache.spark.ml.feature.VectorSlicer -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext} - -val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.dense(-2.0, 2.3, 0.0) -) - -val defaultAttr = NumericAttribute.defaultAttr -val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) -val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) - -val dataRDD = sc.parallelize(data).map(Row.apply) -val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField())) - -val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") - -slicer.setIndices(1).setNames("f3") -// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) - -val output = slicer.transform(dataset) -println(output.select("userFeatures", "features").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorSlicerExample.scala %}
    @@ -1865,41 +951,7 @@ println(output.select("userFeatures", "features").first()) Refer to the [VectorSlicer Java docs](api/java/org/apache/spark/ml/feature/VectorSlicer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -Attribute[] attrs = new Attribute[]{ - NumericAttribute.defaultAttr().withName("f1"), - NumericAttribute.defaultAttr().withName("f2"), - NumericAttribute.defaultAttr().withName("f3") -}; -AttributeGroup group = new AttributeGroup("userFeatures", attrs); - -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), - RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) -)); - -DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); - -VectorSlicer vectorSlicer = new VectorSlicer() - .setInputCol("userFeatures").setOutputCol("features"); - -vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); -// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) - -DataFrame output = vectorSlicer.transform(dataset); - -System.out.println(output.select("userFeatures", "features").first()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java %}
    @@ -1936,21 +988,7 @@ id | country | hour | clicked | features | label Refer to the [RFormula Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RFormula) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.RFormula - -val dataset = sqlContext.createDataFrame(Seq( - (7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0) -)).toDF("id", "country", "hour", "clicked") -val formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label") -val output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/RFormulaExample.scala %}
    @@ -1958,38 +996,7 @@ output.select("features", "label").show() Refer to the [RFormula Java docs](api/java/org/apache/spark/ml/feature/RFormula.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RFormula; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("country", StringType, false), - createStructField("hour", IntegerType, false), - createStructField("clicked", DoubleType, false) -}); -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(7, "US", 18, 1.0), - RowFactory.create(8, "CA", 12, 0.0), - RowFactory.create(9, "NZ", 15, 0.0) -)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -RFormula formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label"); - -DataFrame output = formula.fit(dataset).transform(dataset); -output.select("features", "label").show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaRFormulaExample.java %}
    @@ -1997,21 +1004,7 @@ output.select("features", "label").show(); Refer to the [RFormula Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import RFormula - -dataset = sqlContext.createDataFrame( - [(7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0)], - ["id", "country", "hour", "clicked"]) -formula = RFormula( - formula="clicked ~ country + hour", - featuresCol="features", - labelCol="label") -output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example python/ml/rformula_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java new file mode 100644 index 0000000000000..9698cac504371 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Binarizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBinarizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBinarizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 0.1), + RowFactory.create(1, 0.8), + RowFactory.create(2, 0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); + Binarizer binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5); + DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); + DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); + for (Row r : binarizedFeatures.collect()) { + Double binarized_value = r.getDouble(0); + System.out.println(binarized_value); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java new file mode 100644 index 0000000000000..b06a23e76d604 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Bucketizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBucketizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBucketizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame dataFrame = jsql.createDataFrame(data, schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits); + + // Transform original data into its bucket index. + DataFrame bucketedData = bucketizer.transform(dataFrame); + // $example off$ + jsc.stop(); + } +} + + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java new file mode 100644 index 0000000000000..35c0d534a45e9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.DCT; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaDCTExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDCTExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), + RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), + RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + DataFrame df = jsql.createDataFrame(data, schema); + DCT dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false); + DataFrame dctDf = dct.transform(df); + dctDf.select("featuresDCT").show(3); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java new file mode 100644 index 0000000000000..2898accec61b0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaElementwiseProductExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Create some vector data; also works for sparse vectors + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) + )); + + List fields = new ArrayList(2); + fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("vector", new VectorUDT(), false)); + + StructType schema = DataTypes.createStructType(fields); + + DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); + + Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); + + ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java new file mode 100644 index 0000000000000..138b3ab6aba44 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaMinMaxScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JaveMinMaxScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + + // Compute summary statistics and generate MinMaxScalerModel + MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + + // rescale each feature to range [min, max]. + DataFrame scaledData = scalerModel.transform(dataFrame); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java new file mode 100644 index 0000000000000..8fd75ed8b5f4e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaNGramExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNGramExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField( + "words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); + + NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); + + DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); + + for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java new file mode 100644 index 0000000000000..6283a355e1fef --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.Normalizer; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaNormalizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNormalizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Normalize each Vector using $L^1$ norm. + Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0); + + DataFrame l1NormData = normalizer.transform(dataFrame); + + // Normalize each Vector using $L^\infty$ norm. + DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java new file mode 100644 index 0000000000000..172a9cc6feb28 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.OneHotEncoder; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaOneHotEncoderExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaOneHotEncoderExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); + DataFrame indexed = indexer.transform(df); + + OneHotEncoder encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec"); + DataFrame encoded = encoder.transform(indexed); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java new file mode 100644 index 0000000000000..8282fab084f36 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PCA; +import org.apache.spark.ml.feature.PCAModel; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPCAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPCAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + + PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); + + DataFrame result = pca.transform(df).select("pcaFeatures"); + result.show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java new file mode 100644 index 0000000000000..668f71e64056b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PolynomialExpansion; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPolynomialExpansionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPolynomialExpansionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(-2.0, 2.3)), + RowFactory.create(Vectors.dense(0.0, 0.0)), + RowFactory.create(Vectors.dense(0.6, -1.1)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + DataFrame polyDF = polyExpansion.transform(df); + + Row[] row = polyDF.select("polyFeatures").take(3); + for (Row r : row) { + System.out.println(r.get(0)); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java new file mode 100644 index 0000000000000..1e1062b541ad9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaRFormulaExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRFormulaExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) + }); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) + )); + + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + DataFrame output = formula.fit(dataset).transform(dataset); + output.select("features", "label").show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java new file mode 100644 index 0000000000000..0cbdc97e8ae30 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.ml.feature.StandardScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaStandardScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStandardScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + + // Compute summary statistics by fitting the StandardScaler + StandardScalerModel scalerModel = scaler.fit(dataFrame); + + // Normalize each feature to have unit standard deviation. + DataFrame scaledData = scalerModel.transform(dataFrame); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java new file mode 100644 index 0000000000000..b6b201c6b68d2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaStopWordsRemoverExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStopWordsRemoverExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField( + "raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame dataset = jsql.createDataFrame(rdd, schema); + remover.transform(dataset).show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java new file mode 100644 index 0000000000000..05d12c1e702f1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaStringIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStringIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + StructType schema = new StructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("category", StringType, false) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); + DataFrame indexed = indexer.fit(df).transform(df); + indexed.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java new file mode 100644 index 0000000000000..617dc3f66e3bf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaTokenizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTokenizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "Hi I heard about Spark"), + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); + + Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); + + DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); + for (Row r : wordsDataFrame.select("words", "label"). take(3)) { + java.util.List words = r.getList(0); + for (String word : words) System.out.print(word + " "); + System.out.println(); + } + + RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java new file mode 100644 index 0000000000000..7e230b5897c1e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaVectorAssemblerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorAssemblerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) + }); + Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + + DataFrame output = assembler.transform(dataset); + System.out.println(output.select("features", "clicked").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java new file mode 100644 index 0000000000000..06b4bf6bf8ff6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Map; + +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaVectorIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10); + VectorIndexerModel indexerModel = indexer.fit(data); + + Map> categoryMaps = indexerModel.javaCategoryMaps(); + System.out.print("Chose " + categoryMaps.size() + " categorical features:"); + + for (Integer feature : categoryMaps.keySet()) { + System.out.print(" " + feature); + } + System.out.println(); + + // Create new column "indexed" with categorical values transformed to indices + DataFrame indexedData = indexerModel.transform(data); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java new file mode 100644 index 0000000000000..4d5cb04ff5e2b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.attribute.Attribute; +import org.apache.spark.ml.attribute.AttributeGroup; +import org.apache.spark.ml.attribute.NumericAttribute; +import org.apache.spark.ml.feature.VectorSlicer; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaVectorSlicerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorSlicerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") + }; + AttributeGroup group = new AttributeGroup("userFeatures", attrs); + + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) + )); + + DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + + VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + + vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); + // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + + DataFrame output = vectorSlicer.transform(dataset); + + System.out.println(output.select("userFeatures", "features").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/python/ml/binarizer_example.py b/examples/src/main/python/ml/binarizer_example.py new file mode 100644 index 0000000000000..960ad208be12e --- /dev/null +++ b/examples/src/main/python/ml/binarizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Binarizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinarizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + continuousDataFrame = sqlContext.createDataFrame([ + (0, 0.1), + (1, 0.8), + (2, 0.2) + ], ["label", "feature"]) + binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") + binarizedDataFrame = binarizer.transform(continuousDataFrame) + binarizedFeatures = binarizedDataFrame.select("binarized_feature") + for binarized_feature, in binarizedFeatures.collect(): + print(binarized_feature) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/bucketizer_example.py b/examples/src/main/python/ml/bucketizer_example.py new file mode 100644 index 0000000000000..a12750aa9248a --- /dev/null +++ b/examples/src/main/python/ml/bucketizer_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Bucketizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BucketizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] + + data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] + dataFrame = sqlContext.createDataFrame(data, ["features"]) + + bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") + + # Transform original data into its bucket index. + bucketedData = bucketizer.transform(dataFrame) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py new file mode 100644 index 0000000000000..c85cb0d89543c --- /dev/null +++ b/examples/src/main/python/ml/elementwise_product_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="ElementwiseProductExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] + df = sqlContext.createDataFrame(data, ["vector"]) + transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), + inputCol="vector", outputCol="transformedVector") + transformer.transform(df).show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/n_gram_example.py b/examples/src/main/python/ml/n_gram_example.py new file mode 100644 index 0000000000000..f2d85f53e7219 --- /dev/null +++ b/examples/src/main/python/ml/n_gram_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import NGram +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NGramExample") + sqlContext = SQLContext(sc) + + # $example on$ + wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) + ], ["label", "words"]) + ngram = NGram(inputCol="words", outputCol="ngrams") + ngramDataFrame = ngram.transform(wordDataFrame) + for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/normalizer_example.py b/examples/src/main/python/ml/normalizer_example.py new file mode 100644 index 0000000000000..833d93e976a7e --- /dev/null +++ b/examples/src/main/python/ml/normalizer_example.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Normalizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NormalizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Normalize each Vector using $L^1$ norm. + normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) + l1NormData = normalizer.transform(dataFrame) + + # Normalize each Vector using $L^\infty$ norm. + lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py new file mode 100644 index 0000000000000..7529dfd09213a --- /dev/null +++ b/examples/src/main/python/ml/onehot_encoder_example.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import OneHotEncoder, StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="OneHotEncoderExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + ], ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + encoder = OneHotEncoder(dropLast=False, inputCol="categoryIndex", outputCol="categoryVec") + encoded = encoder.transform(indexed) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/pca_example.py b/examples/src/main/python/ml/pca_example.py new file mode 100644 index 0000000000000..8b66140a40a7a --- /dev/null +++ b/examples/src/main/python/ml/pca_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PCAExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] + df = sqlContext.createDataFrame(data,["features"]) + pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") + model = pca.fit(df) + result = model.transform(df).select("pcaFeatures") + result.show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py new file mode 100644 index 0000000000000..030a6132a451a --- /dev/null +++ b/examples/src/main/python/ml/polynomial_expansion_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PolynomialExpansion +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PolynomialExpansionExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(Vectors.dense([-2.0, 2.3]), ), + (Vectors.dense([0.0, 0.0]), ), + (Vectors.dense([0.6, -1.1]), )], + ["features"]) + px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") + polyDF = px.transform(df) + for expanded in polyDF.select("polyFeatures").take(3): + print(expanded) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/rformula_example.py b/examples/src/main/python/ml/rformula_example.py new file mode 100644 index 0000000000000..b544a14700762 --- /dev/null +++ b/examples/src/main/python/ml/rformula_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import RFormula +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="RFormulaExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) + formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") + output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/standard_scaler_example.py b/examples/src/main/python/ml/standard_scaler_example.py new file mode 100644 index 0000000000000..139acecbfb53f --- /dev/null +++ b/examples/src/main/python/ml/standard_scaler_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StandardScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StandardScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", + withStd=True, withMean=False) + + # Compute summary statistics by fitting the StandardScaler + scalerModel = scaler.fit(dataFrame) + + # Normalize each feature to have unit standard deviation. + scaledData = scalerModel.transform(dataFrame) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/stopwords_remover_example.py b/examples/src/main/python/ml/stopwords_remover_example.py new file mode 100644 index 0000000000000..01f94af8ca752 --- /dev/null +++ b/examples/src/main/python/ml/stopwords_remover_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StopWordsRemover +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StopWordsRemoverExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) + ], ["label", "raw"]) + + remover = StopWordsRemover(inputCol="raw", outputCol="filtered") + remover.transform(sentenceData).show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/string_indexer_example.py b/examples/src/main/python/ml/string_indexer_example.py new file mode 100644 index 0000000000000..58a8cb5d56b73 --- /dev/null +++ b/examples/src/main/python/ml/string_indexer_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StringIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) + indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + indexed = indexer.fit(df).transform(df) + indexed.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/tokenizer_example.py b/examples/src/main/python/ml/tokenizer_example.py new file mode 100644 index 0000000000000..ce9b225be5357 --- /dev/null +++ b/examples/src/main/python/ml/tokenizer_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Tokenizer, RegexTokenizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="TokenizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceDataFrame = sqlContext.createDataFrame([ + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + ], ["label", "sentence"]) + tokenizer = Tokenizer(inputCol="sentence", outputCol="words") + wordsDataFrame = tokenizer.transform(sentenceDataFrame) + for words_label in wordsDataFrame.select("words", "label").take(3): + print(words_label) + regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") + # alternatively, pattern="\\w+", gaps(False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_assembler_example.py b/examples/src/main/python/ml/vector_assembler_example.py new file mode 100644 index 0000000000000..04f64839f188d --- /dev/null +++ b/examples/src/main/python/ml/vector_assembler_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorAssemblerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) + assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") + output = assembler.transform(dataset) + print(output.select("features", "clicked").first()) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_indexer_example.py b/examples/src/main/python/ml/vector_indexer_example.py new file mode 100644 index 0000000000000..cc00d1454f2e0 --- /dev/null +++ b/examples/src/main/python/ml/vector_indexer_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import VectorIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) + indexerModel = indexer.fit(data) + + # Create new column "indexed" with categorical values transformed to indices + indexedData = indexerModel.transform(data) + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala new file mode 100644 index 0000000000000..e724aa587294b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Binarizer +// $example off$ +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.{SparkConf, SparkContext} + +object BinarizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BinarizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + val data = Array((0, 0.1), (1, 0.8), (2, 0.2)) + val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5) + + val binarizedDataFrame = binarizer.transform(dataFrame) + val binarizedFeatures = binarizedDataFrame.select("binarized_feature") + binarizedFeatures.collect().foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala new file mode 100644 index 0000000000000..30c2776d39688 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Bucketizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object BucketizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BucketizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + + val data = Array(-0.5, -0.3, 0.0, 0.2) + val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits) + + // Transform original data into its bucket index. + val bucketedData = bucketizer.transform(dataFrame) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala new file mode 100644 index 0000000000000..314c2c28a2a10 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.DCT +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object DCTExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DCTExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Seq( + Vectors.dense(0.0, 1.0, -2.0, 3.0), + Vectors.dense(-1.0, 2.0, 4.0, -7.0), + Vectors.dense(14.0, -2.0, -5.0, 1.0)) + + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false) + + val dctDf = dct.transform(df) + dctDf.select("featuresDCT").show(3) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala new file mode 100644 index 0000000000000..ac50bb7b2b155 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object ElementwiseProductExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ElementwiseProductExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Create some vector data; also works for sparse vectors + val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + + val transformingVector = Vectors.dense(0.0, 1.0, 2.0) + val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala new file mode 100644 index 0000000000000..dac3679a5bf7e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object MinMaxScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MinMaxScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + + // Compute summary statistics and generate MinMaxScalerModel + val scalerModel = scaler.fit(dataFrame) + + // rescale each feature to range [min, max]. + val scaledData = scalerModel.transform(dataFrame) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala new file mode 100644 index 0000000000000..8a85f71b56f3d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.NGram +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NGramExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NGramExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) + )).toDF("label", "words") + + val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") + val ngramDataFrame = ngram.transform(wordDataFrame) + ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala new file mode 100644 index 0000000000000..17571f0aad793 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Normalizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NormalizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NormalizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Normalize each Vector using $L^1$ norm. + val normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0) + + val l1NormData = normalizer.transform(dataFrame) + + // Normalize each Vector using $L^\infty$ norm. + val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala new file mode 100644 index 0000000000000..4512736943dd5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object OneHotEncoderExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("OneHotEncoderExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + )).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) + val indexed = indexer.transform(df) + + val encoder = new OneHotEncoder().setInputCol("categoryIndex"). + setOutputCol("categoryVec") + val encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala new file mode 100644 index 0000000000000..a18d4f33973d8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PCAExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PCAExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) + val pcaDF = pca.transform(df) + val result = pcaDF.select("pcaFeatures") + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala new file mode 100644 index 0000000000000..b8e9e6952a5ea --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PolynomialExpansion +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PolynomialExpansionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PolynomialExpansionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0), + Vectors.dense(0.6, -1.1) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) + val polyDF = polynomialExpansion.transform(df) + polyDF.select("polyFeatures").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala new file mode 100644 index 0000000000000..286866edea502 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.RFormula +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object RFormulaExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RFormulaExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) + )).toDF("id", "country", "hour", "clicked") + val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") + val output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala new file mode 100644 index 0000000000000..646ce0f13ecf5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StandardScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StandardScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false) + + // Compute summary statistics by fitting the StandardScaler. + val scalerModel = scaler.fit(dataFrame) + + // Normalize each feature to have unit standard deviation. + val scaledData = scalerModel.transform(dataFrame) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala new file mode 100644 index 0000000000000..655ffce08d3ab --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StopWordsRemover +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StopWordsRemoverExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StopWordsRemoverExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + + val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) + )).toDF("id", "raw") + + remover.transform(dataSet).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala new file mode 100644 index 0000000000000..1be8a5f33f7c0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StringIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StringIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StringIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + ).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + + val indexed = indexer.fit(df).transform(df) + indexed.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala new file mode 100644 index 0000000000000..01e0d1388a2f4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object TokenizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("TokenizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val sentenceDataFrame = sqlContext.createDataFrame(Seq( + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + )).toDF("label", "sentence") + + val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") + val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + + val tokenized = tokenizer.transform(sentenceDataFrame) + tokenized.select("words", "label").take(3).foreach(println) + val regexTokenized = regexTokenizer.transform(sentenceDataFrame) + regexTokenized.select("words", "label").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala new file mode 100644 index 0000000000000..d527924419f81 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorAssemblerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorAssemblerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) + ).toDF("id", "hour", "mobile", "userFeatures", "clicked") + + val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") + + val output = assembler.transform(dataset) + println(output.select("features", "clicked").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala new file mode 100644 index 0000000000000..14279d610fda8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10) + + val indexerModel = indexer.fit(data) + + val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet + println(s"Chose ${categoricalFeatures.size} categorical features: " + + categoricalFeatures.mkString(", ")) + + // Create new column "indexed" with categorical values transformed to indices + val indexedData = indexerModel.transform(data) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala new file mode 100644 index 0000000000000..04f19829eff87 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorSlicerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorSlicerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0))) + + val defaultAttr = NumericAttribute.defaultAttr + val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) + val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + + val dataRDD = sc.parallelize(data) + val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField()))) + + val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + + slicer.setIndices(Array(1)).setNames(Array("f3")) + // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + + val output = slicer.transform(dataset) + println(output.select("userFeatures", "features").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println From b11245ae741d96c53d13365e254ad9f53e37da45 Mon Sep 17 00:00:00 2001 From: Takahashi Hiroshi Date: Mon, 7 Dec 2015 23:46:55 -0800 Subject: [PATCH 1064/1168] [SPARK-10259][ML] Add @since annotation to ml.classification Add since annotation to ml.classification Author: Takahashi Hiroshi Closes #8534 from taishi-oss/issue10259. --- .../DecisionTreeClassifier.scala | 30 +++++++-- .../ml/classification/GBTClassifier.scala | 35 ++++++++-- .../classification/LogisticRegression.scala | 64 +++++++++++++++---- .../MultilayerPerceptronClassifier.scala | 23 +++++-- .../spark/ml/classification/NaiveBayes.scala | 19 ++++-- .../spark/ml/classification/OneVsRest.scala | 24 +++++-- .../RandomForestClassifier.scala | 34 ++++++++-- 7 files changed, 185 insertions(+), 44 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index c478aea44ace8..8c4cec1326653 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest @@ -36,32 +36,44 @@ import org.apache.spark.sql.DataFrame * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental -final class DecisionTreeClassifier(override val uid: String) +final class DecisionTreeClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("dtc")) // Override parameter setters from parent trait for Java API compatibility. + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) + @Since("1.6.0") override def setSeed(value: Long): this.type = super.setSeed(value) override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { @@ -89,12 +101,15 @@ final class DecisionTreeClassifier(override val uid: String) subsamplingRate = 1.0) } + @Since("1.4.1") override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental object DecisionTreeClassifier { /** Accessor for supported impurities: entropy, gini */ + @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities } @@ -104,12 +119,13 @@ object DecisionTreeClassifier { * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental final class DecisionTreeClassificationModel private[ml] ( - override val uid: String, - override val rootNode: Node, - override val numFeatures: Int, - override val numClasses: Int) + @Since("1.4.0")override val uid: String, + @Since("1.4.0")override val rootNode: Node, + @Since("1.6.0")override val numFeatures: Int, + @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { @@ -142,11 +158,13 @@ final class DecisionTreeClassificationModel private[ml] ( } } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeClassificationModel = { copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra) .setParent(parent) } + @Since("1.4.0") override def toString: String = { s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 74aef94bf7675..cda2bca58c50d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel @@ -44,36 +44,47 @@ import org.apache.spark.sql.types.DoubleType * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. */ +@Since("1.4.0") @Experimental -final class GBTClassifier(override val uid: String) +final class GBTClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] with GBTParams with TreeClassifierParams with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtc")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) /** * The impurity setting is ignored for GBT models. * Individual trees are built using impurity "Variance." */ + @Since("1.4.0") override def setImpurity(value: String): this.type = { logWarning("GBTClassifier.setImpurity should NOT be used") this @@ -81,8 +92,10 @@ final class GBTClassifier(override val uid: String) // Parameters from TreeEnsembleParams: + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + @Since("1.4.0") override def setSeed(value: Long): this.type = { logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") super.setSeed(value) @@ -90,8 +103,10 @@ final class GBTClassifier(override val uid: String) // Parameters from GBTParams: + @Since("1.4.0") override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) // Parameters for GBTClassifier: @@ -102,6 +117,7 @@ final class GBTClassifier(override val uid: String) * (default = logistic) * @group param */ + @Since("1.4.0") val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTClassifier.supportedLossTypes.mkString(", ")}", @@ -110,9 +126,11 @@ final class GBTClassifier(override val uid: String) setDefault(lossType -> "logistic") /** @group setParam */ + @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) /** @group getParam */ + @Since("1.4.0") def getLossType: String = $(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ @@ -145,13 +163,16 @@ final class GBTClassifier(override val uid: String) GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) } + @Since("1.4.1") override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental object GBTClassifier { // The losses below should be lowercase. /** Accessor for supported loss settings: logistic */ + @Since("1.4.0") final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) } @@ -164,12 +185,13 @@ object GBTClassifier { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ +@Since("1.6.0") @Experimental final class GBTClassificationModel private[ml]( - override val uid: String, + @Since("1.6.0") override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double], - override val numFeatures: Int) + @Since("1.6.0") override val numFeatures: Int) extends PredictionModel[Vector, GBTClassificationModel] with TreeEnsembleModel with Serializable { @@ -182,11 +204,14 @@ final class GBTClassificationModel private[ml]( * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ + @Since("1.6.0") def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = this(uid, _trees, _treeWeights, -1) + @Since("1.4.0") override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { @@ -205,11 +230,13 @@ final class GBTClassificationModel private[ml]( if (prediction > 0.0) 1.0 else 0.0 } + @Since("1.4.0") override def copy(extra: ParamMap): GBTClassificationModel = { copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), extra).setParent(parent) } + @Since("1.4.0") override def toString: String = { s"GBTClassificationModel (uid=$uid) with $numTrees trees" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index d320d64dd90d0..19cc323d5073f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -154,11 +154,14 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas * Currently, this class only supports binary classification. It will support multiclass * in the future. */ +@Since("1.2.0") @Experimental -class LogisticRegression(override val uid: String) +class LogisticRegression @Since("1.2.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams with DefaultParamsWritable with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("logreg")) /** @@ -166,6 +169,7 @@ class LogisticRegression(override val uid: String) * Default is 0.0. * @group setParam */ + @Since("1.2.0") def setRegParam(value: Double): this.type = set(regParam, value) setDefault(regParam -> 0.0) @@ -176,6 +180,7 @@ class LogisticRegression(override val uid: String) * Default is 0.0 which is an L2 penalty. * @group setParam */ + @Since("1.4.0") def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) setDefault(elasticNetParam -> 0.0) @@ -184,6 +189,7 @@ class LogisticRegression(override val uid: String) * Default is 100. * @group setParam */ + @Since("1.2.0") def setMaxIter(value: Int): this.type = set(maxIter, value) setDefault(maxIter -> 100) @@ -193,6 +199,7 @@ class LogisticRegression(override val uid: String) * Default is 1E-6. * @group setParam */ + @Since("1.4.0") def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) @@ -201,6 +208,7 @@ class LogisticRegression(override val uid: String) * Default is true. * @group setParam */ + @Since("1.4.0") def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) @@ -213,11 +221,14 @@ class LogisticRegression(override val uid: String) * Default is true. * @group setParam */ + @Since("1.5.0") def setStandardization(value: Boolean): this.type = set(standardization, value) setDefault(standardization -> true) + @Since("1.5.0") override def setThreshold(value: Double): this.type = super.setThreshold(value) + @Since("1.5.0") override def getThreshold: Double = super.getThreshold /** @@ -226,11 +237,14 @@ class LogisticRegression(override val uid: String) * Default is empty, so all instances have weight one. * @group setParam */ + @Since("1.6.0") def setWeightCol(value: String): this.type = set(weightCol, value) setDefault(weightCol -> "") + @Since("1.5.0") override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + @Since("1.5.0") override def getThresholds: Array[Double] = super.getThresholds override protected def train(dataset: DataFrame): LogisticRegressionModel = { @@ -384,11 +398,14 @@ class LogisticRegression(override val uid: String) model.setSummary(logRegSummary) } + @Since("1.4.0") override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) } +@Since("1.6.0") object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { + @Since("1.6.0") override def load(path: String): LogisticRegression = super.load(path) } @@ -396,23 +413,28 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { * :: Experimental :: * Model produced by [[LogisticRegression]]. */ +@Since("1.4.0") @Experimental class LogisticRegressionModel private[ml] ( - override val uid: String, - val coefficients: Vector, - val intercept: Double) + @Since("1.4.0") override val uid: String, + @Since("1.6.0") val coefficients: Vector, + @Since("1.3.0") val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams with MLWritable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients + @Since("1.5.0") override def setThreshold(value: Double): this.type = super.setThreshold(value) + @Since("1.5.0") override def getThreshold: Double = super.getThreshold + @Since("1.5.0") override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + @Since("1.5.0") override def getThresholds: Array[Double] = super.getThresholds /** Margin (rawPrediction) for class label 1. For binary classification only. */ @@ -426,8 +448,10 @@ class LogisticRegressionModel private[ml] ( 1.0 / (1.0 + math.exp(-m)) } + @Since("1.6.0") override val numFeatures: Int = coefficients.size + @Since("1.3.0") override val numClasses: Int = 2 private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None @@ -436,6 +460,7 @@ class LogisticRegressionModel private[ml] ( * Gets summary of model on training set. An exception is * thrown if `trainingSummary == None`. */ + @Since("1.5.0") def summary: LogisticRegressionTrainingSummary = trainingSummary match { case Some(summ) => summ case None => @@ -451,6 +476,7 @@ class LogisticRegressionModel private[ml] ( } /** Indicates whether a training summary exists for this model instance. */ + @Since("1.5.0") def hasSummary: Boolean = trainingSummary.isDefined /** @@ -493,6 +519,7 @@ class LogisticRegressionModel private[ml] ( Vectors.dense(-m, m) } + @Since("1.4.0") override def copy(extra: ParamMap): LogisticRegressionModel = { val newModel = copyValues(new LogisticRegressionModel(uid, coefficients, intercept), extra) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) @@ -710,12 +737,13 @@ sealed trait LogisticRegressionSummary extends Serializable { * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Experimental +@Since("1.5.0") class BinaryLogisticRegressionTrainingSummary private[classification] ( - predictions: DataFrame, - probabilityCol: String, - labelCol: String, - featuresCol: String, - val objectiveHistory: Array[Double]) + @Since("1.5.0") predictions: DataFrame, + @Since("1.5.0") probabilityCol: String, + @Since("1.5.0") labelCol: String, + @Since("1.6.0") featuresCol: String, + @Since("1.5.0") val objectiveHistory: Array[Double]) extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol) with LogisticRegressionTrainingSummary { @@ -731,11 +759,13 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * @param featuresCol field in "predictions" which gives the features of each instance as a vector. */ @Experimental +@Since("1.5.0") class BinaryLogisticRegressionSummary private[classification] ( - @transient override val predictions: DataFrame, - override val probabilityCol: String, - override val labelCol: String, - override val featuresCol: String) extends LogisticRegressionSummary { + @Since("1.5.0") @transient override val predictions: DataFrame, + @Since("1.5.0") override val probabilityCol: String, + @Since("1.5.0") override val labelCol: String, + @Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary { + private val sqlContext = predictions.sqlContext import sqlContext.implicits._ @@ -760,6 +790,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * This will change in later Spark versions. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic */ + @Since("1.5.0") @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") /** @@ -768,6 +799,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() /** @@ -777,6 +809,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") /** @@ -785,6 +818,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val fMeasureByThreshold: DataFrame = { binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") } @@ -797,6 +831,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val precisionByThreshold: DataFrame = { binaryMetrics.precisionByThreshold().toDF("threshold", "precision") } @@ -809,6 +844,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val recallByThreshold: DataFrame = { binaryMetrics.recallByThreshold().toDF("threshold", "recall") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index cd7462596dd9e..a691aa005ef54 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.classification import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} @@ -104,19 +104,23 @@ private object LabelConverter { * Each layer has sigmoid activation function, output layer has softmax. * Number of inputs has to be equal to the size of feature vectors. * Number of outputs has to be equal to the total number of labels. - * */ +@Since("1.5.0") @Experimental -class MultilayerPerceptronClassifier(override val uid: String) +class MultilayerPerceptronClassifier @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] with MultilayerPerceptronParams { + @Since("1.5.0") def this() = this(Identifiable.randomUID("mlpc")) /** @group setParam */ + @Since("1.5.0") def setLayers(value: Array[Int]): this.type = set(layers, value) /** @group setParam */ + @Since("1.5.0") def setBlockSize(value: Int): this.type = set(blockSize, value) /** @@ -124,6 +128,7 @@ class MultilayerPerceptronClassifier(override val uid: String) * Default is 100. * @group setParam */ + @Since("1.5.0") def setMaxIter(value: Int): this.type = set(maxIter, value) /** @@ -132,14 +137,17 @@ class MultilayerPerceptronClassifier(override val uid: String) * Default is 1E-4. * @group setParam */ + @Since("1.5.0") def setTol(value: Double): this.type = set(tol, value) /** * Set the seed for weights initialization. * @group setParam */ + @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) /** @@ -173,14 +181,16 @@ class MultilayerPerceptronClassifier(override val uid: String) * @param weights vector of initial weights for the model that consists of the weights of layers * @return prediction model */ +@Since("1.5.0") @Experimental class MultilayerPerceptronClassificationModel private[ml] ( - override val uid: String, - val layers: Array[Int], - val weights: Vector) + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val layers: Array[Int], + @Since("1.5.0") val weights: Vector) extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] with Serializable { + @Since("1.6.0") override val numFeatures: Int = layers.head private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) @@ -200,6 +210,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( LabelConverter.decodeLabel(mlpModel.predict(features)) } + @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index c512a2cb8bf3d..718f49d3aedcd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -72,11 +72,14 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]). * The input feature values must be nonnegative. */ +@Since("1.5.0") @Experimental -class NaiveBayes(override val uid: String) +class NaiveBayes @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("nb")) /** @@ -84,6 +87,7 @@ class NaiveBayes(override val uid: String) * Default is 1.0. * @group setParam */ + @Since("1.5.0") def setSmoothing(value: Double): this.type = set(smoothing, value) setDefault(smoothing -> 1.0) @@ -93,6 +97,7 @@ class NaiveBayes(override val uid: String) * Default is "multinomial" * @group setParam */ + @Since("1.5.0") def setModelType(value: String): this.type = set(modelType, value) setDefault(modelType -> OldNaiveBayes.Multinomial) @@ -102,6 +107,7 @@ class NaiveBayes(override val uid: String) NaiveBayesModel.fromOld(oldModel, this) } + @Since("1.5.0") override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) } @@ -119,11 +125,12 @@ object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { * @param theta log of class conditional probabilities, whose dimension is C (number of classes) * by D (number of features) */ +@Since("1.5.0") @Experimental class NaiveBayesModel private[ml] ( - override val uid: String, - val pi: Vector, - val theta: Matrix) + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val pi: Vector, + @Since("1.5.0") val theta: Matrix) extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams with MLWritable { @@ -148,8 +155,10 @@ class NaiveBayesModel private[ml] ( throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } + @Since("1.6.0") override val numFeatures: Int = theta.numCols + @Since("1.5.0") override val numClasses: Int = pi.size private def multinomialCalculation(features: Vector) = { @@ -206,10 +215,12 @@ class NaiveBayesModel private[ml] ( } } + @Since("1.5.0") override def copy(extra: ParamMap): NaiveBayesModel = { copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) } + @Since("1.5.0") override def toString: String = { s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index debc164bf2432..08a51109d6c62 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -21,7 +21,7 @@ import java.util.UUID import scala.language.existentials -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.{Param, ParamMap} @@ -70,17 +70,20 @@ private[ml] trait OneVsRestParams extends PredictorParams { * The i-th model is produced by testing the i-th class (taking label 1) vs the rest * (taking label 0). */ +@Since("1.4.0") @Experimental final class OneVsRestModel private[ml] ( - override val uid: String, - labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_, _]]) + @Since("1.4.0") override val uid: String, + @Since("1.4.0") labelMetadata: Metadata, + @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams { + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) } + @Since("1.4.0") override def transform(dataset: DataFrame): DataFrame = { // Check schema transformSchema(dataset.schema, logging = true) @@ -134,6 +137,7 @@ final class OneVsRestModel private[ml] ( .drop(accColName) } + @Since("1.4.1") override def copy(extra: ParamMap): OneVsRestModel = { val copied = new OneVsRestModel( uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) @@ -150,30 +154,39 @@ final class OneVsRestModel private[ml] ( * Each example is scored against all k models and the model with highest score * is picked to label the example. */ +@Since("1.4.0") @Experimental -final class OneVsRest(override val uid: String) +final class OneVsRest @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends Estimator[OneVsRestModel] with OneVsRestParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("oneVsRest")) /** @group setParam */ + @Since("1.4.0") def setClassifier(value: Classifier[_, _, _]): this.type = { set(classifier, value.asInstanceOf[ClassifierType]) } /** @group setParam */ + @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) /** @group setParam */ + @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ + @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) } + @Since("1.4.0") override def fit(dataset: DataFrame): OneVsRestModel = { // determine number of classes either from metadata if provided, or via computation. val labelSchema = dataset.schema($(labelCol)) @@ -222,6 +235,7 @@ final class OneVsRest(override val uid: String) copyValues(model) } + @Since("1.4.1") override def copy(extra: ParamMap): OneVsRest = { val copied = defaultCopy(extra).asInstanceOf[OneVsRest] if (isDefined(classifier)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bae329692a68d..d6d85ad2533a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} @@ -38,44 +38,59 @@ import org.apache.spark.sql.functions._ * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental -final class RandomForestClassifier(override val uid: String) +final class RandomForestClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("rfc")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) // Parameters from TreeEnsembleParams: + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + @Since("1.4.0") override def setSeed(value: Long): this.type = super.setSeed(value) // Parameters from RandomForestParams: + @Since("1.4.0") override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + @Since("1.4.0") override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) @@ -99,15 +114,19 @@ final class RandomForestClassifier(override val uid: String) new RandomForestClassificationModel(trees, numFeatures, numClasses) } + @Since("1.4.1") override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental object RandomForestClassifier { /** Accessor for supported impurity settings: entropy, gini */ + @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = RandomForestParams.supportedFeatureSubsetStrategies } @@ -120,12 +139,13 @@ object RandomForestClassifier { * @param _trees Decision trees in the ensemble. * Warning: These have null parents. */ +@Since("1.4.0") @Experimental final class RandomForestClassificationModel private[ml] ( - override val uid: String, + @Since("1.5.0") override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], - override val numFeatures: Int, - override val numClasses: Int) + @Since("1.6.0") override val numFeatures: Int, + @Since("1.5.0") override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { @@ -141,11 +161,13 @@ final class RandomForestClassificationModel private[ml] ( numClasses: Int) = this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) + @Since("1.4.0") override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { @@ -186,11 +208,13 @@ final class RandomForestClassificationModel private[ml] ( } } + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) .setParent(parent) } + @Since("1.4.0") override def toString: String = { s"RandomForestClassificationModel (uid=$uid) with $numTrees trees" } From c0e42e58960f0efdfdcfd63fbca51301c58f7bc4 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 7 Dec 2015 23:50:57 -0800 Subject: [PATCH 1065/1168] [SPARK-11958][SPARK-11957][ML][DOC] SQLTransformer user guide and example code Add ```SQLTransformer``` user guide, example code and make Scala API doc more clear. Author: Yanbo Liang Closes #10006 from yanboliang/spark-11958. --- docs/ml-features.md | 59 +++++++++++++++++++ .../ml/JavaSQLTransformerExample.java | 59 +++++++++++++++++++ .../src/main/python/ml/sql_transformer.py | 40 +++++++++++++ .../examples/ml/SQLTransformerExample.scala | 45 ++++++++++++++ .../spark/ml/feature/SQLTransformer.scala | 11 +++- 5 files changed, 212 insertions(+), 2 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java create mode 100644 examples/src/main/python/ml/sql_transformer.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 5105a948fec8e..f85e0d56d2e40 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -756,6 +756,65 @@ for more details on the API. +## SQLTransformer + +`SQLTransformer` implements the transformations which are defined by SQL statement. +Currently we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` +where `"__THIS__"` represents the underlying table of the input dataset. +The select clause specifies the fields, constants, and expressions to display in +the output, it can be any select clause that Spark SQL supports. Users can also +use Spark SQL built-in function and UDFs to operate on these selected columns. +For example, `SQLTransformer` supports statements like: + +* `SELECT a, a + b AS a_b FROM __THIS__` +* `SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5` +* `SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b` + +**Examples** + +Assume that we have the following DataFrame with columns `id`, `v1` and `v2`: + +~~~~ + id | v1 | v2 +----|-----|----- + 0 | 1.0 | 3.0 + 2 | 2.0 | 5.0 +~~~~ + +This is the output of the `SQLTransformer` with statement `"SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"`: + +~~~~ + id | v1 | v2 | v3 | v4 +----|-----|-----|-----|----- + 0 | 1.0 | 3.0 | 4.0 | 3.0 + 2 | 2.0 | 5.0 | 7.0 |10.0 +~~~~ + +
    +
    + +Refer to the [SQLTransformer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.SQLTransformer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/SQLTransformerExample.scala %} +
    + +
    + +Refer to the [SQLTransformer Java docs](api/java/org/apache/spark/ml/feature/SQLTransformer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java %} +
    + +
    + +Refer to the [SQLTransformer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.SQLTransformer) for more details on the API. + +{% include_example python/ml/sql_transformer.py %} +
    +
    + ## VectorAssembler `VectorAssembler` is a transformer that combines a given list of columns into a single vector diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java new file mode 100644 index 0000000000000..d55c70796a967 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.SQLTransformer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaSQLTransformerExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaSQLTransformerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 1.0, 3.0), + RowFactory.create(2, 2.0, 5.0) + )); + StructType schema = new StructType(new StructField [] { + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("v1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("v2", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + SQLTransformer sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"); + + sqlTrans.transform(df).show(); + // $example off$ + } +} diff --git a/examples/src/main/python/ml/sql_transformer.py b/examples/src/main/python/ml/sql_transformer.py new file mode 100644 index 0000000000000..9575d728d8159 --- /dev/null +++ b/examples/src/main/python/ml/sql_transformer.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.feature import SQLTransformer +# $example off$ +from pyspark.sql import SQLContext + +if __name__ == "__main__": + sc = SparkContext(appName="SQLTransformerExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, 1.0, 3.0), + (2, 2.0, 5.0) + ], ["id", "v1", "v2"]) + sqlTrans = SQLTransformer( + statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + sqlTrans.transform(df).show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala new file mode 100644 index 0000000000000..014abd1fdbc63 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.SQLTransformer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + + +object SQLTransformerExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("SQLTransformerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + + val sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + + sqlTrans.transform(df).show() + // $example off$ + } +} +// scalastyle:on println diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 3a735017ba836..c09f4d076c964 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -27,9 +27,16 @@ import org.apache.spark.sql.types.StructType /** * :: Experimental :: - * Implements the transforms which are defined by SQL statement. - * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + * Implements the transformations which are defined by SQL statement. + * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__ ...' * where '__THIS__' represents the underlying table of the input dataset. + * The select clause specifies the fields, constants, and expressions to display in + * the output, it can be any select clause that Spark SQL supports. Users can also + * use Spark SQL built-in function and UDFs to operate on these selected columns. + * For example, [[SQLTransformer]] supports statements like: + * - SELECT a, a + b AS a_b FROM __THIS__ + * - SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5 + * - SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b */ @Experimental @Since("1.6.0") From 48f29142c554a2228c3e87e15bb412df4335f5b4 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Tue, 8 Dec 2015 11:02:35 +0000 Subject: [PATCH 1066/1168] =?UTF-8?q?[SPARK-12103][STREAMING][KAFKA][DOC]?= =?UTF-8?q?=20document=20that=20K=20means=20Key=20and=20V=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …means Value Author: cody koeninger Closes #10132 from koeninger/SPARK-12103. --- .../spark/streaming/kafka/KafkaUtils.scala | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index ad2fb8aa5f24c..fe572220528d5 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -51,6 +51,7 @@ object KafkaUtils { * in its own thread * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( ssc: StreamingContext, @@ -74,6 +75,11 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel Storage level to use for storing the received objects + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam U type of Kafka message key decoder + * @tparam T type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( ssc: StreamingContext, @@ -93,6 +99,7 @@ object KafkaUtils { * @param groupId The group id for this consumer * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( jssc: JavaStreamingContext, @@ -111,6 +118,7 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( jssc: JavaStreamingContext, @@ -135,6 +143,11 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread * @param storageLevel RDD storage level. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam U type of Kafka message key decoder + * @tparam T type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createStream[K, V, U <: Decoder[_], T <: Decoder[_]]( jssc: JavaStreamingContext, @@ -219,6 +232,11 @@ object KafkaUtils { * host1:port1,host2:port2 form. * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return RDD of (Kafka message key, Kafka message value) */ def createRDD[ K: ClassTag, @@ -251,6 +269,12 @@ object KafkaUtils { * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return RDD of R */ def createRDD[ K: ClassTag, @@ -288,6 +312,15 @@ object KafkaUtils { * host1:port1,host2:port2 form. * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition + * @param keyClass type of Kafka message key + * @param valueClass type of Kafka message value + * @param keyDecoderClass type of Kafka message key decoder + * @param valueDecoderClass type of Kafka message value decoder + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return RDD of (Kafka message key, Kafka message value) */ def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jsc: JavaSparkContext, @@ -321,6 +354,12 @@ object KafkaUtils { * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return RDD of R */ def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jsc: JavaSparkContext, @@ -373,6 +412,12 @@ object KafkaUtils { * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return DStream of R */ def createDirectStream[ K: ClassTag, @@ -419,6 +464,11 @@ object KafkaUtils { * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createDirectStream[ K: ClassTag, @@ -470,6 +520,12 @@ object KafkaUtils { * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return DStream of R */ def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jssc: JavaStreamingContext, @@ -529,6 +585,11 @@ object KafkaUtils { * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jssc: JavaStreamingContext, From 675b92ae8894c3747cc50c9d653af08c39603de4 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 8 Dec 2015 11:05:06 +0000 Subject: [PATCH 1067/1168] [SPARK-12166][TEST] Unset hadoop related environment in testing Author: Jeff Zhang Closes #10172 from zjffdu/SPARK-12166. --- bin/spark-class | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bin/spark-class b/bin/spark-class index 87d06693af4fe..5d964ba96abd8 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -71,6 +71,12 @@ fi export _SPARK_ASSEMBLY="$SPARK_ASSEMBLY_JAR" +# For tests +if [[ -n "$SPARK_TESTING" ]]; then + unset YARN_CONF_DIR + unset HADOOP_CONF_DIR +fi + # The launcher library will print arguments separated by a NULL character, to allow arguments with # characters that would be otherwise interpreted by the shell. Read that in a while loop, populating # an array that will be used to exec the final command. From 7d60877cd3c511e504ba06dc9662b03b6bd7a1d7 Mon Sep 17 00:00:00 2001 From: Nakul Jindal Date: Tue, 8 Dec 2015 11:08:27 +0000 Subject: [PATCH 1068/1168] [SPARK-11439][ML] Optimization of creating sparse feature without dense one Sparse feature generated in LinearDataGenerator does not create dense vectors as an intermediate any more. Author: Nakul Jindal Closes #9756 from nakul02/SPARK-11439_sparse_without_creating_dense_feature. --- .../mllib/util/LinearDataGenerator.scala | 44 ++-- .../evaluation/RegressionEvaluatorSuite.scala | 6 +- .../ml/regression/LinearRegressionSuite.scala | 214 ++++++++++-------- 3 files changed, 142 insertions(+), 122 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 6ff07eed6cfd2..094528e2ece06 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -24,7 +24,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{BLAS, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -131,35 +131,27 @@ object LinearDataGenerator { eps: Double, sparsity: Double): Seq[LabeledPoint] = { require(0.0 <= sparsity && sparsity <= 1.0) - val rnd = new Random(seed) - val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(rnd.nextDouble())) - - val sparseRnd = new Random(seed) - x.foreach { v => - var i = 0 - val len = v.length - while (i < len) { - if (sparseRnd.nextDouble() < sparsity) { - v(i) = 0.0 - } else { - v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) - } - i += 1 - } - } - val y = x.map { xi => - blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() - } + val rnd = new Random(seed) + def rndElement(i: Int) = {(rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)} - y.zip(x).map { p => - if (sparsity == 0.0) { + if (sparsity == 0.0) { + (0 until nPoints).map { _ => + val features = Vectors.dense(weights.indices.map { rndElement(_) }.toArray) + val label = BLAS.dot(Vectors.dense(weights), features) + + intercept + eps * rnd.nextGaussian() // Return LabeledPoints with DenseVector - LabeledPoint(p._1, Vectors.dense(p._2)) - } else { + LabeledPoint(label, features) + } + } else { + (0 until nPoints).map { _ => + val indices = weights.indices.filter { _ => rnd.nextDouble() <= sparsity} + val values = indices.map { rndElement(_) } + val features = Vectors.sparse(weights.length, indices.toArray, values.toArray) + val label = BLAS.dot(Vectors.dense(weights), features) + + intercept + eps * rnd.nextGaussian() // Return LabeledPoints with SparseVector - LabeledPoint(p._1, Vectors.dense(p._2).toSparse) + LabeledPoint(label, features) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 60886bf77d2f0..954d3bedc14bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -65,15 +65,15 @@ class RegressionEvaluatorSuite // default = rmse val evaluator = new RegressionEvaluator() - assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.1013829 absTol 0.01) // r2 score evaluator.setMetricName("r2") - assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.9998387 absTol 0.01) // mae evaluator.setMetricName("mae") - assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.08399089 absTol 0.01) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 2bdc0e184d734..2f3e703f4c252 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -42,6 +42,7 @@ class LinearRegressionSuite In `LinearRegressionSuite`, we will make sure that the model trained by SparkML is the same as the one trained by R's glmnet package. The following instruction describes how to reproduce the data in R. + In a spark-shell, use the following code: import org.apache.spark.mllib.util.LinearDataGenerator val data = @@ -184,15 +185,15 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 6.995908 - as.numeric.data.V3. 5.275131 + as.numeric.data.V2. 6.973403 + as.numeric.data.V3. 5.284370 */ - val coefficientsR = Vectors.dense(6.995908, 5.275131) + val coefficientsR = Vectors.dense(6.973403, 5.284370) - assert(model1.intercept ~== 0 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR relTol 1E-3) - assert(model2.intercept ~== 0 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR relTol 1E-3) + assert(model1.intercept ~== 0 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR relTol 1E-2) + assert(model2.intercept ~== 0 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR relTol 1E-2) /* Then again with the data with no intercept: @@ -235,14 +236,14 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 6.24300 - as.numeric.data.V2. 4.024821 - as.numeric.data.V3. 6.679841 + (Intercept) 6.242284 + as.numeric.d1.V2. 4.019605 + as.numeric.d1.V3. 6.679538 */ - val interceptR1 = 6.24300 - val coefficientsR1 = Vectors.dense(4.024821, 6.679841) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + val interceptR1 = 6.242284 + val coefficientsR1 = Vectors.dense(4.019605, 6.679538) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, @@ -296,14 +297,14 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 6.299752 - as.numeric.data.V3. 4.772913 + as.numeric.data.V2. 6.272927 + as.numeric.data.V3. 4.782604 */ val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(6.299752, 4.772913) + val coefficientsR1 = Vectors.dense(6.272927, 4.782604) - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, @@ -312,14 +313,14 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 6.232193 - as.numeric.data.V3. 4.764229 + as.numeric.data.V2. 6.207817 + as.numeric.data.V3. 4.775780 */ val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(6.232193, 4.764229) + val coefficientsR2 = Vectors.dense(6.207817, 4.775780) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { @@ -347,15 +348,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 5.269376 - as.numeric.data.V2. 3.736216 - as.numeric.data.V3. 5.712356) + (Intercept) 5.260103 + as.numeric.d1.V2. 3.725522 + as.numeric.d1.V3. 5.711203 */ - val interceptR1 = 5.269376 - val coefficientsR1 = Vectors.dense(3.736216, 5.712356) + val interceptR1 = 5.260103 + val coefficientsR1 = Vectors.dense(3.725522, 5.711203) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, @@ -363,15 +364,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 5.791109 - as.numeric.data.V2. 3.435466 - as.numeric.data.V3. 5.910406 + (Intercept) 5.790885 + as.numeric.d1.V2. 3.432373 + as.numeric.d1.V3. 5.919196 */ - val interceptR2 = 5.791109 - val coefficientsR2 = Vectors.dense(3.435466, 5.910406) + val interceptR2 = 5.790885 + val coefficientsR2 = Vectors.dense(3.432373, 5.919196) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -398,15 +399,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) . - as.numeric.data.V2. 5.522875 - as.numeric.data.V3. 4.214502 + (Intercept) . + as.numeric.d1.V2. 5.493430 + as.numeric.d1.V3. 4.223082 */ val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(5.522875, 4.214502) + val coefficientsR1 = Vectors.dense(5.493430, 4.223082) - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, @@ -415,14 +416,14 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 5.263704 - as.numeric.data.V3. 4.187419 + as.numeric.d1.V2. 5.244324 + as.numeric.d1.V3. 4.203106 */ val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(5.263704, 4.187419) + val coefficientsR2 = Vectors.dense(5.244324, 4.203106) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -457,15 +458,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 6.324108 - as.numeric.data.V2. 3.168435 - as.numeric.data.V3. 5.200403 + (Intercept) 5.689855 + as.numeric.d1.V2. 3.661181 + as.numeric.d1.V3. 6.000274 */ - val interceptR1 = 5.696056 - val coefficientsR1 = Vectors.dense(3.670489, 6.001122) + val interceptR1 = 5.689855 + val coefficientsR1 = Vectors.dense(3.661181, 6.000274) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 @@ -473,15 +474,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 6.114723 - as.numeric.data.V2. 3.409937 - as.numeric.data.V3. 6.146531 + (Intercept) 6.113890 + as.numeric.d1.V2. 3.407021 + as.numeric.d1.V3. 6.152512 */ - val interceptR2 = 6.114723 - val coefficientsR2 = Vectors.dense(3.409937, 6.146531) + val interceptR2 = 6.113890 + val coefficientsR2 = Vectors.dense(3.407021, 6.152512) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { @@ -518,15 +519,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) . - as.numeric.dataM.V2. 5.673348 - as.numeric.dataM.V3. 4.322251 + (Intercept) . + as.numeric.d1.V2. 5.643748 + as.numeric.d1.V3. 4.331519 */ val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(5.673348, 4.322251) + val coefficientsR1 = Vectors.dense(5.643748, 4.331519) - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, @@ -535,14 +536,15 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 5.477988 - as.numeric.data.V3. 4.297622 + as.numeric.d1.V2. 5.455902 + as.numeric.d1.V3. 4.312266 + */ val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(5.477988, 4.297622) + val coefficientsR2 = Vectors.dense(5.455902, 4.312266) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { @@ -592,21 +594,47 @@ class LinearRegressionSuite } /* - Use the following R code to generate model training results. - - predictions <- predict(fit, newx=features) - residuals <- label - predictions - > mean(residuals^2) # MSE - [1] 0.009720325 - > mean(abs(residuals)) # MAD - [1] 0.07863206 - > cor(predictions, label)^2# r^2 - [,1] - s0 0.9998749 + # Use the following R code to generate model training results. + + # path/part-00000 is the file generated by running LinearDataGenerator.generateLinearInput + # as described before the beforeAll() method. + d1 <- read.csv("path/part-00000", header=FALSE, stringsAsFactors=FALSE) + fit <- glm(V1 ~ V2 + V3, data = d1, family = "gaussian") + names(f1)[1] = c("V2") + names(f1)[2] = c("V3") + f1 <- data.frame(as.numeric(d1$V2), as.numeric(d1$V3)) + predictions <- predict(fit, newdata=f1) + l1 <- as.numeric(d1$V1) + + residuals <- l1 - predictions + > mean(residuals^2) # MSE + [1] 0.00985449 + > mean(abs(residuals)) # MAD + [1] 0.07961668 + > cor(predictions, l1)^2 # r^2 + [1] 0.9998737 + + > summary(fit) + + Call: + glm(formula = V1 ~ V2 + V3, family = "gaussian", data = d1) + + Deviance Residuals: + Min 1Q Median 3Q Max + -0.47082 -0.06797 0.00002 0.06725 0.34635 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 6.3022157 0.0018600 3388 <2e-16 *** + V2 4.6982442 0.0011805 3980 <2e-16 *** + V3 7.1994344 0.0009044 7961 <2e-16 *** + --- + + .... */ - assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5) - assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5) - assert(model.summary.r2 ~== 0.9998749 relTol 1E-5) + assert(model.summary.meanSquaredError ~== 0.00985449 relTol 1E-4) + assert(model.summary.meanAbsoluteError ~== 0.07961668 relTol 1E-4) + assert(model.summary.r2 ~== 0.9998737 relTol 1E-4) // Normal solver uses "WeightedLeastSquares". This algorithm does not generate // objective history because it does not run through iterations. @@ -621,14 +649,14 @@ class LinearRegressionSuite // To clalify that the normal solver is used here. assert(model.summary.objectiveHistory.length == 1) assert(model.summary.objectiveHistory(0) == 0.0) - val devianceResidualsR = Array(-0.35566, 0.34504) - val seCoefR = Array(0.0011756, 0.0009032, 0.0018489) - val tValsR = Array(3998, 7971, 3407) + val devianceResidualsR = Array(-0.47082, 0.34635) + val seCoefR = Array(0.0011805, 0.0009044, 0.0018600) + val tValsR = Array(3980, 7961, 3388) val pValsR = Array(0, 0, 0) model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => - assert(x._1 ~== x._2 absTol 1E-5) } + assert(x._1 ~== x._2 absTol 1E-4) } model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => - assert(x._1 ~== x._2 absTol 1E-5) } + assert(x._1 ~== x._2 absTol 1E-4) } model.summary.tValues.map(_.round).zip(tValsR).foreach{ x => assert(x._1 === x._2) } model.summary.pValues.map(_.round).zip(pValsR).foreach{ x => assert(x._1 === x._2) } } From 3dc7ca1f3d189df3173059aa42195efe9aa96a75 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 8 Dec 2015 19:18:59 +0800 Subject: [PATCH 1069/1168] [SPARK-11551][DOC][EXAMPLE] Revert PR #10002 This reverts PR #10002, commit 78209b0ccaf3f22b5e2345dfb2b98edfdb746819. The original PR wasn't tested on Jenkins before being merged. Author: Cheng Lian Closes #10200 from liancheng/revert-pr-10002. --- docs/ml-features.md | 1109 ++++++++++++++++- .../examples/ml/JavaBinarizerExample.java | 68 - .../examples/ml/JavaBucketizerExample.java | 70 -- .../spark/examples/ml/JavaDCTExample.java | 65 - .../ml/JavaElementwiseProductExample.java | 75 -- .../examples/ml/JavaMinMaxScalerExample.java | 50 - .../spark/examples/ml/JavaNGramExample.java | 71 -- .../examples/ml/JavaNormalizerExample.java | 52 - .../examples/ml/JavaOneHotEncoderExample.java | 77 -- .../spark/examples/ml/JavaPCAExample.java | 71 -- .../ml/JavaPolynomialExpansionExample.java | 71 -- .../examples/ml/JavaRFormulaExample.java | 69 - .../ml/JavaStandardScalerExample.java | 53 - .../ml/JavaStopWordsRemoverExample.java | 65 - .../examples/ml/JavaStringIndexerExample.java | 66 - .../examples/ml/JavaTokenizerExample.java | 75 -- .../ml/JavaVectorAssemblerExample.java | 67 - .../examples/ml/JavaVectorIndexerExample.java | 60 - .../examples/ml/JavaVectorSlicerExample.java | 73 -- .../src/main/python/ml/binarizer_example.py | 43 - .../src/main/python/ml/bucketizer_example.py | 42 - .../python/ml/elementwise_product_example.py | 39 - examples/src/main/python/ml/n_gram_example.py | 42 - .../src/main/python/ml/normalizer_example.py | 41 - .../main/python/ml/onehot_encoder_example.py | 47 - examples/src/main/python/ml/pca_example.py | 42 - .../python/ml/polynomial_expansion_example.py | 43 - .../src/main/python/ml/rformula_example.py | 44 - .../main/python/ml/standard_scaler_example.py | 42 - .../python/ml/stopwords_remover_example.py | 40 - .../main/python/ml/string_indexer_example.py | 39 - .../src/main/python/ml/tokenizer_example.py | 44 - .../python/ml/vector_assembler_example.py | 42 - .../main/python/ml/vector_indexer_example.py | 39 - .../spark/examples/ml/BinarizerExample.scala | 48 - .../spark/examples/ml/BucketizerExample.scala | 51 - .../apache/spark/examples/ml/DCTExample.scala | 54 - .../ml/ElementWiseProductExample.scala | 53 - .../examples/ml/MinMaxScalerExample.scala | 49 - .../spark/examples/ml/NGramExample.scala | 47 - .../spark/examples/ml/NormalizerExample.scala | 50 - .../examples/ml/OneHotEncoderExample.scala | 58 - .../apache/spark/examples/ml/PCAExample.scala | 54 - .../ml/PolynomialExpansionExample.scala | 53 - .../spark/examples/ml/RFormulaExample.scala | 49 - .../examples/ml/StandardScalerExample.scala | 51 - .../examples/ml/StopWordsRemoverExample.scala | 48 - .../examples/ml/StringIndexerExample.scala | 49 - .../spark/examples/ml/TokenizerExample.scala | 54 - .../examples/ml/VectorAssemblerExample.scala | 49 - .../examples/ml/VectorIndexerExample.scala | 53 - .../examples/ml/VectorSlicerExample.scala | 58 - 52 files changed, 1058 insertions(+), 2806 deletions(-) delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java delete mode 100644 examples/src/main/python/ml/binarizer_example.py delete mode 100644 examples/src/main/python/ml/bucketizer_example.py delete mode 100644 examples/src/main/python/ml/elementwise_product_example.py delete mode 100644 examples/src/main/python/ml/n_gram_example.py delete mode 100644 examples/src/main/python/ml/normalizer_example.py delete mode 100644 examples/src/main/python/ml/onehot_encoder_example.py delete mode 100644 examples/src/main/python/ml/pca_example.py delete mode 100644 examples/src/main/python/ml/polynomial_expansion_example.py delete mode 100644 examples/src/main/python/ml/rformula_example.py delete mode 100644 examples/src/main/python/ml/standard_scaler_example.py delete mode 100644 examples/src/main/python/ml/stopwords_remover_example.py delete mode 100644 examples/src/main/python/ml/string_indexer_example.py delete mode 100644 examples/src/main/python/ml/tokenizer_example.py delete mode 100644 examples/src/main/python/ml/vector_assembler_example.py delete mode 100644 examples/src/main/python/ml/vector_indexer_example.py delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index f85e0d56d2e40..01d6abeb5ba6a 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -170,7 +170,25 @@ Refer to the [Tokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.fea and the [RegexTokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/TokenizerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} + +val sentenceDataFrame = sqlContext.createDataFrame(Seq( + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") +)).toDF("label", "sentence") +val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") +val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + +val tokenized = tokenizer.transform(sentenceDataFrame) +tokenized.select("words", "label").take(3).foreach(println) +val regexTokenized = regexTokenizer.transform(sentenceDataFrame) +regexTokenized.select("words", "label").take(3).foreach(println) +{% endhighlight %}
    @@ -179,7 +197,44 @@ Refer to the [Tokenizer Java docs](api/java/org/apache/spark/ml/feature/Tokenize and the [RegexTokenizer Java docs](api/java/org/apache/spark/ml/feature/RegexTokenizer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaTokenizerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "Hi I heard about Spark"), + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") +)); +StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) +}); +DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); +Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); +DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); +for (Row r : wordsDataFrame.select("words", "label").take(3)) { + java.util.List words = r.getList(0); + for (String word : words) System.out.print(word + " "); + System.out.println(); +} + +RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); +{% endhighlight %}
    @@ -188,7 +243,21 @@ Refer to the [Tokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.featu the the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) for more details on the API. -{% include_example python/ml/tokenizer_example.py %} +{% highlight python %} +from pyspark.ml.feature import Tokenizer, RegexTokenizer + +sentenceDataFrame = sqlContext.createDataFrame([ + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") +], ["label", "sentence"]) +tokenizer = Tokenizer(inputCol="sentence", outputCol="words") +wordsDataFrame = tokenizer.transform(sentenceDataFrame) +for words_label in wordsDataFrame.select("words", "label").take(3): + print(words_label) +regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") +# alternatively, pattern="\\w+", gaps(False) +{% endhighlight %}
    @@ -237,7 +306,19 @@ filtered out. Refer to the [StopWordsRemover Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.StopWordsRemover + +val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") +val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) +)).toDF("id", "raw") + +remover.transform(dataSet).show() +{% endhighlight %}
    @@ -245,7 +326,34 @@ for more details on the API. Refer to the [StopWordsRemover Java docs](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + +JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) +)); +StructType schema = new StructType(new StructField[] { + new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) +}); +DataFrame dataset = jsql.createDataFrame(rdd, schema); + +remover.transform(dataset).show(); +{% endhighlight %}
    @@ -253,7 +361,17 @@ for more details on the API. Refer to the [StopWordsRemover Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) for more details on the API. -{% include_example python/ml/stopwords_remover_example.py %} +{% highlight python %} +from pyspark.ml.feature import StopWordsRemover + +sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) +], ["label", "raw"]) + +remover = StopWordsRemover(inputCol="raw", outputCol="filtered") +remover.transform(sentenceData).show(truncate=False) +{% endhighlight %}
    @@ -270,7 +388,19 @@ An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (t Refer to the [NGram Scala docs](api/scala/index.html#org.apache.spark.ml.feature.NGram) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/NGramExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.NGram + +val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) +)).toDF("label", "words") + +val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") +val ngramDataFrame = ngram.transform(wordDataFrame) +ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) +{% endhighlight %}
    @@ -278,7 +408,38 @@ for more details on the API. Refer to the [NGram Java docs](api/java/org/apache/spark/ml/feature/NGram.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaNGramExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) +)); +StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) +}); +DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); +NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); +DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); +for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); +} +{% endhighlight %}
    @@ -286,7 +447,19 @@ for more details on the API. Refer to the [NGram Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) for more details on the API. -{% include_example python/ml/n_gram_example.py %} +{% highlight python %} +from pyspark.ml.feature import NGram + +wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) +], ["label", "words"]) +ngram = NGram(inputCol="words", outputCol="ngrams") +ngramDataFrame = ngram.transform(wordDataFrame) +for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) +{% endhighlight %}
    @@ -303,7 +476,26 @@ Binarization is the process of thresholding numerical features to binary (0/1) f Refer to the [Binarizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/BinarizerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.Binarizer +import org.apache.spark.sql.DataFrame + +val data = Array( + (0, 0.1), + (1, 0.8), + (2, 0.2) +) +val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") + +val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5) + +val binarizedDataFrame = binarizer.transform(dataFrame) +val binarizedFeatures = binarizedDataFrame.select("binarized_feature") +binarizedFeatures.collect().foreach(println) +{% endhighlight %}
    @@ -311,7 +503,40 @@ for more details on the API. Refer to the [Binarizer Java docs](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaBinarizerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Binarizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 0.1), + RowFactory.create(1, 0.8), + RowFactory.create(2, 0.2) +)); +StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) +}); +DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); +Binarizer binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5); +DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); +DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); +for (Row r : binarizedFeatures.collect()) { + Double binarized_value = r.getDouble(0); + System.out.println(binarized_value); +} +{% endhighlight %}
    @@ -319,7 +544,20 @@ for more details on the API. Refer to the [Binarizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details on the API. -{% include_example python/ml/binarizer_example.py %} +{% highlight python %} +from pyspark.ml.feature import Binarizer + +continuousDataFrame = sqlContext.createDataFrame([ + (0, 0.1), + (1, 0.8), + (2, 0.2) +], ["label", "feature"]) +binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") +binarizedDataFrame = binarizer.transform(continuousDataFrame) +binarizedFeatures = binarizedDataFrame.select("binarized_feature") +for binarized_feature, in binarizedFeatures.collect(): + print(binarized_feature) +{% endhighlight %}
    @@ -333,7 +571,25 @@ for more details on the API. Refer to the [PCA Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PCA) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/PCAExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors + +val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) +) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) +val pcaDF = pca.transform(df) +val result = pcaDF.select("pcaFeatures") +result.show() +{% endhighlight %}
    @@ -341,7 +597,42 @@ for more details on the API. Refer to the [PCA Java docs](api/java/org/apache/spark/ml/feature/PCA.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaPCAExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.PCA +import org.apache.spark.ml.feature.PCAModel +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaSparkContext jsc = ... +SQLContext jsql = ... +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); +DataFrame result = pca.transform(df).select("pcaFeatures"); +result.show(); +{% endhighlight %}
    @@ -349,7 +640,19 @@ for more details on the API. Refer to the [PCA Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) for more details on the API. -{% include_example python/ml/pca_example.py %} +{% highlight python %} +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors + +data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] +df = sqlContext.createDataFrame(data,["features"]) +pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") +model = pca.fit(df) +result = model.transform(df).select("pcaFeatures") +result.show(truncate=False) +{% endhighlight %}
    @@ -363,7 +666,23 @@ for more details on the API. Refer to the [PolynomialExpansion Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.PolynomialExpansion +import org.apache.spark.mllib.linalg.Vectors + +val data = Array( + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0), + Vectors.dense(0.6, -1.1) +) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) +val polyDF = polynomialExpansion.transform(df) +polyDF.select("polyFeatures").take(3).foreach(println) +{% endhighlight %}
    @@ -371,7 +690,43 @@ for more details on the API. Refer to the [PolynomialExpansion Java docs](api/java/org/apache/spark/ml/feature/PolynomialExpansion.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaSparkContext jsc = ... +SQLContext jsql = ... +PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(-2.0, 2.3)), + RowFactory.create(Vectors.dense(0.0, 0.0)), + RowFactory.create(Vectors.dense(0.6, -1.1)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +DataFrame polyDF = polyExpansion.transform(df); +Row[] row = polyDF.select("polyFeatures").take(3); +for (Row r : row) { + System.out.println(r.get(0)); +} +{% endhighlight %}
    @@ -379,7 +734,20 @@ for more details on the API. Refer to the [PolynomialExpansion Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PolynomialExpansion) for more details on the API. -{% include_example python/ml/polynomial_expansion_example.py %} +{% highlight python %} +from pyspark.ml.feature import PolynomialExpansion +from pyspark.mllib.linalg import Vectors + +df = sqlContext.createDataFrame( + [(Vectors.dense([-2.0, 2.3]), ), + (Vectors.dense([0.0, 0.0]), ), + (Vectors.dense([0.6, -1.1]), )], + ["features"]) +px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") +polyDF = px.transform(df) +for expanded in polyDF.select("polyFeatures").take(3): + print(expanded) +{% endhighlight %}
    @@ -403,7 +771,22 @@ $0$th DCT coefficient and _not_ the $N/2$th). Refer to the [DCT Scala docs](api/scala/index.html#org.apache.spark.ml.feature.DCT) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/DCTExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.DCT +import org.apache.spark.mllib.linalg.Vectors + +val data = Seq( + Vectors.dense(0.0, 1.0, -2.0, 3.0), + Vectors.dense(-1.0, 2.0, 4.0, -7.0), + Vectors.dense(14.0, -2.0, -5.0, 1.0)) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false) +val dctDf = dct.transform(df) +dctDf.select("featuresDCT").show(3) +{% endhighlight %}
    @@ -411,7 +794,39 @@ for more details on the API. Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.DCT; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), + RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), + RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +DCT dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false); +DataFrame dctDf = dct.transform(df); +dctDf.select("featuresDCT").show(3); +{% endhighlight %}
    @@ -466,7 +881,18 @@ index `2`. Refer to the [StringIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/StringIndexerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.StringIndexer + +val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) +).toDF("id", "category") +val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") +val indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %}
    @@ -474,7 +900,37 @@ for more details on the API. Refer to the [StringIndexer Java docs](api/java/org/apache/spark/ml/feature/StringIndexer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaStringIndexerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") +)); +StructType schema = new StructType(new StructField[] { + createStructField("id", DoubleType, false), + createStructField("category", StringType, false) +}); +DataFrame df = sqlContext.createDataFrame(jrdd, schema); +StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); +DataFrame indexed = indexer.fit(df).transform(df); +indexed.show(); +{% endhighlight %}
    @@ -482,7 +938,16 @@ for more details on the API. Refer to the [StringIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) for more details on the API. -{% include_example python/ml/string_indexer_example.py %} +{% highlight python %} +from pyspark.ml.feature import StringIndexer + +df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) +indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") +indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %}
    @@ -496,7 +961,29 @@ for more details on the API. Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} + +val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") +)).toDF("id", "category") + +val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) +val indexed = indexer.transform(df) + +val encoder = new OneHotEncoder().setInputCol("categoryIndex"). + setOutputCol("categoryVec") +val encoded = encoder.transform(indexed) +encoded.select("id", "categoryVec").foreach(println) +{% endhighlight %}
    @@ -504,7 +991,45 @@ for more details on the API. Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.OneHotEncoder; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") +)); +StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) +}); +DataFrame df = sqlContext.createDataFrame(jrdd, schema); +StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); +DataFrame indexed = indexer.transform(df); + +OneHotEncoder encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec"); +DataFrame encoded = encoder.transform(indexed); +{% endhighlight %}
    @@ -512,7 +1037,24 @@ for more details on the API. Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) for more details on the API. -{% include_example python/ml/onehot_encoder_example.py %} +{% highlight python %} +from pyspark.ml.feature import OneHotEncoder, StringIndexer + +df = sqlContext.createDataFrame([ + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") +], ["id", "category"]) + +stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") +model = stringIndexer.fit(df) +indexed = model.transform(df) +encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") +encoded = encoder.transform(indexed) +{% endhighlight %}
    @@ -536,7 +1078,23 @@ In the example below, we read in a dataset of labeled points and then use `Vecto Refer to the [VectorIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorIndexer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/VectorIndexerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.VectorIndexer + +val data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +val indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10) +val indexerModel = indexer.fit(data) +val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet +println(s"Chose ${categoricalFeatures.size} categorical features: " + + categoricalFeatures.mkString(", ")) + +// Create new column "indexed" with categorical values transformed to indices +val indexedData = indexerModel.transform(data) +{% endhighlight %}
    @@ -544,7 +1102,30 @@ for more details on the API. Refer to the [VectorIndexer Java docs](api/java/org/apache/spark/ml/feature/VectorIndexer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java %} +{% highlight java %} +import java.util.Map; + +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.sql.DataFrame; + +DataFrame data = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); +VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10); +VectorIndexerModel indexerModel = indexer.fit(data); +Map> categoryMaps = indexerModel.javaCategoryMaps(); +System.out.print("Chose " + categoryMaps.size() + "categorical features:"); +for (Integer feature : categoryMaps.keySet()) { + System.out.print(" " + feature); +} +System.out.println(); + +// Create new column "indexed" with categorical values transformed to indices +DataFrame indexedData = indexerModel.transform(data); +{% endhighlight %}
    @@ -552,7 +1133,17 @@ for more details on the API. Refer to the [VectorIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorIndexer) for more details on the API. -{% include_example python/ml/vector_indexer_example.py %} +{% highlight python %} +from pyspark.ml.feature import VectorIndexer + +data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) +indexerModel = indexer.fit(data) + +# Create new column "indexed" with categorical values transformed to indices +indexedData = indexerModel.transform(data) +{% endhighlight %}
    @@ -569,7 +1160,22 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.Normalizer + +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") + +// Normalize each Vector using $L^1$ norm. +val normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0) +val l1NormData = normalizer.transform(dataFrame) + +// Normalize each Vector using $L^\infty$ norm. +val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) +{% endhighlight %}
    @@ -577,7 +1183,24 @@ for more details on the API. Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %} +{% highlight java %} +import org.apache.spark.ml.feature.Normalizer; +import org.apache.spark.sql.DataFrame; + +DataFrame dataFrame = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + +// Normalize each Vector using $L^1$ norm. +Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0); +DataFrame l1NormData = normalizer.transform(dataFrame); + +// Normalize each Vector using $L^\infty$ norm. +DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); +{% endhighlight %}
    @@ -585,7 +1208,19 @@ for more details on the API. Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) for more details on the API. -{% include_example python/ml/normalizer_example.py %} +{% highlight python %} +from pyspark.ml.feature import Normalizer + +dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") + +# Normalize each Vector using $L^1$ norm. +normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) +l1NormData = normalizer.transform(dataFrame) + +# Normalize each Vector using $L^\infty$ norm. +lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) +{% endhighlight %}
    @@ -609,7 +1244,23 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.StandardScaler + +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +val scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false) + +// Compute summary statistics by fitting the StandardScaler +val scalerModel = scaler.fit(dataFrame) + +// Normalize each feature to have unit standard deviation. +val scaledData = scalerModel.transform(dataFrame) +{% endhighlight %}
    @@ -617,7 +1268,25 @@ for more details on the API. Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %} +{% highlight java %} +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.ml.feature.StandardScalerModel; +import org.apache.spark.sql.DataFrame; + +DataFrame dataFrame = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); +StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + +// Compute summary statistics by fitting the StandardScaler +StandardScalerModel scalerModel = scaler.fit(dataFrame); + +// Normalize each feature to have unit standard deviation. +DataFrame scaledData = scalerModel.transform(dataFrame); +{% endhighlight %}
    @@ -625,7 +1294,20 @@ for more details on the API. Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) for more details on the API. -{% include_example python/ml/standard_scaler_example.py %} +{% highlight python %} +from pyspark.ml.feature import StandardScaler + +dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", + withStd=True, withMean=False) + +# Compute summary statistics by fitting the StandardScaler +scalerModel = scaler.fit(dataFrame) + +# Normalize each feature to have unit standard deviation. +scaledData = scalerModel.transform(dataFrame) +{% endhighlight %}
    @@ -655,7 +1337,21 @@ Refer to the [MinMaxScaler Scala docs](api/scala/index.html#org.apache.spark.ml. and the [MinMaxScalerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.MinMaxScaler + +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + +// Compute summary statistics and generate MinMaxScalerModel +val scalerModel = scaler.fit(dataFrame) + +// rescale each feature to range [min, max]. +val scaledData = scalerModel.transform(dataFrame) +{% endhighlight %}
    @@ -664,7 +1360,24 @@ Refer to the [MinMaxScaler Java docs](api/java/org/apache/spark/ml/feature/MinMa and the [MinMaxScalerModel Java docs](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %} +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.sql.DataFrame; + +DataFrame dataFrame = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); +MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + +// Compute summary statistics and generate MinMaxScalerModel +MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + +// rescale each feature to range [min, max]. +DataFrame scaledData = scalerModel.transform(dataFrame); +{% endhighlight %}
    @@ -688,7 +1401,23 @@ The following example demonstrates how to bucketize a column of `Double`s into a Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.Bucketizer +import org.apache.spark.sql.DataFrame + +val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + +val data = Array(-0.5, -0.3, 0.0, 0.2) +val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + +val bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits) + +// Transform original data into its bucket index. +val bucketedData = bucketizer.transform(dataFrame) +{% endhighlight %}
    @@ -696,7 +1425,38 @@ for more details on the API. Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; + +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) +}); +DataFrame dataFrame = jsql.createDataFrame(data, schema); + +Bucketizer bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits); + +// Transform original data into its bucket index. +DataFrame bucketedData = bucketizer.transform(dataFrame); +{% endhighlight %}
    @@ -704,7 +1464,19 @@ for more details on the API. Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) for more details on the API. -{% include_example python/ml/bucketizer_example.py %} +{% highlight python %} +from pyspark.ml.feature import Bucketizer + +splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] + +data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] +dataFrame = sqlContext.createDataFrame(data, ["features"]) + +bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") + +# Transform original data into its bucket index. +bucketedData = bucketizer.transform(dataFrame) +{% endhighlight %}
    @@ -736,7 +1508,25 @@ This example below demonstrates how to transform vectors using a transforming ve Refer to the [ElementwiseProduct Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors + +// Create some vector data; also works for sparse vectors +val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + +val transformingVector = Vectors.dense(0.0, 1.0, 2.0) +val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + +// Batch transform the vectors to create new column: +transformer.transform(dataFrame).show() + +{% endhighlight %}
    @@ -744,7 +1534,41 @@ for more details on the API. Refer to the [ElementwiseProduct Java docs](api/java/org/apache/spark/ml/feature/ElementwiseProduct.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +// Create some vector data; also works for sparse vectors +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) +)); +List fields = new ArrayList(2); +fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); +fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); +StructType schema = DataTypes.createStructType(fields); +DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); +Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); +ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); +// Batch transform the vectors to create new column: +transformer.transform(dataFrame).show(); + +{% endhighlight %}
    @@ -752,8 +1576,19 @@ for more details on the API. Refer to the [ElementwiseProduct Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.ElementwiseProduct) for more details on the API. -{% include_example python/ml/elementwise_product_example.py %} +{% highlight python %} +from pyspark.ml.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors + +data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] +df = sqlContext.createDataFrame(data, ["vector"]) +transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), + inputCol="vector", outputCol="transformedVector") +transformer.transform(df).show() + +{% endhighlight %}
    + ## SQLTransformer @@ -856,7 +1691,19 @@ output column to `features`, after transformation we should get the following Da Refer to the [VectorAssembler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala %} +{% highlight scala %} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.feature.VectorAssembler + +val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) +).toDF("id", "hour", "mobile", "userFeatures", "clicked") +val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") +val output = assembler.transform(dataset) +println(output.select("features", "clicked").first()) +{% endhighlight %}
    @@ -864,7 +1711,36 @@ for more details on the API. Refer to the [VectorAssembler Java docs](api/java/org/apache/spark/ml/feature/VectorAssembler.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) +}); +Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); +JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); +DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + +VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + +DataFrame output = assembler.transform(dataset); +System.out.println(output.select("features", "clicked").first()); +{% endhighlight %}
    @@ -872,7 +1748,19 @@ for more details on the API. Refer to the [VectorAssembler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) for more details on the API. -{% include_example python/ml/vector_assembler_example.py %} +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler + +dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) +assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") +output = assembler.transform(dataset) +print(output.select("features", "clicked").first()) +{% endhighlight %}
    @@ -1002,7 +1890,33 @@ Suppose also that we have a potential input attributes for the `userFeatures`, i Refer to the [VectorSlicer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/VectorSlicerExample.scala %} +{% highlight scala %} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3, 0.0) +) + +val defaultAttr = NumericAttribute.defaultAttr +val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) +val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + +val dataRDD = sc.parallelize(data).map(Row.apply) +val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField())) + +val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + +slicer.setIndices(1).setNames("f3") +// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + +val output = slicer.transform(dataset) +println(output.select("userFeatures", "features").first()) +{% endhighlight %}
    @@ -1010,7 +1924,41 @@ for more details on the API. Refer to the [VectorSlicer Java docs](api/java/org/apache/spark/ml/feature/VectorSlicer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") +}; +AttributeGroup group = new AttributeGroup("userFeatures", attrs); + +JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) +)); + +DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + +VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + +vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); +// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + +DataFrame output = vectorSlicer.transform(dataset); + +System.out.println(output.select("userFeatures", "features").first()); +{% endhighlight %}
    @@ -1047,7 +1995,21 @@ id | country | hour | clicked | features | label Refer to the [RFormula Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RFormula) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/RFormulaExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.RFormula + +val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) +)).toDF("id", "country", "hour", "clicked") +val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") +val output = formula.fit(dataset).transform(dataset) +output.select("features", "label").show() +{% endhighlight %}
    @@ -1055,7 +2017,38 @@ for more details on the API. Refer to the [RFormula Java docs](api/java/org/apache/spark/ml/feature/RFormula.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaRFormulaExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) +}); +JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) +)); +DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + +RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + +DataFrame output = formula.fit(dataset).transform(dataset); +output.select("features", "label").show(); +{% endhighlight %}
    @@ -1063,7 +2056,21 @@ for more details on the API. Refer to the [RFormula Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) for more details on the API. -{% include_example python/ml/rformula_example.py %} +{% highlight python %} +from pyspark.ml.feature import RFormula + +dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) +formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") +output = formula.fit(dataset).transform(dataset) +output.select("features", "label").show() +{% endhighlight %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java deleted file mode 100644 index 9698cac504371..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Binarizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaBinarizerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaBinarizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, 0.1), - RowFactory.create(1, 0.8), - RowFactory.create(2, 0.2) - )); - StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) - }); - DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); - Binarizer binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5); - DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); - DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); - for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); - System.out.println(binarized_value); - } - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java deleted file mode 100644 index b06a23e76d604..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Bucketizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaBucketizerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaBucketizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; - - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(-0.5), - RowFactory.create(-0.3), - RowFactory.create(0.0), - RowFactory.create(0.2) - )); - StructType schema = new StructType(new StructField[]{ - new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) - }); - DataFrame dataFrame = jsql.createDataFrame(data, schema); - - Bucketizer bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits); - - // Transform original data into its bucket index. - DataFrame bucketedData = bucketizer.transform(dataFrame); - // $example off$ - jsc.stop(); - } -} - - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java deleted file mode 100644 index 35c0d534a45e9..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.DCT; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaDCTExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaDCTExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), - RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), - RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) - )); - StructType schema = new StructType(new StructField[]{ - new StructField("features", new VectorUDT(), false, Metadata.empty()), - }); - DataFrame df = jsql.createDataFrame(data, schema); - DCT dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false); - DataFrame dctDf = dct.transform(df); - dctDf.select("featuresDCT").show(3); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java deleted file mode 100644 index 2898accec61b0..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.ElementwiseProduct; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaElementwiseProductExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - // Create some vector data; also works for sparse vectors - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), - RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) - )); - - List fields = new ArrayList(2); - fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); - fields.add(DataTypes.createStructField("vector", new VectorUDT(), false)); - - StructType schema = DataTypes.createStructType(fields); - - DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); - - Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); - - ElementwiseProduct transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector"); - - // Batch transform the vectors to create new column: - transformer.transform(dataFrame).show(); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java deleted file mode 100644 index 138b3ab6aba44..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import org.apache.spark.ml.feature.MinMaxScaler; -import org.apache.spark.ml.feature.MinMaxScalerModel; -import org.apache.spark.sql.DataFrame; -// $example off$ - -public class JavaMinMaxScalerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JaveMinMaxScalerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - MinMaxScaler scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures"); - - // Compute summary statistics and generate MinMaxScalerModel - MinMaxScalerModel scalerModel = scaler.fit(dataFrame); - - // rescale each feature to range [min, max]. - DataFrame scaledData = scalerModel.transform(dataFrame); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java deleted file mode 100644 index 8fd75ed8b5f4e..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.NGram; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaNGramExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaNGramExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), - RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), - RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField( - "words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) - }); - - DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); - - NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); - - DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); - - for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { - java.util.List ngrams = r.getList(0); - for (String ngram : ngrams) System.out.print(ngram + " --- "); - System.out.println(); - } - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java deleted file mode 100644 index 6283a355e1fef..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.sql.DataFrame; -// $example off$ - -public class JavaNormalizerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaNormalizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - - // Normalize each Vector using $L^1$ norm. - Normalizer normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0); - - DataFrame l1NormData = normalizer.transform(dataFrame); - - // Normalize each Vector using $L^\infty$ norm. - DataFrame lInfNormData = - normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java deleted file mode 100644 index 172a9cc6feb28..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaOneHotEncoderExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaOneHotEncoderExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) - }); - - DataFrame df = sqlContext.createDataFrame(jrdd, schema); - - StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); - DataFrame indexed = indexer.transform(df); - - OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); - DataFrame encoded = encoder.transform(indexed); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java deleted file mode 100644 index 8282fab084f36..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.PCA; -import org.apache.spark.ml.feature.PCAModel; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaPCAExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaPCAExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), - RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("features", new VectorUDT(), false, Metadata.empty()), - }); - - DataFrame df = jsql.createDataFrame(data, schema); - - PCAModel pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df); - - DataFrame result = pca.transform(df).select("pcaFeatures"); - result.show(); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java deleted file mode 100644 index 668f71e64056b..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.PolynomialExpansion; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaPolynomialExpansionExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaPolynomialExpansionExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - PolynomialExpansion polyExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3); - - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(-2.0, 2.3)), - RowFactory.create(Vectors.dense(0.0, 0.0)), - RowFactory.create(Vectors.dense(0.6, -1.1)) - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("features", new VectorUDT(), false, Metadata.empty()), - }); - - DataFrame df = jsql.createDataFrame(data, schema); - DataFrame polyDF = polyExpansion.transform(df); - - Row[] row = polyDF.select("polyFeatures").take(3); - for (Row r : row) { - System.out.println(r.get(0)); - } - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java deleted file mode 100644 index 1e1062b541ad9..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RFormula; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -import static org.apache.spark.sql.types.DataTypes.*; -// $example off$ - -public class JavaRFormulaExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaRFormulaExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - StructType schema = createStructType(new StructField[]{ - createStructField("id", IntegerType, false), - createStructField("country", StringType, false), - createStructField("hour", IntegerType, false), - createStructField("clicked", DoubleType, false) - }); - - JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(7, "US", 18, 1.0), - RowFactory.create(8, "CA", 12, 0.0), - RowFactory.create(9, "NZ", 15, 0.0) - )); - - DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - RFormula formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label"); - DataFrame output = formula.fit(dataset).transform(dataset); - output.select("features", "label").show(); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java deleted file mode 100644 index 0cbdc97e8ae30..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.ml.feature.StandardScalerModel; -import org.apache.spark.sql.DataFrame; -// $example off$ - -public class JavaStandardScalerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaStandardScalerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - - StandardScaler scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false); - - // Compute summary statistics by fitting the StandardScaler - StandardScalerModel scalerModel = scaler.fit(dataFrame); - - // Normalize each feature to have unit standard deviation. - DataFrame scaledData = scalerModel.transform(dataFrame); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java deleted file mode 100644 index b6b201c6b68d2..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StopWordsRemover; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaStopWordsRemoverExample { - - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaStopWordsRemoverExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - StopWordsRemover remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered"); - - JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), - RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) - )); - - StructType schema = new StructType(new StructField[]{ - new StructField( - "raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) - }); - - DataFrame dataset = jsql.createDataFrame(rdd, schema); - remover.transform(dataset).show(); - // $example off$ - jsc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java deleted file mode 100644 index 05d12c1e702f1..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -import static org.apache.spark.sql.types.DataTypes.*; -// $example off$ - -public class JavaStringIndexerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaStringIndexerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") - )); - StructType schema = new StructType(new StructField[]{ - createStructField("id", IntegerType, false), - createStructField("category", StringType, false) - }); - DataFrame df = sqlContext.createDataFrame(jrdd, schema); - StringIndexer indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex"); - DataFrame indexed = indexer.fit(df).transform(df); - indexed.show(); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java deleted file mode 100644 index 617dc3f66e3bf..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RegexTokenizer; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaTokenizerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaTokenizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(1, "I wish Java could use case classes"), - RowFactory.create(2, "Logistic,regression,models,are,neat") - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) - }); - - DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); - - Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); - - DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); - for (Row r : wordsDataFrame.select("words", "label"). take(3)) { - java.util.List words = r.getList(0); - for (String word : words) System.out.print(word + " "); - System.out.println(); - } - - RegexTokenizer regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); - // $example off$ - jsc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java deleted file mode 100644 index 7e230b5897c1e..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.VectorAssembler; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; - -import static org.apache.spark.sql.types.DataTypes.*; -// $example off$ - -public class JavaVectorAssemblerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaVectorAssemblerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - StructType schema = createStructType(new StructField[]{ - createStructField("id", IntegerType, false), - createStructField("hour", IntegerType, false), - createStructField("mobile", DoubleType, false), - createStructField("userFeatures", new VectorUDT(), false), - createStructField("clicked", DoubleType, false) - }); - Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); - JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); - DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - - VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) - .setOutputCol("features"); - - DataFrame output = assembler.transform(dataset); - System.out.println(output.select("features", "clicked").first()); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java deleted file mode 100644 index 06b4bf6bf8ff6..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Map; - -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.sql.DataFrame; -// $example off$ - -public class JavaVectorIndexerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaVectorIndexerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - - VectorIndexer indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10); - VectorIndexerModel indexerModel = indexer.fit(data); - - Map> categoryMaps = indexerModel.javaCategoryMaps(); - System.out.print("Chose " + categoryMaps.size() + " categorical features:"); - - for (Integer feature : categoryMaps.keySet()) { - System.out.print(" " + feature); - } - System.out.println(); - - // Create new column "indexed" with categorical values transformed to indices - DataFrame indexedData = indexerModel.transform(data); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java deleted file mode 100644 index 4d5cb04ff5e2b..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import com.google.common.collect.Lists; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.attribute.Attribute; -import org.apache.spark.ml.attribute.AttributeGroup; -import org.apache.spark.ml.attribute.NumericAttribute; -import org.apache.spark.ml.feature.VectorSlicer; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -// $example off$ - -public class JavaVectorSlicerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaVectorSlicerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - Attribute[] attrs = new Attribute[]{ - NumericAttribute.defaultAttr().withName("f1"), - NumericAttribute.defaultAttr().withName("f2"), - NumericAttribute.defaultAttr().withName("f3") - }; - AttributeGroup group = new AttributeGroup("userFeatures", attrs); - - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), - RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) - )); - - DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); - - VectorSlicer vectorSlicer = new VectorSlicer() - .setInputCol("userFeatures").setOutputCol("features"); - - vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); - // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) - - DataFrame output = vectorSlicer.transform(dataset); - - System.out.println(output.select("userFeatures", "features").first()); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/python/ml/binarizer_example.py b/examples/src/main/python/ml/binarizer_example.py deleted file mode 100644 index 960ad208be12e..0000000000000 --- a/examples/src/main/python/ml/binarizer_example.py +++ /dev/null @@ -1,43 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import Binarizer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="BinarizerExample") - sqlContext = SQLContext(sc) - - # $example on$ - continuousDataFrame = sqlContext.createDataFrame([ - (0, 0.1), - (1, 0.8), - (2, 0.2) - ], ["label", "feature"]) - binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") - binarizedDataFrame = binarizer.transform(continuousDataFrame) - binarizedFeatures = binarizedDataFrame.select("binarized_feature") - for binarized_feature, in binarizedFeatures.collect(): - print(binarized_feature) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/bucketizer_example.py b/examples/src/main/python/ml/bucketizer_example.py deleted file mode 100644 index a12750aa9248a..0000000000000 --- a/examples/src/main/python/ml/bucketizer_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import Bucketizer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="BucketizerExample") - sqlContext = SQLContext(sc) - - # $example on$ - splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] - - data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] - dataFrame = sqlContext.createDataFrame(data, ["features"]) - - bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") - - # Transform original data into its bucket index. - bucketedData = bucketizer.transform(dataFrame) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py deleted file mode 100644 index c85cb0d89543c..0000000000000 --- a/examples/src/main/python/ml/elementwise_product_example.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import ElementwiseProduct -from pyspark.mllib.linalg import Vectors -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="ElementwiseProductExample") - sqlContext = SQLContext(sc) - - # $example on$ - data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] - df = sqlContext.createDataFrame(data, ["vector"]) - transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), - inputCol="vector", outputCol="transformedVector") - transformer.transform(df).show() - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/n_gram_example.py b/examples/src/main/python/ml/n_gram_example.py deleted file mode 100644 index f2d85f53e7219..0000000000000 --- a/examples/src/main/python/ml/n_gram_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import NGram -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="NGramExample") - sqlContext = SQLContext(sc) - - # $example on$ - wordDataFrame = sqlContext.createDataFrame([ - (0, ["Hi", "I", "heard", "about", "Spark"]), - (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), - (2, ["Logistic", "regression", "models", "are", "neat"]) - ], ["label", "words"]) - ngram = NGram(inputCol="words", outputCol="ngrams") - ngramDataFrame = ngram.transform(wordDataFrame) - for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): - print(ngrams_label) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/normalizer_example.py b/examples/src/main/python/ml/normalizer_example.py deleted file mode 100644 index 833d93e976a7e..0000000000000 --- a/examples/src/main/python/ml/normalizer_example.py +++ /dev/null @@ -1,41 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import Normalizer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="NormalizerExample") - sqlContext = SQLContext(sc) - - # $example on$ - dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Normalize each Vector using $L^1$ norm. - normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) - l1NormData = normalizer.transform(dataFrame) - - # Normalize each Vector using $L^\infty$ norm. - lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py deleted file mode 100644 index 7529dfd09213a..0000000000000 --- a/examples/src/main/python/ml/onehot_encoder_example.py +++ /dev/null @@ -1,47 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import OneHotEncoder, StringIndexer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="OneHotEncoderExample") - sqlContext = SQLContext(sc) - - # $example on$ - df = sqlContext.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - ], ["id", "category"]) - - stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") - model = stringIndexer.fit(df) - indexed = model.transform(df) - encoder = OneHotEncoder(dropLast=False, inputCol="categoryIndex", outputCol="categoryVec") - encoded = encoder.transform(indexed) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/pca_example.py b/examples/src/main/python/ml/pca_example.py deleted file mode 100644 index 8b66140a40a7a..0000000000000 --- a/examples/src/main/python/ml/pca_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import PCA -from pyspark.mllib.linalg import Vectors -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="PCAExample") - sqlContext = SQLContext(sc) - - # $example on$ - data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), - (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), - (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] - df = sqlContext.createDataFrame(data,["features"]) - pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") - model = pca.fit(df) - result = model.transform(df).select("pcaFeatures") - result.show(truncate=False) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py deleted file mode 100644 index 030a6132a451a..0000000000000 --- a/examples/src/main/python/ml/polynomial_expansion_example.py +++ /dev/null @@ -1,43 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import PolynomialExpansion -from pyspark.mllib.linalg import Vectors -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="PolynomialExpansionExample") - sqlContext = SQLContext(sc) - - # $example on$ - df = sqlContext.createDataFrame( - [(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], - ["features"]) - px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") - polyDF = px.transform(df) - for expanded in polyDF.select("polyFeatures").take(3): - print(expanded) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/rformula_example.py b/examples/src/main/python/ml/rformula_example.py deleted file mode 100644 index b544a14700762..0000000000000 --- a/examples/src/main/python/ml/rformula_example.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import RFormula -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="RFormulaExample") - sqlContext = SQLContext(sc) - - # $example on$ - dataset = sqlContext.createDataFrame( - [(7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0)], - ["id", "country", "hour", "clicked"]) - formula = RFormula( - formula="clicked ~ country + hour", - featuresCol="features", - labelCol="label") - output = formula.fit(dataset).transform(dataset) - output.select("features", "label").show() - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/standard_scaler_example.py b/examples/src/main/python/ml/standard_scaler_example.py deleted file mode 100644 index 139acecbfb53f..0000000000000 --- a/examples/src/main/python/ml/standard_scaler_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import StandardScaler -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="StandardScalerExample") - sqlContext = SQLContext(sc) - - # $example on$ - dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", - withStd=True, withMean=False) - - # Compute summary statistics by fitting the StandardScaler - scalerModel = scaler.fit(dataFrame) - - # Normalize each feature to have unit standard deviation. - scaledData = scalerModel.transform(dataFrame) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/stopwords_remover_example.py b/examples/src/main/python/ml/stopwords_remover_example.py deleted file mode 100644 index 01f94af8ca752..0000000000000 --- a/examples/src/main/python/ml/stopwords_remover_example.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import StopWordsRemover -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="StopWordsRemoverExample") - sqlContext = SQLContext(sc) - - # $example on$ - sentenceData = sqlContext.createDataFrame([ - (0, ["I", "saw", "the", "red", "baloon"]), - (1, ["Mary", "had", "a", "little", "lamb"]) - ], ["label", "raw"]) - - remover = StopWordsRemover(inputCol="raw", outputCol="filtered") - remover.transform(sentenceData).show(truncate=False) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/string_indexer_example.py b/examples/src/main/python/ml/string_indexer_example.py deleted file mode 100644 index 58a8cb5d56b73..0000000000000 --- a/examples/src/main/python/ml/string_indexer_example.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import StringIndexer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="StringIndexerExample") - sqlContext = SQLContext(sc) - - # $example on$ - df = sqlContext.createDataFrame( - [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], - ["id", "category"]) - indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") - indexed = indexer.fit(df).transform(df) - indexed.show() - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/tokenizer_example.py b/examples/src/main/python/ml/tokenizer_example.py deleted file mode 100644 index ce9b225be5357..0000000000000 --- a/examples/src/main/python/ml/tokenizer_example.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import Tokenizer, RegexTokenizer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="TokenizerExample") - sqlContext = SQLContext(sc) - - # $example on$ - sentenceDataFrame = sqlContext.createDataFrame([ - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") - ], ["label", "sentence"]) - tokenizer = Tokenizer(inputCol="sentence", outputCol="words") - wordsDataFrame = tokenizer.transform(sentenceDataFrame) - for words_label in wordsDataFrame.select("words", "label").take(3): - print(words_label) - regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") - # alternatively, pattern="\\w+", gaps(False) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/vector_assembler_example.py b/examples/src/main/python/ml/vector_assembler_example.py deleted file mode 100644 index 04f64839f188d..0000000000000 --- a/examples/src/main/python/ml/vector_assembler_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.mllib.linalg import Vectors -from pyspark.ml.feature import VectorAssembler -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="VectorAssemblerExample") - sqlContext = SQLContext(sc) - - # $example on$ - dataset = sqlContext.createDataFrame( - [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], - ["id", "hour", "mobile", "userFeatures", "clicked"]) - assembler = VectorAssembler( - inputCols=["hour", "mobile", "userFeatures"], - outputCol="features") - output = assembler.transform(dataset) - print(output.select("features", "clicked").first()) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/vector_indexer_example.py b/examples/src/main/python/ml/vector_indexer_example.py deleted file mode 100644 index cc00d1454f2e0..0000000000000 --- a/examples/src/main/python/ml/vector_indexer_example.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import VectorIndexer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="VectorIndexerExample") - sqlContext = SQLContext(sc) - - # $example on$ - data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) - indexerModel = indexer.fit(data) - - # Create new column "indexed" with categorical values transformed to indices - indexedData = indexerModel.transform(data) - # $example off$ - - sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala deleted file mode 100644 index e724aa587294b..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.Binarizer -// $example off$ -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.{SparkConf, SparkContext} - -object BinarizerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("BinarizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - // $example on$ - val data = Array((0, 0.1), (1, 0.8), (2, 0.2)) - val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") - - val binarizer: Binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5) - - val binarizedDataFrame = binarizer.transform(dataFrame) - val binarizedFeatures = binarizedDataFrame.select("binarized_feature") - binarizedFeatures.collect().foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala deleted file mode 100644 index 30c2776d39688..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.Bucketizer -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object BucketizerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("BucketizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) - - val data = Array(-0.5, -0.3, 0.0, 0.2) - val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - - val bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits) - - // Transform original data into its bucket index. - val bucketedData = bucketizer.transform(dataFrame) - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala deleted file mode 100644 index 314c2c28a2a10..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.DCT -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object DCTExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("DCTExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = Seq( - Vectors.dense(0.0, 1.0, -2.0, 3.0), - Vectors.dense(-1.0, 2.0, 4.0, -7.0), - Vectors.dense(14.0, -2.0, -5.0, 1.0)) - - val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - - val dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false) - - val dctDf = dct.transform(df) - dctDf.select("featuresDCT").show(3) - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala deleted file mode 100644 index ac50bb7b2b155..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.ElementwiseProduct -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object ElementwiseProductExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("ElementwiseProductExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - // Create some vector data; also works for sparse vectors - val dataFrame = sqlContext.createDataFrame(Seq( - ("a", Vectors.dense(1.0, 2.0, 3.0)), - ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") - - val transformingVector = Vectors.dense(0.0, 1.0, 2.0) - val transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector") - - // Batch transform the vectors to create new column: - transformer.transform(dataFrame).show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala deleted file mode 100644 index dac3679a5bf7e..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.MinMaxScaler -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object MinMaxScalerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("MinMaxScalerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - val scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - - // Compute summary statistics and generate MinMaxScalerModel - val scalerModel = scaler.fit(dataFrame) - - // rescale each feature to range [min, max]. - val scaledData = scalerModel.transform(dataFrame) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala deleted file mode 100644 index 8a85f71b56f3d..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.NGram -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object NGramExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("NGramExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val wordDataFrame = sqlContext.createDataFrame(Seq( - (0, Array("Hi", "I", "heard", "about", "Spark")), - (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), - (2, Array("Logistic", "regression", "models", "are", "neat")) - )).toDF("label", "words") - - val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") - val ngramDataFrame = ngram.transform(wordDataFrame) - ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala deleted file mode 100644 index 17571f0aad793..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.Normalizer -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object NormalizerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("NormalizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - // Normalize each Vector using $L^1$ norm. - val normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0) - - val l1NormData = normalizer.transform(dataFrame) - - // Normalize each Vector using $L^\infty$ norm. - val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala deleted file mode 100644 index 4512736943dd5..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object OneHotEncoderExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("OneHotEncoderExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val df = sqlContext.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - )).toDF("id", "category") - - val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) - val indexed = indexer.transform(df) - - val encoder = new OneHotEncoder().setInputCol("categoryIndex"). - setOutputCol("categoryVec") - val encoded = encoder.transform(indexed) - encoded.select("id", "categoryVec").foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala deleted file mode 100644 index a18d4f33973d8..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.PCA -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object PCAExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("PCAExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = Array( - Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), - Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), - Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) - ) - val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - val pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df) - val pcaDF = pca.transform(df) - val result = pcaDF.select("pcaFeatures") - result.show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala deleted file mode 100644 index b8e9e6952a5ea..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.PolynomialExpansion -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object PolynomialExpansionExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("PolynomialExpansionExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = Array( - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0), - Vectors.dense(0.6, -1.1) - ) - val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - val polynomialExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3) - val polyDF = polynomialExpansion.transform(df) - polyDF.select("polyFeatures").take(3).foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println - - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala deleted file mode 100644 index 286866edea502..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.RFormula -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object RFormulaExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("RFormulaExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataset = sqlContext.createDataFrame(Seq( - (7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0) - )).toDF("id", "country", "hour", "clicked") - val formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label") - val output = formula.fit(dataset).transform(dataset) - output.select("features", "label").show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala deleted file mode 100644 index 646ce0f13ecf5..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.StandardScaler -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object StandardScalerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("StandardScalerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - val scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false) - - // Compute summary statistics by fitting the StandardScaler. - val scalerModel = scaler.fit(dataFrame) - - // Normalize each feature to have unit standard deviation. - val scaledData = scalerModel.transform(dataFrame) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala deleted file mode 100644 index 655ffce08d3ab..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.StopWordsRemover -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object StopWordsRemoverExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("StopWordsRemoverExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered") - - val dataSet = sqlContext.createDataFrame(Seq( - (0, Seq("I", "saw", "the", "red", "baloon")), - (1, Seq("Mary", "had", "a", "little", "lamb")) - )).toDF("id", "raw") - - remover.transform(dataSet).show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala deleted file mode 100644 index 1be8a5f33f7c0..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.StringIndexer -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object StringIndexerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("StringIndexerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val df = sqlContext.createDataFrame( - Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) - ).toDF("id", "category") - - val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - - val indexed = indexer.fit(df).transform(df) - indexed.show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala deleted file mode 100644 index 01e0d1388a2f4..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object TokenizerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("TokenizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val sentenceDataFrame = sqlContext.createDataFrame(Seq( - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") - )).toDF("label", "sentence") - - val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") - val regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) - - val tokenized = tokenizer.transform(sentenceDataFrame) - tokenized.select("words", "label").take(3).foreach(println) - val regexTokenized = regexTokenizer.transform(sentenceDataFrame) - regexTokenized.select("words", "label").take(3).foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala deleted file mode 100644 index d527924419f81..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.VectorAssembler -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object VectorAssemblerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("VectorAssemblerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataset = sqlContext.createDataFrame( - Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) - ).toDF("id", "hour", "mobile", "userFeatures", "clicked") - - val assembler = new VectorAssembler() - .setInputCols(Array("hour", "mobile", "userFeatures")) - .setOutputCol("features") - - val output = assembler.transform(dataset) - println(output.select("features", "clicked").first()) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala deleted file mode 100644 index 14279d610fda8..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.VectorIndexer -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object VectorIndexerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("VectorIndexerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - val indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10) - - val indexerModel = indexer.fit(data) - - val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet - println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) - - // Create new column "indexed" with categorical values transformed to indices - val indexedData = indexerModel.transform(data) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala deleted file mode 100644 index 04f19829eff87..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} -import org.apache.spark.ml.feature.VectorSlicer -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.sql.Row -import org.apache.spark.sql.types.StructType -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object VectorSlicerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("VectorSlicerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0))) - - val defaultAttr = NumericAttribute.defaultAttr - val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) - val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) - - val dataRDD = sc.parallelize(data) - val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField()))) - - val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") - - slicer.setIndices(Array(1)).setNames(Array("f3")) - // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) - - val output = slicer.transform(dataset) - println(output.select("userFeatures", "features").first()) - // $example off$ - sc.stop() - } -} -// scalastyle:on println From f7fc52a71eb44e42fc358fff41fe247a6b292509 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 8 Dec 2015 14:34:47 +0000 Subject: [PATCH 1070/1168] [SPARK-11652][CORE] Remote code execution with InvokerTransformer Fix commons-collection group ID to commons-collections for version 3.x Patches earlier PR at https://github.com/apache/spark/pull/9731 Author: Sean Owen Closes #10198 from srowen/SPARK-11652.2. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index ae2ff8878b0a5..5daca03f61436 100644 --- a/pom.xml +++ b/pom.xml @@ -478,7 +478,7 @@ ${commons.math3.version}
    - org.apache.commons + commons-collections commons-collections ${commons.collections.version} From 1ee1b4fbda05f0019c09b1af49a0ba84631cdaa8 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Tue, 8 Dec 2015 11:44:51 -0600 Subject: [PATCH 1071/1168] [SPARK-11155][WEB UI] Stage summary json should include stage duration The json endpoint for stages doesn't include information on the stage duration that is present in the UI. This looks like a simple oversight, they should be included. eg., the metrics should be included at api/v1/applications//stages. Metrics I've added are: submissionTime, firstTaskLaunchedTime and completionTime Author: Xin Ren Closes #10107 from keypointt/SPARK-11155. --- .../status/api/v1/AllStagesResource.scala | 14 ++++- .../org/apache/spark/status/api/v1/api.scala | 3 + .../complete_stage_list_json_expectation.json | 11 +++- .../failed_stage_list_json_expectation.json | 5 +- .../one_stage_attempt_json_expectation.json | 5 +- .../one_stage_json_expectation.json | 5 +- .../stage_list_json_expectation.json | 14 ++++- ...ist_with_accumulable_json_expectation.json | 5 +- ...age_with_accumulable_json_expectation.json | 5 +- .../api/v1/AllStagesResourceSuite.scala | 62 +++++++++++++++++++ project/MimaExcludes.scala | 4 +- 11 files changed, 124 insertions(+), 9 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 24a0b5220695c..31b4dd7c0f427 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -17,8 +17,8 @@ package org.apache.spark.status.api.v1 import java.util.{Arrays, Date, List => JList} -import javax.ws.rs.{GET, PathParam, Produces, QueryParam} import javax.ws.rs.core.MediaType +import javax.ws.rs.{GET, Produces, QueryParam} import org.apache.spark.executor.{InputMetrics => InternalInputMetrics, OutputMetrics => InternalOutputMetrics, ShuffleReadMetrics => InternalShuffleReadMetrics, ShuffleWriteMetrics => InternalShuffleWriteMetrics, TaskMetrics => InternalTaskMetrics} import org.apache.spark.scheduler.{AccumulableInfo => InternalAccumulableInfo, StageInfo} @@ -59,6 +59,15 @@ private[v1] object AllStagesResource { stageUiData: StageUIData, includeDetails: Boolean): StageData = { + val taskLaunchTimes = stageUiData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) + + val firstTaskLaunchedTime: Option[Date] = + if (taskLaunchTimes.nonEmpty) { + Some(new Date(taskLaunchTimes.min)) + } else { + None + } + val taskData = if (includeDetails) { Some(stageUiData.taskData.map { case (k, v) => k -> convertTaskData(v) } ) } else { @@ -92,6 +101,9 @@ private[v1] object AllStagesResource { numCompleteTasks = stageUiData.numCompleteTasks, numFailedTasks = stageUiData.numFailedTasks, executorRunTime = stageUiData.executorRunTime, + submissionTime = stageInfo.submissionTime.map(new Date(_)), + firstTaskLaunchedTime, + completionTime = stageInfo.completionTime.map(new Date(_)), inputBytes = stageUiData.inputBytes, inputRecords = stageUiData.inputRecords, outputBytes = stageUiData.outputBytes, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index baddfc50c1a40..5feb1dc2e5b74 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -120,6 +120,9 @@ class StageData private[spark]( val numFailedTasks: Int, val executorRunTime: Long, + val submissionTime: Option[Date], + val firstTaskLaunchedTime: Option[Date], + val completionTime: Option[Date], val inputBytes: Long, val inputRecords: Long, diff --git a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json index 31ac9beea8788..8f8067f86d57f 100644 --- a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "submissionTime" : "2015-02-03T16:43:07.191GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", + "completionTime" : "2015-02-03T16:43:07.226GMT", "inputBytes" : 160, "inputRecords" : 0, "outputBytes" : 0, @@ -28,6 +31,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -50,6 +56,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "submissionTime" : "2015-02-03T16:43:04.228GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", + "completionTime" : "2015-02-03T16:43:04.819GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -64,4 +73,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line9.$read$$iwC$$iwC$$iwC$$iwC.(:15)\n$line9.$read$$iwC$$iwC$$iwC.(:20)\n$line9.$read$$iwC$$iwC.(:22)\n$line9.$read$$iwC.(:24)\n$line9.$read.(:26)\n$line9.$read$.(:30)\n$line9.$read$.()\n$line9.$eval$.(:7)\n$line9.$eval$.()\n$line9.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json index bff6a4f69d077..08b692eda8028 100644 --- a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "submissionTime" : "2015-02-03T16:43:06.296GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", + "completionTime" : "2015-02-03T16:43:06.347GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -20,4 +23,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 111cb8163eb3d..b07011d4f113f 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -267,4 +270,4 @@ "diskBytesSpilled" : 0 } } -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index ef339f89afa45..2f71520549e1f 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -267,4 +270,4 @@ "diskBytesSpilled" : 0 } } -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json index 056fac7088594..5b957ed549556 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "submissionTime" : "2015-02-03T16:43:07.191GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", + "completionTime" : "2015-02-03T16:43:07.226GMT", "inputBytes" : 160, "inputRecords" : 0, "outputBytes" : 0, @@ -28,6 +31,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -50,6 +56,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "submissionTime" : "2015-02-03T16:43:04.228GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", + "completionTime" : "2015-02-03T16:43:04.819GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -72,6 +81,9 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "submissionTime" : "2015-02-03T16:43:06.296GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", + "completionTime" : "2015-02-03T16:43:06.347GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -86,4 +98,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json index 79ccacd309693..afa425f8c27bb 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "submissionTime" : "2015-03-16T19:25:36.103GMT", + "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", + "completionTime" : "2015-03-16T19:25:36.579GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -24,4 +27,4 @@ "name" : "my counter", "value" : "5050" } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 32d5731676ad5..12665a152c9ec 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "submissionTime" : "2015-03-16T19:25:36.103GMT", + "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", + "completionTime" : "2015-03-16T19:25:36.579GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -239,4 +242,4 @@ "diskBytesSpilled" : 0 } } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala new file mode 100644 index 0000000000000..88817dccf3497 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1 + +import java.util.Date + +import scala.collection.mutable.HashMap + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.{StageInfo, TaskInfo, TaskLocality} +import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData} + +class AllStagesResourceSuite extends SparkFunSuite { + + def getFirstTaskLaunchTime(taskLaunchTimes: Seq[Long]): Option[Date] = { + val tasks = new HashMap[Long, TaskUIData] + taskLaunchTimes.zipWithIndex.foreach { case (time, idx) => + tasks(idx.toLong) = new TaskUIData( + new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false), None, None) + } + + val stageUiData = new StageUIData() + stageUiData.taskData = tasks + val status = StageStatus.ACTIVE + val stageInfo = new StageInfo( + 1, 1, "stage 1", 10, Seq.empty, Seq.empty, "details abc", Seq.empty) + val stageData = AllStagesResource.stageUiToStageData(status, stageInfo, stageUiData, false) + + stageData.firstTaskLaunchedTime + } + + test("firstTaskLaunchedTime when there are no tasks") { + val result = getFirstTaskLaunchTime(Seq()) + assert(result == None) + } + + test("firstTaskLaunchedTime when there are tasks but none launched") { + val result = getFirstTaskLaunchTime(Seq(-100L, -200L, -300L)) + assert(result == None) + } + + test("firstTaskLaunchedTime when there are tasks and some launched") { + val result = getFirstTaskLaunchTime(Seq(-100L, 1449255596000L, 1449255597000L)) + assert(result == Some(new Date(1449255596000L))) + } + +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b4aa6adc3c620..685cb419ca8a7 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -132,7 +132,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$") ) ++ Seq ( ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.ApplicationInfo.this") + "org.apache.spark.status.api.v1.ApplicationInfo.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.StageData.this") ) ++ Seq( // SPARK-11766 add toJson to Vector ProblemFilters.exclude[MissingMethodProblem]( From 303e6f26b87c1c8eee4a186d469ceaef4a92613c Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 8 Dec 2015 10:01:44 -0800 Subject: [PATCH 1072/1168] [SPARK-12074] Avoid memory copy involving ByteBuffer.wrap(ByteArrayOutputStream.toByteArray) SPARK-12060 fixed JavaSerializerInstance.serialize This PR applies the same technique on two other classes. zsxwing Author: tedyu Closes #10177 from tedyu/master. --- core/src/main/scala/org/apache/spark/scheduler/Task.scala | 7 +++---- .../main/scala/org/apache/spark/storage/BlockManager.scala | 4 ++-- .../org/apache/spark/util/ByteBufferOutputStream.scala | 4 +++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5fe5ae8c45819..d4bc3a5c900f7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -27,8 +27,7 @@ import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.ByteBufferInputStream -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} /** @@ -172,7 +171,7 @@ private[spark] object Task { serializer: SerializerInstance) : ByteBuffer = { - val out = new ByteArrayOutputStream(4096) + val out = new ByteBufferOutputStream(4096) val dataOut = new DataOutputStream(out) // Write currentFiles @@ -193,7 +192,7 @@ private[spark] object Task { dataOut.flush() val taskBytes = serializer.serialize(task) Utils.writeByteBuffer(taskBytes, out) - ByteBuffer.wrap(out.toByteArray) + out.toByteBuffer } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ab0007fb78993..ed05143877e20 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1202,9 +1202,9 @@ private[spark] class BlockManager( blockId: BlockId, values: Iterator[Any], serializer: Serializer = defaultSerializer): ByteBuffer = { - val byteStream = new ByteArrayOutputStream(4096) + val byteStream = new ByteBufferOutputStream(4096) dataSerializeStream(blockId, byteStream, values, serializer) - ByteBuffer.wrap(byteStream.toByteArray) + byteStream.toByteBuffer } /** diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala index 92e45224db81c..8527e3ae692e2 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -23,7 +23,9 @@ import java.nio.ByteBuffer /** * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer */ -private[spark] class ByteBufferOutputStream extends ByteArrayOutputStream { +private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutputStream(capacity) { + + def this() = this(32) def toByteBuffer: ByteBuffer = { return ByteBuffer.wrap(buf, 0, count) From a551f53e14cd093b307ee1843678b85473ead02e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 8 Dec 2015 10:13:40 -0800 Subject: [PATCH 1073/1168] [SPARK-12201][SQL] add type coercion rule for greatest/least checked with hive, greatest/least should cast their children to a tightest common type, i.e. `(int, long) => long`, `(int, string) => error`, `(decimal(10,5), decimal(5, 10)) => error` Author: Wenchen Fan Closes #10196 from cloud-fan/type-coercion. --- .../catalyst/analysis/HiveTypeCoercion.scala | 14 +++++++++++ .../ExpressionTypeCheckingSuite.scala | 10 ++++++++ .../analysis/HiveTypeCoercionSuite.scala | 23 +++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 29502a59915f0..dbcbd6854b474 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -594,6 +594,20 @@ object HiveTypeCoercion { case None => c } + case g @ Greatest(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonType(types) match { + case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) + case None => g + } + + case l @ Least(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonType(types) match { + case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) + case None => l + } + case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType => NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ba1866efc84e1..915c585ec91fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -32,6 +32,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { 'intField.int, 'stringField.string, 'booleanField.boolean, + 'decimalField.decimal(8, 0), 'arrayField.array(StringType), 'mapField.map(StringType, LongType)) @@ -189,4 +190,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Round('intField, 'mapField), "requires int type") assertError(Round('booleanField, 'intField), "requires numeric type") } + + test("check types for Greatest/Least") { + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + assertError(operator(Seq('booleanField)), "requires at least 2 arguments") + assertError(operator(Seq('intField, 'stringField)), "should all have the same type") + assertError(operator(Seq('intField, 'decimalField)), "should all have the same type") + assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index d3fafaae89938..142915056f451 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -251,6 +251,29 @@ class HiveTypeCoercionSuite extends PlanTest { :: Nil)) } + test("greatest/least cast") { + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + operator(Literal(1.0) + :: Literal(1) + :: Literal.create(1.0, FloatType) + :: Nil), + operator(Cast(Literal(1.0), DoubleType) + :: Cast(Literal(1), DoubleType) + :: Cast(Literal.create(1.0, FloatType), DoubleType) + :: Nil)) + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + operator(Literal(1L) + :: Literal(1) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) + :: Nil), + operator(Cast(Literal(1L), DecimalType(22, 0)) + :: Cast(Literal(1), DecimalType(22, 0)) + :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) + :: Nil)) + } + } + test("nanvl casts") { ruleTest(HiveTypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), From e0accf0b211830f1b8899e70707cd15072b3bf79 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 8 Dec 2015 10:15:58 -0800 Subject: [PATCH 1074/1168] [SPARK-12195][SQL] Adding BigDecimal, Date and Timestamp into Encoder This PR is to add three more data types into Encoder, including `BigDecimal`, `Date` and `Timestamp`. marmbrus cloud-fan rxin Could you take a quick look at these three types? Not sure if it can be merged to 1.6. Thank you very much! Author: gatorsmile Closes #10188 from gatorsmile/dataTypesinEncoder. --- .../scala/org/apache/spark/sql/Encoder.scala | 18 ++++++++++++++++++ .../org/apache/spark/sql/JavaDatasetSuite.java | 17 +++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index c40061ae0aafd..3ca5ade7f30f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -97,6 +97,24 @@ object Encoders { */ def STRING: Encoder[java.lang.String] = ExpressionEncoder() + /** + * An encoder for nullable decimal type. + * @since 1.6.0 + */ + def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder() + + /** + * An encoder for nullable date type. + * @since 1.6.0 + */ + def DATE: Encoder[java.sql.Date] = ExpressionEncoder() + + /** + * An encoder for nullable timestamp type. + * @since 1.6.0 + */ + def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() + /** * Creates an encoder for Java Bean of type T. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index ae47f4fe0e231..383a2d0badb53 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -18,6 +18,9 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; import java.util.*; import scala.Tuple2; @@ -385,6 +388,20 @@ public void testNestedTupleEncoder() { Assert.assertEquals(data3, ds3.collectAsList()); } + @Test + public void testPrimitiveEncoder() { + Encoder> encoder = + Encoders.tuple(Encoders.DOUBLE(), Encoders.DECIMAL(), Encoders.DATE(), Encoders.TIMESTAMP(), + Encoders.FLOAT()); + List> data = + Arrays.asList(new Tuple5( + 1.7976931348623157E308, new BigDecimal("0.922337203685477589"), + Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE)); + Dataset> ds = + context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + @Test public void testTypedAggregation() { Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); From 4426d9507faf466786a903f7a884f6dcf2f8cc2b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 8 Dec 2015 10:25:57 -0800 Subject: [PATCH 1075/1168] [SPARK-12188][SQL] Code refactoring and comment correction in Dataset APIs This PR contains the following updates: - Created a new private variable `boundTEncoder` that can be shared by multiple functions, `RDD`, `select` and `collect`. - Replaced all the `queryExecution.analyzed` by the function call `logicalPlan` - A few API comments are using wrong class names (e.g., `DataFrame`) or parameter names (e.g., `n`) - A few API descriptions are wrong. (e.g., `mapPartitions`) marmbrus rxin cloud-fan Could you take a look and check if they are appropriate? Thank you! Author: gatorsmile Closes #10184 from gatorsmile/datasetClean. --- .../scala/org/apache/spark/sql/Dataset.scala | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d6bb1d2ad8e50..3bd18a14f9e8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -67,15 +67,21 @@ class Dataset[T] private[sql]( tEncoder: Encoder[T]) extends Queryable with Serializable { /** - * An unresolved version of the internal encoder for the type of this dataset. This one is marked - * implicit so that we can use it when constructing new [[Dataset]] objects that have the same - * object type (that will be possibly resolved to a different schema). + * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is + * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the + * same object type (that will be possibly resolved to a different schema). */ private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes) + unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes) + + /** + * The encoder where the expressions used to construct an object from an input row have been + * bound to the ordinals of the given schema. + */ + private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) private implicit def classTag = resolvedTEncoder.clsTag @@ -89,7 +95,7 @@ class Dataset[T] private[sql]( override def schema: StructType = resolvedTEncoder.schema /** - * Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format. + * Prints the schema of the underlying [[Dataset]] to the console in a nice tree format. * @since 1.6.0 */ override def printSchema(): Unit = toDF().printSchema() @@ -111,7 +117,7 @@ class Dataset[T] private[sql]( * ************* */ /** - * Returns a new `Dataset` where each record has been mapped on to the specified type. The + * Returns a new [[Dataset]] where each record has been mapped on to the specified type. The * method used to map columns depend on the type of `U`: * - When `U` is a class, fields for the class will be mapped to columns of the same name * (case sensitivity is determined by `spark.sql.caseSensitive`) @@ -145,7 +151,7 @@ class Dataset[T] private[sql]( def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan) /** - * Returns this Dataset. + * Returns this [[Dataset]]. * @since 1.6.0 */ // This is declared with parentheses to prevent the Scala compiler from treating @@ -153,15 +159,12 @@ class Dataset[T] private[sql]( def toDS(): Dataset[T] = this /** - * Converts this Dataset to an RDD. + * Converts this [[Dataset]] to an [[RDD]]. * @since 1.6.0 */ def rdd: RDD[T] = { - val tEnc = resolvedTEncoder - val input = queryExecution.analyzed.output queryExecution.toRdd.mapPartitions { iter => - val bound = tEnc.bind(input) - iter.map(bound.fromRow) + iter.map(boundTEncoder.fromRow) } } @@ -189,7 +192,7 @@ class Dataset[T] private[sql]( def show(numRows: Int): Unit = show(numRows, truncate = true) /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters + * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters * will be truncated, and all cells will be aligned right. * * @since 1.6.0 @@ -197,7 +200,7 @@ class Dataset[T] private[sql]( def show(): Unit = show(20) /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * Displays the top 20 rows of [[Dataset]] in a tabular form. * * @param truncate Whether truncate long strings. If true, strings more than 20 characters will * be truncated and all cells will be aligned right @@ -207,7 +210,7 @@ class Dataset[T] private[sql]( def show(truncate: Boolean): Unit = show(20, truncate) /** - * Displays the [[DataFrame]] in a tabular form. For example: + * Displays the [[Dataset]] in a tabular form. For example: * {{{ * year month AVG('Adj Close) MAX('Adj Close) * 1980 12 0.503218 0.595103 @@ -291,7 +294,7 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. * @since 1.6.0 */ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { @@ -307,7 +310,7 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. * @since 1.6.0 */ def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { @@ -341,28 +344,28 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Runs `func` on each element of this Dataset. + * Runs `func` on each element of this [[Dataset]]. * @since 1.6.0 */ def foreach(func: T => Unit): Unit = rdd.foreach(func) /** * (Java-specific) - * Runs `func` on each element of this Dataset. + * Runs `func` on each element of this [[Dataset]]. * @since 1.6.0 */ def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_)) /** * (Scala-specific) - * Runs `func` on each partition of this Dataset. + * Runs `func` on each partition of this [[Dataset]]. * @since 1.6.0 */ def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func) /** * (Java-specific) - * Runs `func` on each partition of this Dataset. + * Runs `func` on each partition of this [[Dataset]]. * @since 1.6.0 */ def foreachPartition(func: ForeachPartitionFunction[T]): Unit = @@ -374,7 +377,7 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func` * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ @@ -382,7 +385,7 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this Dataset using the specified binary function. The given `func` * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ @@ -390,11 +393,11 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key function. + * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. * @since 1.6.0 */ def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { - val inputPlan = queryExecution.analyzed + val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) @@ -429,18 +432,18 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key function. + * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. * @since 1.6.0 */ - def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = - groupBy(f.call(_))(encoder) + def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + groupBy(func.call(_))(encoder) /* ****************** * * Typed Relational * * ****************** */ /** - * Selects a set of column based expressions. + * Returns a new [[DataFrame]] by selecting a set of column based expressions. * {{{ * df.select($"colA", $"colB" + 1) * }}} @@ -464,8 +467,8 @@ class Dataset[T] private[sql]( sqlContext, Project( c1.withInputType( - resolvedTEncoder.bind(queryExecution.analyzed.output), - queryExecution.analyzed.output).named :: Nil, + boundTEncoder, + logicalPlan.output).named :: Nil, logicalPlan)) } @@ -477,7 +480,7 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named) + columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named) val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) @@ -654,7 +657,7 @@ class Dataset[T] private[sql]( * Returns an array that contains all the elements in this [[Dataset]]. * * Running collect requires moving all the data into the application's driver process, and - * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError. * * For Java API, use [[collectAsList]]. * @since 1.6.0 @@ -662,17 +665,14 @@ class Dataset[T] private[sql]( def collect(): Array[T] = { // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders // to convert the rows into objects of type T. - val tEnc = resolvedTEncoder - val input = queryExecution.analyzed.output - val bound = tEnc.bind(input) - queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow) + queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow) } /** * Returns an array that contains all the elements in this [[Dataset]]. * * Running collect requires moving all the data into the application's driver process, and - * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError. * * For Java API, use [[collectAsList]]. * @since 1.6.0 @@ -683,7 +683,7 @@ class Dataset[T] private[sql]( * Returns the first `num` elements of this [[Dataset]] as an array. * * Running take requires moving data into the application's driver process, and doing so with - * a very large `n` can crash the driver process with OutOfMemoryError. + * a very large `num` can crash the driver process with OutOfMemoryError. * @since 1.6.0 */ def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect() @@ -692,7 +692,7 @@ class Dataset[T] private[sql]( * Returns the first `num` elements of this [[Dataset]] as an array. * * Running take requires moving data into the application's driver process, and doing so with - * a very large `n` can crash the driver process with OutOfMemoryError. + * a very large `num` can crash the driver process with OutOfMemoryError. * @since 1.6.0 */ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) From 526862cf160ed103559f18e00a4de97c1dcf87c3 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 8 Dec 2015 10:29:51 -0800 Subject: [PATCH 1076/1168] [SPARK-10393] use ML pipeline in LDA example jira: https://issues.apache.org/jira/browse/SPARK-10393 Since the logic of the text processing part has been moved to ML estimators/transformers, replace the related code in LDA Example with the ML pipeline. Author: Yuhao Yang Author: yuhaoyang Closes #8551 from hhbyyh/ldaExUpdate. --- .../spark/examples/mllib/LDAExample.scala | 153 +++++------------- 1 file changed, 40 insertions(+), 113 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 75b0f69cf91aa..70010b05e4345 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -18,19 +18,16 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import java.text.BreakIterator - -import scala.collection.mutable - import scopt.OptionParser import org.apache.log4j.{Level, Logger} - -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA} -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover} +import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD - +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.{SparkConf, SparkContext} /** * An example Latent Dirichlet Allocation (LDA) app. Run with @@ -192,115 +189,45 @@ object LDAExample { vocabSize: Int, stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + // Get dataset of document texts // One document per line in each text file. If the input consists of many small files, // this can result in a large number of small partitions, which can degrade performance. // In this case, consider using coalesce() to create fewer, larger partitions. - val textRDD: RDD[String] = sc.textFile(paths.mkString(",")) - - // Split text into words - val tokenizer = new SimpleTokenizer(sc, stopwordFile) - val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) => - id -> tokenizer.getWords(text) - } - tokenized.cache() - - // Counts words: RDD[(word, wordCount)] - val wordCounts: RDD[(String, Long)] = tokenized - .flatMap { case (_, tokens) => tokens.map(_ -> 1L) } - .reduceByKey(_ + _) - wordCounts.cache() - val fullVocabSize = wordCounts.count() - // Select vocab - // (vocab: Map[word -> id], total tokens after selecting vocab) - val (vocab: Map[String, Int], selectedTokenCount: Long) = { - val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) { - // Use all terms - wordCounts.collect().sortBy(-_._2) - } else { - // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take(vocabSize) - } - (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum) - } - - val documents = tokenized.map { case (id, tokens) => - // Filter tokens by vocabulary, and create word count vector representation of document. - val wc = new mutable.HashMap[Int, Int]() - tokens.foreach { term => - if (vocab.contains(term)) { - val termIndex = vocab(term) - wc(termIndex) = wc.getOrElse(termIndex, 0) + 1 - } - } - val indices = wc.keys.toArray.sorted - val values = indices.map(i => wc(i).toDouble) - - val sb = Vectors.sparse(vocab.size, indices, values) - (id, sb) - } - - val vocabArray = new Array[String](vocab.size) - vocab.foreach { case (term, i) => vocabArray(i) = term } - - (documents, vocabArray, selectedTokenCount) - } -} - -/** - * Simple Tokenizer. - * - * TODO: Formalize the interface, and make this a public class in mllib.feature - */ -private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable { - - private val stopwords: Set[String] = if (stopwordFile.isEmpty) { - Set.empty[String] - } else { - val stopwordText = sc.textFile(stopwordFile).collect() - stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet - } - - // Matches sequences of Unicode letters - private val allWordRegex = "^(\\p{L}*)$".r - - // Ignore words shorter than this length. - private val minWordLength = 3 - - def getWords(text: String): IndexedSeq[String] = { - - val words = new mutable.ArrayBuffer[String]() - - // Use Java BreakIterator to tokenize text into words. - val wb = BreakIterator.getWordInstance - wb.setText(text) - - // current,end index start,end of each word - var current = wb.first() - var end = wb.next() - while (end != BreakIterator.DONE) { - // Convert to lowercase - val word: String = text.substring(current, end).toLowerCase - // Remove short words and strings that aren't only letters - word match { - case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) => - words += w - case _ => - } - - current = end - try { - end = wb.next() - } catch { - case e: Exception => - // Ignore remaining text in line. - // This is a known bug in BreakIterator (for some Java versions), - // which fails when it sees certain characters. - end = BreakIterator.DONE - } + val df = sc.textFile(paths.mkString(",")).toDF("docs") + val customizedStopWords: Array[String] = if (stopwordFile.isEmpty) { + Array.empty[String] + } else { + val stopWordText = sc.textFile(stopwordFile).collect() + stopWordText.flatMap(_.stripMargin.split("\\s+")) } - words + val tokenizer = new RegexTokenizer() + .setInputCol("docs") + .setOutputCol("rawTokens") + val stopWordsRemover = new StopWordsRemover() + .setInputCol("rawTokens") + .setOutputCol("tokens") + stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords) + val countVectorizer = new CountVectorizer() + .setVocabSize(vocabSize) + .setInputCol("tokens") + .setOutputCol("features") + + val pipeline = new Pipeline() + .setStages(Array(tokenizer, stopWordsRemover, countVectorizer)) + + val model = pipeline.fit(df) + val documents = model.transform(df) + .select("features") + .map { case Row(features: Vector) => features } + .zipWithIndex() + .map(_.swap) + + (documents, + model.stages(2).asInstanceOf[CountVectorizerModel].vocabulary, // vocabulary + documents.map(_._2.numActives).sum().toLong) // total token count } - } // scalastyle:on println From bf0176bb4f490560d0f106699c9e8596e319217f Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Tue, 8 Dec 2015 10:52:17 -0800 Subject: [PATCH 1077/1168] [SPARK-12205][SQL] Pivot fails Analysis when aggregate is UnresolvedFunction Delays application of ResolvePivot until all aggregates are resolved to prevent problems with UnresolvedFunction and adds unit test Author: Andrew Ray Closes #10202 from aray/sql-pivot-unresolved-function. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../scala/org/apache/spark/sql/DataFramePivotSuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d3163dcd4db94..ca00a5e49f668 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -259,7 +259,7 @@ class Analyzer( object ResolvePivot extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Pivot if !p.childrenResolved => p + case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) => p case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => val singleAgg = aggregates.size == 1 val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index fc53aba68ebb7..bc1a336ea4fd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -85,4 +85,12 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) } + + test("pivot with UnresolvedFunction") { + checkAnswer( + courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) + .agg("earnings" -> "sum"), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } } From b82de2061f8c6e2c5ee79bf05da5b62eddec94a3 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 8 Dec 2015 11:46:26 -0800 Subject: [PATCH 1078/1168] [SPARK-11605][MLLIB] ML 1.6 QA: API: Java compatibility, docs jira: https://issues.apache.org/jira/browse/SPARK-11605 Check Java compatibility for MLlib for this release. fix: 1. `StreamingTest.registerStream` needs java friendly interface. 2. `GradientBoostedTreesModel.computeInitialPredictionAndError` and `GradientBoostedTreesModel.updatePredictionError` has java compatibility issue. Mark them as `developerAPI`. TBD: [updated] no fix for now per discussion. `org.apache.spark.mllib.classification.LogisticRegressionModel` `public scala.Option getThreshold();` has wrong return type for Java invocation. `SVMModel` has the similar issue. Yet adding a `scala.Option getThreshold()` would result in an overloading error due to the same function signature. And adding a new function with different name seems to be not necessary. cc jkbradley feynmanliang Author: Yuhao Yang Closes #10102 from hhbyyh/javaAPI. --- .../examples/mllib/StreamingTestExample.scala | 4 +- .../spark/mllib/stat/test/StreamingTest.scala | 50 +++++++++++++++---- .../mllib/tree/model/treeEnsembleModels.scala | 6 ++- .../spark/mllib/stat/JavaStatisticsSuite.java | 38 ++++++++++++-- .../spark/mllib/stat/StreamingTestSuite.scala | 25 +++++----- 5 files changed, 96 insertions(+), 27 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala index b6677c6476639..49f5df39443e9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -18,7 +18,7 @@ package org.apache.spark.examples.mllib import org.apache.spark.SparkConf -import org.apache.spark.mllib.stat.test.StreamingTest +import org.apache.spark.mllib.stat.test.{BinarySample, StreamingTest} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.util.Utils @@ -66,7 +66,7 @@ object StreamingTestExample { // $example on$ val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { - case Array(label, value) => (label.toBoolean, value.toDouble) + case Array(label, value) => BinarySample(label.toBoolean, value.toDouble) }) val streamingTest = new StreamingTest() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala index 75c6a51d09571..e990fe0768bc9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -17,12 +17,30 @@ package org.apache.spark.mllib.stat.test +import scala.beans.BeanInfo + import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.api.java.JavaDStream import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter +/** + * Class that represents the group and value of a sample. + * + * @param isExperiment if the sample is of the experiment group. + * @param value numeric value of the observation. + */ +@Since("1.6.0") +@BeanInfo +case class BinarySample @Since("1.6.0") ( + @Since("1.6.0") isExperiment: Boolean, + @Since("1.6.0") value: Double) { + override def toString: String = { + s"($isExperiment, $value)" + } +} + /** * :: Experimental :: * Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The @@ -83,13 +101,13 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { /** * Register a [[DStream]] of values for significance testing. * - * @param data stream of (key,value) pairs where the key denotes group membership (true = - * experiment, false = control) and the value is the numerical metric to test for - * significance + * @param data stream of BinarySample(key,value) pairs where the key denotes group membership + * (true = experiment, false = control) and the value is the numerical metric to + * test for significance * @return stream of significance testing results */ @Since("1.6.0") - def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = { + def registerStream(data: DStream[BinarySample]): DStream[StreamingTestResult] = { val dataAfterPeacePeriod = dropPeacePeriod(data) val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod) val pairedSummaries = pairSummaries(summarizedData) @@ -97,9 +115,22 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { testMethod.doTest(pairedSummaries) } + /** + * Register a [[JavaDStream]] of values for significance testing. + * + * @param data stream of BinarySample(isExperiment,value) pairs where the isExperiment denotes + * group (true = experiment, false = control) and the value is the numerical metric + * to test for significance + * @return stream of significance testing results + */ + @Since("1.6.0") + def registerStream(data: JavaDStream[BinarySample]): JavaDStream[StreamingTestResult] = { + JavaDStream.fromDStream(registerStream(data.dstream)) + } + /** Drop all batches inside the peace period. */ private[stat] def dropPeacePeriod( - data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = { + data: DStream[BinarySample]): DStream[BinarySample] = { data.transform { (rdd, time) => if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) { rdd @@ -111,9 +142,10 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { /** Compute summary statistics over each key and the specified test window size. */ private[stat] def summarizeByKeyAndWindow( - data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = { + data: DStream[BinarySample]): DStream[(Boolean, StatCounter)] = { + val categoryValuePair = data.map(sample => (sample.isExperiment, sample.value)) if (this.windowSize == 0) { - data.updateStateByKey[StatCounter]( + categoryValuePair.updateStateByKey[StatCounter]( (newValues: Seq[Double], oldSummary: Option[StatCounter]) => { val newSummary = oldSummary.getOrElse(new StatCounter()) newSummary.merge(newValues) @@ -121,7 +153,7 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { }) } else { val windowDuration = data.slideDuration * this.windowSize - data + categoryValuePair .groupByKeyAndWindow(windowDuration) .mapValues { values => val summary = new StatCounter() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 3f427f0be3af2..feabcee24fa2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -25,7 +25,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.annotation.Since +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -186,6 +186,7 @@ class GradientBoostedTreesModel @Since("1.2.0") ( object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { /** + * :: DeveloperApi :: * Compute the initial predictions and errors for a dataset for the first * iteration of gradient boosting. * @param data: training data. @@ -196,6 +197,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * corresponding to every sample. */ @Since("1.4.0") + @DeveloperApi def computeInitialPredictionAndError( data: RDD[LabeledPoint], initTreeWeight: Double, @@ -209,6 +211,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { } /** + * :: DeveloperApi :: * Update a zipped predictionError RDD * (as obtained with computeInitialPredictionAndError) * @param data: training data. @@ -220,6 +223,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * corresponding to each sample. */ @Since("1.4.0") + @DeveloperApi def updatePredictionError( data: RDD[LabeledPoint], predictionAndError: RDD[(Double, Double)], diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 4795809e47a46..66b2ceacb05f2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -18,34 +18,49 @@ package org.apache.spark.mllib.stat; import java.io.Serializable; - import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; +import static org.apache.spark.streaming.JavaTestUtils.*; import static org.junit.Assert.assertEquals; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.stat.test.BinarySample; import org.apache.spark.mllib.stat.test.ChiSqTestResult; import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; +import org.apache.spark.mllib.stat.test.StreamingTest; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; public class JavaStatisticsSuite implements Serializable { private transient JavaSparkContext sc; + private transient JavaStreamingContext ssc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaStatistics"); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("JavaStatistics") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + sc = new JavaSparkContext(conf); + ssc = new JavaStreamingContext(sc, new Duration(1000)); + ssc.checkpoint("checkpoint"); } @After public void tearDown() { - sc.stop(); + ssc.stop(); + ssc = null; sc = null; } @@ -76,4 +91,21 @@ public void chiSqTest() { new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); ChiSqTestResult[] testResults = Statistics.chiSqTest(data); } + + @Test + public void streamingTest() { + List trainingBatch = Arrays.asList( + new BinarySample(true, 1.0), + new BinarySample(false, 2.0)); + JavaDStream training = + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + int numBatches = 2; + StreamingTest model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod("welch"); + model.registerStream(training); + attachTestOutputStream(training); + runStreams(ssc, numBatches, numBatches); + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala index d3e9ef4ff079c..3c657c8cfe743 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.mllib.stat import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest} +import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, + WelchTTest, BinarySample} import org.apache.spark.streaming.TestSuiteBase import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter @@ -48,7 +49,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => @@ -75,7 +76,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => @@ -102,7 +103,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) @@ -130,7 +131,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => @@ -157,7 +158,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( input, - (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream)) + (inputDStream: DStream[BinarySample]) => model.summarizeByKeyAndWindow(inputDStream)) val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches) val outputCounts = outputBatches.flatten.map(_._2.count) @@ -190,7 +191,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.dropPeacePeriod(inputDStream)) val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches) assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch) @@ -210,11 +211,11 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { .setPeacePeriod(0) val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) - .map(batch => batch.filter(_._1)) // only keep one test group + .map(batch => batch.filter(_.isExperiment)) // only keep one test group // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001)) @@ -228,13 +229,13 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { stdevA: Double, meanB: Double, stdevB: Double, - seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = { + seed: Int): (IndexedSeq[IndexedSeq[BinarySample]]) = { val rand = new XORShiftRandom(seed) val numTrues = pointsPerBatch / 2 val data = (0 until numBatches).map { i => - (0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++ + (0 until numTrues).map { idx => BinarySample(true, meanA + stdevA * rand.nextGaussian())} ++ (pointsPerBatch / 2 until pointsPerBatch).map { idx => - (false, meanB + stdevB * rand.nextGaussian()) + BinarySample(false, meanB + stdevB * rand.nextGaussian()) } } From add793d35347ac9b3783aa43a1d3ad93a6066d39 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Tue, 8 Dec 2015 12:45:34 -0800 Subject: [PATCH 1079/1168] [SPARK-12159][ML] Add user guide section for IndexToString transformer Documentation regarding the `IndexToString` label transformer with code snippets in Scala/Java/Python. Author: BenFradet Closes #10166 from BenFradet/SPARK-12159. --- docs/ml-features.md | 104 +++++++++++++++--- .../examples/ml/JavaIndexToStringExample.java | 75 +++++++++++++ .../main/python/ml/index_to_string_example.py | 45 ++++++++ .../examples/ml/IndexToStringExample.scala | 60 ++++++++++ 4 files changed, 268 insertions(+), 16 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java create mode 100644 examples/src/main/python/ml/index_to_string_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 01d6abeb5ba6a..e15c26836affc 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -835,10 +835,10 @@ dctDf.select("featuresDCT").show(3); `StringIndexer` encodes a string column of labels to a column of label indices. The indices are in `[0, numLabels)`, ordered by label frequencies. So the most frequent label gets index `0`. -If the input column is numeric, we cast it to string and index the string -values. When downstream pipeline components such as `Estimator` or -`Transformer` make use of this string-indexed label, you must set the input -column of the component to this string-indexed column name. In many cases, +If the input column is numeric, we cast it to string and index the string +values. When downstream pipeline components such as `Estimator` or +`Transformer` make use of this string-indexed label, you must set the input +column of the component to this string-indexed column name. In many cases, you can set the input column with `setInputCol`. **Examples** @@ -951,9 +951,78 @@ indexed.show() + +## IndexToString + +Symmetrically to `StringIndexer`, `IndexToString` maps a column of label indices +back to a column containing the original labels as strings. The common use case +is to produce indices from labels with `StringIndexer`, train a model with those +indices and retrieve the original labels from the column of predicted indices +with `IndexToString`. However, you are free to supply your own labels. + +**Examples** + +Building on the `StringIndexer` example, let's assume we have the following +DataFrame with columns `id` and `categoryIndex`: + +~~~~ + id | categoryIndex +----|--------------- + 0 | 0.0 + 1 | 2.0 + 2 | 1.0 + 3 | 0.0 + 4 | 0.0 + 5 | 1.0 +~~~~ + +Applying `IndexToString` with `categoryIndex` as the input column, +`originalCategory` as the output column, we are able to retrieve our original +labels (they will be inferred from the columns' metadata): + +~~~~ + id | categoryIndex | originalCategory +----|---------------|----------------- + 0 | 0.0 | a + 1 | 2.0 | b + 2 | 1.0 | c + 3 | 0.0 | a + 4 | 0.0 | a + 5 | 1.0 | c +~~~~ + +
    +
    + +Refer to the [IndexToString Scala docs](api/scala/index.html#org.apache.spark.ml.feature.IndexToString) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/IndexToStringExample.scala %} + +
    + +
    + +Refer to the [IndexToString Java docs](api/java/org/apache/spark/ml/feature/IndexToString.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaIndexToStringExample.java %} + +
    + +
    + +Refer to the [IndexToString Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IndexToString) +for more details on the API. + +{% include_example python/ml/index_to_string_example.py %} + +
    +
    + ## OneHotEncoder -[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features +[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features
    @@ -979,10 +1048,11 @@ val indexer = new StringIndexer() .fit(df) val indexed = indexer.transform(df) -val encoder = new OneHotEncoder().setInputCol("categoryIndex"). - setOutputCol("categoryVec") +val encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec") val encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").foreach(println) +encoded.select("id", "categoryVec").show() {% endhighlight %}
    @@ -1015,7 +1085,7 @@ JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(5, "c") )); StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("category", DataTypes.StringType, false, Metadata.empty()) }); DataFrame df = sqlContext.createDataFrame(jrdd, schema); @@ -1029,6 +1099,7 @@ OneHotEncoder encoder = new OneHotEncoder() .setInputCol("categoryIndex") .setOutputCol("categoryVec"); DataFrame encoded = encoder.transform(indexed); +encoded.select("id", "categoryVec").show(); {% endhighlight %}
    @@ -1054,6 +1125,7 @@ model = stringIndexer.fit(df) indexed = model.transform(df) encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") encoded = encoder.transform(indexed) +encoded.select("id", "categoryVec").show() {% endhighlight %} @@ -1582,7 +1654,7 @@ from pyspark.mllib.linalg import Vectors data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] df = sqlContext.createDataFrame(data, ["vector"]) -transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), +transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), inputCol="vector", outputCol="transformedVector") transformer.transform(df).show() @@ -1837,15 +1909,15 @@ for more details on the API. sub-array of the original features. It is useful for extracting features from a vector column. `VectorSlicer` accepts a vector column with a specified indices, then outputs a new vector column -whose values are selected via those indices. There are two types of indices, +whose values are selected via those indices. There are two types of indices, 1. Integer indices that represents the indices into the vector, `setIndices()`; - 2. String indices that represents the names of features into the vector, `setNames()`. + 2. String indices that represents the names of features into the vector, `setNames()`. *This requires the vector column to have an `AttributeGroup` since the implementation matches on the name field of an `Attribute`.* -Specification by integer and string are both acceptable. Moreover, you can use integer index and +Specification by integer and string are both acceptable. Moreover, you can use integer index and string name simultaneously. At least one feature must be selected. Duplicate features are not allowed, so there can be no overlap between selected indices and names. Note that if names of features are selected, an exception will be threw out when encountering with empty input attributes. @@ -1858,9 +1930,9 @@ followed by the selected names (in the order given). Suppose that we have a DataFrame with the column `userFeatures`: ~~~ - userFeatures + userFeatures ------------------ - [0.0, 10.0, 0.5] + [0.0, 10.0, 0.5] ~~~ `userFeatures` is a vector column that contains three user features. Assuming that the first column @@ -1874,7 +1946,7 @@ column named `features`: [0.0, 10.0, 0.5] | [10.0, 0.5] ~~~ -Suppose also that we have a potential input attributes for the `userFeatures`, i.e. +Suppose also that we have a potential input attributes for the `userFeatures`, i.e. `["f1", "f2", "f3"]`, then we can use `setNames("f2", "f3")` to select them. ~~~ diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java new file mode 100644 index 0000000000000..3ccd6993261e2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.IndexToString; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaIndexToStringExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaIndexToStringExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); + DataFrame indexed = indexer.transform(df); + + IndexToString converter = new IndexToString() + .setInputCol("categoryIndex") + .setOutputCol("originalCategory"); + DataFrame converted = converter.transform(indexed); + converted.select("id", "originalCategory").show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/python/ml/index_to_string_example.py b/examples/src/main/python/ml/index_to_string_example.py new file mode 100644 index 0000000000000..fb0ba2950bbd6 --- /dev/null +++ b/examples/src/main/python/ml/index_to_string_example.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.feature import IndexToString, StringIndexer +# $example off$ +from pyspark.sql import SQLContext + +if __name__ == "__main__": + sc = SparkContext(appName="IndexToStringExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + + converter = IndexToString(inputCol="categoryIndex", outputCol="originalCategory") + converted = converter.transform(indexed) + + converted.select("id", "originalCategory").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala new file mode 100644 index 0000000000000..52537e5bb568d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.{StringIndexer, IndexToString} +// $example off$ + +object IndexToStringExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("IndexToStringExample") + val sc = new SparkContext(conf) + + val sqlContext = SQLContext.getOrCreate(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + )).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) + val indexed = indexer.transform(df) + + val converter = new IndexToString() + .setInputCol("categoryIndex") + .setOutputCol("originalCategory") + + val converted = converter.transform(indexed) + converted.select("id", "originalCategory").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println From 3fd7988e2b6baa16f854a75aff3d214b1d11356c Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 8 Dec 2015 13:13:56 -0800 Subject: [PATCH 1080/1168] [SPARK-3873][BUILD] Add style checker to enforce import ordering. The checker tries to follow as closely as possible the guidelines of the code style document, and makes some decisions where the guide is not clear. In particular: - wildcard imports come first when there are other imports in the same package - multi-import blocks come before single imports - lower-case names inside multi-import blocks come before others In some projects, such as graphx, there seems to be a convention to separate o.a.s imports from the project's own; to simplify the checker, I chose not to allow that, which is a strict interpretation of the code style guide, even though I think it makes sense. Since the checks are based on syntax only, some edge cases may generate spurious warnings; for example, when class names start with a lower case letter (and are thus treated as a package name by the checker). The checker is currently only generating warnings, and since there are many of those, the build output does get a little noisy. The idea is to fix the code (and the checker, as needed) little by little instead of having a huge change that touches everywhere. Author: Marcelo Vanzin Closes #6502 from vanzin/SPARK-3873. --- scalastyle-config.xml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 050c3f360476f..dab1ebddc666e 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -153,7 +153,7 @@ This file is divided into 3 sections: @VisibleForTesting @@ -203,6 +203,18 @@ This file is divided into 3 sections: + + + + java,scala,3rdParty,spark + javax?\..+ + scala\..+ + (?!org\.apache\.spark\.).* + org\.apache\.spark\..* + + + + From 1bcc18b2ecf046f149a270eae69a94e97a129d68 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 8 Dec 2015 14:34:15 -0800 Subject: [PATCH 1081/1168] [SPARK-12187] *MemoryPool classes should not be fully public This patch tightens them to `private[memory]`. Author: Andrew Or Closes #10182 from andrewor14/memory-visibility. --- .../scala/org/apache/spark/memory/ExecutionMemoryPool.scala | 2 +- core/src/main/scala/org/apache/spark/memory/MemoryPool.scala | 2 +- .../main/scala/org/apache/spark/memory/StorageMemoryPool.scala | 2 +- .../scala/org/apache/spark/memory/UnifiedMemoryManager.scala | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala index 7825bae425877..9023e1ac012b7 100644 --- a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala @@ -39,7 +39,7 @@ import org.apache.spark.Logging * @param lock a [[MemoryManager]] instance to synchronize on * @param poolName a human-readable name for this pool, for use in log messages */ -class ExecutionMemoryPool( +private[memory] class ExecutionMemoryPool( lock: Object, poolName: String ) extends MemoryPool(lock) with Logging { diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala index bfeec47e3892e..1b9edf9c43bda 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala @@ -27,7 +27,7 @@ import javax.annotation.concurrent.GuardedBy * to `Object` to avoid programming errors, since this object should only be used for * synchronization purposes. */ -abstract class MemoryPool(lock: Object) { +private[memory] abstract class MemoryPool(lock: Object) { @GuardedBy("lock") private[this] var _poolSize: Long = 0 diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala index 6a322eabf81ed..fc4f0357e9f16 100644 --- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.{MemoryStore, BlockStatus, BlockId} * * @param lock a [[MemoryManager]] instance to synchronize on */ -class StorageMemoryPool(lock: Object) extends MemoryPool(lock) with Logging { +private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) with Logging { @GuardedBy("lock") private[this] var _memoryUsed: Long = 0L diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 48b4e23433e43..0f1ea9ab39c07 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -49,7 +49,7 @@ import org.apache.spark.storage.{BlockStatus, BlockId} private[spark] class UnifiedMemoryManager private[memory] ( conf: SparkConf, val maxMemory: Long, - private val storageRegionSize: Long, + storageRegionSize: Long, numCores: Int) extends MemoryManager( conf, From a384868b51a1ff131a70d33d0907f3c0477e372c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Dec 2015 15:58:35 -0800 Subject: [PATCH 1082/1168] [SPARK-12069][SQL] Update documentation with Datasets Author: Michael Armbrust Closes #10060 from marmbrus/docs. --- docs/_layouts/global.html | 2 +- docs/index.md | 2 +- docs/sql-programming-guide.md | 268 +++++++++++------- .../scala/org/apache/spark/sql/Encoder.scala | 48 +++- .../scala/org/apache/spark/sql/Column.scala | 21 +- 5 files changed, 237 insertions(+), 104 deletions(-) diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 1b09e2221e173..0b5b0cd48af64 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -71,7 +71,7 @@
  • Spark Programming Guide
  • Spark Streaming
  • -
  • DataFrames and SQL
  • +
  • DataFrames, Datasets and SQL
  • MLlib (Machine Learning)
  • GraphX (Graph Processing)
  • Bagel (Pregel on Spark)
  • diff --git a/docs/index.md b/docs/index.md index f1d9e012c6cf0..ae26f97c86c21 100644 --- a/docs/index.md +++ b/docs/index.md @@ -87,7 +87,7 @@ options for deployment: in all supported languages (Scala, Java, Python, R) * Modules built on Spark: * [Spark Streaming](streaming-programming-guide.html): processing real-time data streams - * [Spark SQL and DataFrames](sql-programming-guide.html): support for structured data and relational queries + * [Spark SQL, Datasets, and DataFrames](sql-programming-guide.html): support for structured data and relational queries * [MLlib](mllib-guide.html): built-in machine learning library * [GraphX](graphx-programming-guide.html): Spark's new API for graph processing diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7b1d97baa3823..9f87accd30f40 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1,6 +1,6 @@ --- layout: global -displayTitle: Spark SQL and DataFrame Guide +displayTitle: Spark SQL, DataFrames and Datasets Guide title: Spark SQL and DataFrames --- @@ -9,18 +9,51 @@ title: Spark SQL and DataFrames # Overview -Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. +Spark SQL is a Spark module for structured data processing. Unlike the basic Spark RDD API, the interfaces provided +by Spark SQL provide Spark with more information about the structure of both the data and the computation being performed. Internally, +Spark SQL uses this extra information to perform extra optimizations. There are several ways to +interact with Spark SQL including SQL, the DataFrames API and the Datasets API. When computing a result +the same execution engine is used, independent of which API/language you are using to express the +computation. This unification means that developers can easily switch back and forth between the +various APIs based on which provides the most natural way to express a given transformation. -Spark SQL can also be used to read data from an existing Hive installation. For more on how to configure this feature, please refer to the [Hive Tables](#hive-tables) section. +All of the examples on this page use sample data included in the Spark distribution and can be run in +the `spark-shell`, `pyspark` shell, or `sparkR` shell. -# DataFrames +## SQL -A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs. +One use of Spark SQL is to execute SQL queries written using either a basic SQL syntax or HiveQL. +Spark SQL can also be used to read data from an existing Hive installation. For more on how to +configure this feature, please refer to the [Hive Tables](#hive-tables) section. When running +SQL from within another programming language the results will be returned as a [DataFrame](#DataFrames). +You can also interact with the SQL interface using the [command-line](#running-the-spark-sql-cli) +or over [JDBC/ODBC](#running-the-thrift-jdbcodbc-server). -The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). +## DataFrames -All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. +A DataFrame is a distributed collection of data organized into named columns. It is conceptually +equivalent to a table in a relational database or a data frame in R/Python, but with richer +optimizations under the hood. DataFrames can be constructed from a wide array of [sources](#data-sources) such +as: structured data files, tables in Hive, external databases, or existing RDDs. +The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), +[Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), +[Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). + +## Datasets + +A Dataset is a new experimental interface added in Spark 1.6 that tries to provide the benefits of +RDDs (strong typing, ability to use powerful lambda functions) with the benefits of Spark SQL's +optimized execution engine. A Dataset can be [constructed](#creating-datasets) from JVM objects and then manipulated +using functional transformations (map, flatMap, filter, etc.). + +The unified Dataset API can be used both in [Scala](api/scala/index.html#org.apache.spark.sql.Dataset) and +[Java](api/java/index.html?org/apache/spark/sql/Dataset.html). Python does not yet have support for +the Dataset API, but due to its dynamic nature many of the benefits are already available (i.e. you can +access the field of a row by name naturally `row.columnName`). Full python support will be added +in a future release. + +# Getting Started ## Starting Point: SQLContext @@ -29,7 +62,7 @@ All of the examples on this page use sample data included in the Spark distribut The entry point into all functionality in Spark SQL is the [`SQLContext`](api/scala/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic `SQLContext`, all you need is a SparkContext. +descendants. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight scala %} val sc: SparkContext // An existing SparkContext. @@ -45,7 +78,7 @@ import sqlContext.implicits._ The entry point into all functionality in Spark SQL is the [`SQLContext`](api/java/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic `SQLContext`, all you need is a SparkContext. +descendants. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight java %} JavaSparkContext sc = ...; // An existing JavaSparkContext. @@ -58,7 +91,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); The entry point into all relational functionality in Spark is the [`SQLContext`](api/python/pyspark.sql.html#pyspark.sql.SQLContext) class, or one -of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. +of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight python %} from pyspark.sql import SQLContext @@ -70,7 +103,7 @@ sqlContext = SQLContext(sc)
    The entry point into all relational functionality in Spark is the -`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. +`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight r %} sqlContext <- sparkRSQL.init(sc) @@ -82,18 +115,18 @@ sqlContext <- sparkRSQL.init(sc) In addition to the basic `SQLContext`, you can also create a `HiveContext`, which provides a superset of the functionality provided by the basic `SQLContext`. Additional features include the ability to write queries using the more complete HiveQL parser, access to Hive UDFs, and the -ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an +ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an existing Hive setup, and all of the data sources available to a `SQLContext` are still available. `HiveContext` is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using `HiveContext` -is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up +Spark build. If these dependencies are not a problem for your application then using `HiveContext` +is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up to feature parity with a `HiveContext`. The specific variant of SQL that is used to parse queries can also be selected using the -`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on -a `SQLContext` or by using a `SET key=value` command in SQL. For a `SQLContext`, the only dialect -available is "sql" which uses a simple SQL parser provided by Spark SQL. In a `HiveContext`, the -default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, +`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on +a `SQLContext` or by using a `SET key=value` command in SQL. For a `SQLContext`, the only dialect +available is "sql" which uses a simple SQL parser provided by Spark SQL. In a `HiveContext`, the +default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, this is recommended for most use cases. @@ -215,7 +248,7 @@ df.groupBy("age").count().show() For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/scala/index.html#org.apache.spark.sql.DataFrame). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.functions$). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.functions$).
    @@ -270,7 +303,7 @@ df.groupBy("age").count().show(); For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/java/org/apache/spark/sql/DataFrame.html). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html). @@ -331,7 +364,7 @@ df.groupBy("age").count().show() For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/python/pyspark.sql.html#pyspark.sql.DataFrame). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/python/pyspark.sql.html#module-pyspark.sql.functions). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/python/pyspark.sql.html#module-pyspark.sql.functions). @@ -385,7 +418,7 @@ showDF(count(groupBy(df, "age"))) For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/R/index.html). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/index.html). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/index.html). @@ -398,14 +431,14 @@ The `sql` function on a `SQLContext` enables applications to run SQL queries pro
    {% highlight scala %} -val sqlContext = ... // An existing SQLContext +val sqlContext = ... // An existing SQLContext val df = sqlContext.sql("SELECT * FROM table") {% endhighlight %}
    {% highlight java %} -SQLContext sqlContext = ... // An existing SQLContext +SQLContext sqlContext = ... // An existing SQLContext DataFrame df = sqlContext.sql("SELECT * FROM table") {% endhighlight %}
    @@ -428,15 +461,54 @@ df <- sql(sqlContext, "SELECT * FROM table")
    +## Creating Datasets + +Datasets are similar to RDDs, however, instead of using Java Serialization or Kryo they use +a specialized [Encoder](api/scala/index.html#org.apache.spark.sql.Encoder) to serialize the objects +for processing or transmitting over the network. While both encoders and standard serialization are +responsible for turning an object into bytes, encoders are code generated dynamically and use a format +that allows Spark to perform many operations like filtering, sorting and hashing without deserializing +the bytes back into an object. + +
    +
    + +{% highlight scala %} +// Encoders for most common types are automatically provided by importing sqlContext.implicits._ +val ds = Seq(1, 2, 3).toDS() +ds.map(_ + 1).collect() // Returns: Array(2, 3, 4) + +// Encoders are also created for case classes. +case class Person(name: String, age: Long) +val ds = Seq(Person("Andy", 32)).toDS() + +// DataFrames can be converted to a Dataset by providing a class. Mapping will be done by name. +val path = "examples/src/main/resources/people.json" +val people = sqlContext.read.json(path).as[Person] + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +JavaSparkContext sc = ...; // An existing JavaSparkContext. +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); +{% endhighlight %} + +
    +
    + ## Interoperating with RDDs -Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first -method uses reflection to infer the schema of an RDD that contains specific types of objects. This +Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first +method uses reflection to infer the schema of an RDD that contains specific types of objects. This reflection based approach leads to more concise code and works well when you already know the schema while writing your Spark application. The second method for creating DataFrames is through a programmatic interface that allows you to -construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows +construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows you to construct DataFrames when the columns and their types are not known until runtime. ### Inferring the Schema Using Reflection @@ -445,11 +517,11 @@ you to construct DataFrames when the columns and their types are not known until
    The Scala interface for Spark SQL supports automatically converting an RDD containing case classes -to a DataFrame. The case class -defines the schema of the table. The names of the arguments to the case class are read using +to a DataFrame. The case class +defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex types such as Sequences or Arrays. This RDD can be implicitly converted to a DataFrame and then be -registered as a table. Tables can be used in subsequent SQL statements. +registered as a table. Tables can be used in subsequent SQL statements. {% highlight scala %} // sc is an existing SparkContext. @@ -486,9 +558,9 @@ teenagers.map(_.getValuesMap[Any](List("name", "age"))).collect().foreach(printl
    Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) -into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. +into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain -nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a +nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. {% highlight java %} @@ -559,9 +631,9 @@ List teenagerNames = teenagers.javaRDD().map(new Function()
    -Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of +Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, -and the types are inferred by looking at the first row. Since we currently only look at the first +and the types are inferred by looking at the first row. Since we currently only look at the first row, it is important that there is no missing data in the first row of the RDD. In future versions we plan to more completely infer the schema by looking at more data, similar to the inference that is performed on JSON files. @@ -780,7 +852,7 @@ for name in names.collect(): Spark SQL supports operating on a variety of data sources through the `DataFrame` interface. A DataFrame can be operated on as normal RDDs and can also be registered as a temporary table. -Registering a DataFrame as a table allows you to run SQL queries over its data. This section +Registering a DataFrame as a table allows you to run SQL queries over its data. This section describes the general methods for loading and saving data using the Spark Data Sources and then goes into specific options that are available for the built-in data sources. @@ -834,9 +906,9 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet") ### Manually Specifying Options You can also manually specify the data source that will be used along with any extra options -that you would like to pass to the data source. Data sources are specified by their fully qualified +that you would like to pass to the data source. Data sources are specified by their fully qualified name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short -names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types +names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types using this syntax.
    @@ -923,8 +995,8 @@ df <- sql(sqlContext, "SELECT * FROM parquet.`examples/src/main/resources/users. ### Save Modes Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if -present. It is important to realize that these save modes do not utilize any locking and are not -atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the +present. It is important to realize that these save modes do not utilize any locking and are not +atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the new data. @@ -960,7 +1032,7 @@ new data.
    Ignore mode means that when saving a DataFrame to a data source, if data already exists, the save operation is expected to not save the contents of the DataFrame and to not - change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL. + change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL.
    @@ -968,14 +1040,14 @@ new data. ### Saving to Persistent Tables When working with a `HiveContext`, `DataFrames` can also be saved as persistent tables using the -`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the -contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables +`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the +contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables will still exist even after your Spark program has restarted, as long as you maintain your connection -to the same metastore. A DataFrame for a persistent table can be created by calling the `table` +to the same metastore. A DataFrame for a persistent table can be created by calling the `table` method on a `SQLContext` with the name of the table. By default `saveAsTable` will create a "managed table", meaning that the location of the data will -be controlled by the metastore. Managed tables will also have their data deleted automatically +be controlled by the metastore. Managed tables will also have their data deleted automatically when a table is dropped. ## Parquet Files @@ -1003,7 +1075,7 @@ val people: RDD[Person] = ... // An RDD of case class objects, from the previous // The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. people.write.parquet("people.parquet") -// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. +// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a Parquet file is also a DataFrame. val parquetFile = sqlContext.read.parquet("people.parquet") @@ -1025,7 +1097,7 @@ DataFrame schemaPeople = ... // The DataFrame from the previous example. // DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.write().parquet("people.parquet"); -// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); @@ -1051,7 +1123,7 @@ schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.write.parquet("people.parquet") -# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. parquetFile = sqlContext.read.parquet("people.parquet") @@ -1075,7 +1147,7 @@ schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. saveAsParquetFile(schemaPeople, "people.parquet") -# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. parquetFile <- parquetFile(sqlContext, "people.parquet") @@ -1110,10 +1182,10 @@ SELECT * FROM parquetTable ### Partition Discovery -Table partitioning is a common optimization approach used in systems like Hive. In a partitioned +Table partitioning is a common optimization approach used in systems like Hive. In a partitioned table, data are usually stored in different directories, with partitioning column values encoded in -the path of each partition directory. The Parquet data source is now able to discover and infer -partitioning information automatically. For example, we can store all our previously used +the path of each partition directory. The Parquet data source is now able to discover and infer +partitioning information automatically. For example, we can store all our previously used population data into a partitioned table using the following directory structure, with two extra columns, `gender` and `country` as partitioning columns: @@ -1155,7 +1227,7 @@ root {% endhighlight %} -Notice that the data types of the partitioning columns are automatically inferred. Currently, +Notice that the data types of the partitioning columns are automatically inferred. Currently, numeric data types and string type are supported. Sometimes users may not want to automatically infer the data types of the partitioning columns. For these use cases, the automatic type inference can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to @@ -1164,13 +1236,13 @@ can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, w ### Schema Merging -Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with -a simple schema, and gradually add more columns to the schema as needed. In this way, users may end -up with multiple Parquet files with different but mutually compatible schemas. The Parquet data +Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with +a simple schema, and gradually add more columns to the schema as needed. In this way, users may end +up with multiple Parquet files with different but mutually compatible schemas. The Parquet data source is now able to automatically detect this case and merge schemas of all these files. Since schema merging is a relatively expensive operation, and is not a necessity in most cases, we -turned it off by default starting from 1.5.0. You may enable it by +turned it off by default starting from 1.5.0. You may enable it by 1. setting data source option `mergeSchema` to `true` when reading Parquet files (as shown in the examples below), or @@ -1284,10 +1356,10 @@ processing. 1. Hive considers all columns nullable, while nullability in Parquet is significant Due to this reason, we must reconcile Hive metastore schema with Parquet schema when converting a -Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: +Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: 1. Fields that have the same name in both schema must have the same data type regardless of - nullability. The reconciled field should have the data type of the Parquet side, so that + nullability. The reconciled field should have the data type of the Parquet side, so that nullability is respected. 1. The reconciled schema contains exactly those fields defined in Hive metastore schema. @@ -1298,8 +1370,8 @@ Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation r #### Metadata Refreshing -Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table -conversion is enabled, metadata of those converted tables are also cached. If these tables are +Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table +conversion is enabled, metadata of those converted tables are also cached. If these tables are updated by Hive or other external tools, you need to refresh them manually to ensure consistent metadata. @@ -1362,7 +1434,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.int96AsTimestamp true - Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This + Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems. @@ -1400,7 +1472,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`

    The output committer class used by Parquet. The specified class needs to be a subclass of - org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a + org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a subclass of org.apache.parquet.hadoop.ParquetOutputCommitter.

    @@ -1628,7 +1700,7 @@ YARN cluster. The convenient way to do this is adding them through the `--jars` When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do -not have an existing Hive deployment can still create a `HiveContext`. When not configured by the +not have an existing Hive deployment can still create a `HiveContext`. When not configured by the hive-site.xml, the context automatically creates `metastore_db` in the current directory and creates `warehouse` directory indicated by HiveConf, which defaults to `/user/hive/warehouse`. Note that you may need to grant write privilege on `/user/hive/warehouse` to the user who starts @@ -1738,10 +1810,10 @@ The following options can be used to configure the version of Hive that is used enabled. When this option is chosen, spark.sql.hive.metastore.version must be either 1.2.1 or not defined.

  • maven
  • - Use Hive jars of specified version downloaded from Maven repositories. This configuration + Use Hive jars of specified version downloaded from Maven repositories. This configuration is not generally recommended for production deployments. -
  • A classpath in the standard format for the JVM. This classpath must include all of Hive - and its dependencies, including the correct version of Hadoop. These jars only need to be +
  • A classpath in the standard format for the JVM. This classpath must include all of Hive + and its dependencies, including the correct version of Hadoop. These jars only need to be present on the driver, but if you are running in yarn cluster mode then you must ensure they are packaged with you application.
  • @@ -1776,7 +1848,7 @@ The following options can be used to configure the version of Hive that is used ## JDBC To Other Databases -Spark SQL also includes a data source that can read data from other databases using JDBC. This +Spark SQL also includes a data source that can read data from other databases using JDBC. This functionality should be preferred over using [JdbcRDD](api/scala/index.html#org.apache.spark.rdd.JdbcRDD). This is because the results are returned as a DataFrame and they can easily be processed in Spark SQL or joined with other data sources. @@ -1786,7 +1858,7 @@ provide a ClassTag. run queries using Spark SQL). To get started you will need to include the JDBC driver for you particular database on the -spark classpath. For example, to connect to postgres from the Spark Shell you would run the +spark classpath. For example, to connect to postgres from the Spark Shell you would run the following command: {% highlight bash %} @@ -1794,7 +1866,7 @@ SPARK_CLASSPATH=postgresql-9.3-1102-jdbc41.jar bin/spark-shell {% endhighlight %} Tables from the remote database can be loaded as a DataFrame or Spark SQL Temporary table using -the Data Sources API. The following options are supported: +the Data Sources API. The following options are supported: @@ -1807,8 +1879,8 @@ the Data Sources API. The following options are supported: @@ -1816,7 +1888,7 @@ the Data Sources API. The following options are supported: @@ -1825,7 +1897,7 @@ the Data Sources API. The following options are supported: @@ -1947,7 +2019,7 @@ Configuration of in-memory caching can be done using the `setConf` method on `SQ ## Other Configuration Options -The following options can also be used to tune the performance of query execution. It is possible +The following options can also be used to tune the performance of query execution. It is possible that these options will be deprecated in future release as more optimizations are performed automatically.
    Property NameMeaning
    dbtable - The JDBC table that should be read. Note that anything that is valid in a FROM clause of - a SQL query can be used. For example, instead of a full table you could also use a + The JDBC table that should be read. Note that anything that is valid in a FROM clause of + a SQL query can be used. For example, instead of a full table you could also use a subquery in parentheses.
    driver - The class name of the JDBC driver needed to connect to this URL. This class will be loaded + The class name of the JDBC driver needed to connect to this URL. This class will be loaded on the master and workers before running an JDBC commands to allow the driver to register itself with the JDBC subsystem.
    partitionColumn, lowerBound, upperBound, numPartitions - These options must all be specified if any of them is specified. They describe how to + These options must all be specified if any of them is specified. They describe how to partition the table when reading in parallel from multiple workers. partitionColumn must be a numeric column from the table in question. Notice that lowerBound and upperBound are just used to decide the @@ -1938,7 +2010,7 @@ Configuration of in-memory caching can be done using the `setConf` method on `SQ spark.sql.inMemoryColumnarStorage.batchSize 10000 - Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization + Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization and compression, but risk OOMs when caching data.
    @@ -1957,7 +2029,7 @@ that these options will be deprecated in future release as more optimizations ar @@ -1995,8 +2067,8 @@ To start the JDBC/ODBC server, run the following in the Spark directory: ./sbin/start-thriftserver.sh This script accepts all `bin/spark-submit` command line options, plus a `--hiveconf` option to -specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of -all available options. By default, the server listens on localhost:10000. You may override this +specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of +all available options. By default, the server listens on localhost:10000. You may override this behaviour via either environment variables, i.e.: {% highlight bash %} @@ -2062,10 +2134,10 @@ options. ## Upgrading From Spark SQL 1.5 to 1.6 - - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC - connection owns a copy of their own SQL configuration and temporary function registry. Cached - tables are still shared though. If you prefer to run the Thrift server in the old single-session - mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add + - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC + connection owns a copy of their own SQL configuration and temporary function registry. Cached + tables are still shared though. If you prefer to run the Thrift server in the old single-session + mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add this option to `spark-defaults.conf`, or pass it to `start-thriftserver.sh` via `--conf`: {% highlight bash %} @@ -2077,20 +2149,20 @@ options. ## Upgrading From Spark SQL 1.4 to 1.5 - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with - code generation for expression evaluation. These features can both be disabled by setting + code generation for expression evaluation. These features can both be disabled by setting `spark.sql.tungsten.enabled` to `false`. - - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting + - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting `spark.sql.parquet.mergeSchema` to `true`. - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or - access nested values. For example `df['table.column.nestedField']`. However, this means that if - your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). + access nested values. For example `df['table.column.nestedField']`. However, this means that if + your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). - In-memory columnar storage partition pruning is on by default. It can be disabled by setting `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. - Unlimited precision decimal columns are no longer supported, instead Spark SQL enforces a maximum - precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now + precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now used. When no precision is specified in DDL then the default remains `Decimal(10, 0)`. - Timestamps are now stored at a precision of 1us, rather than 1ns - - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains + - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains unchanged. - The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM). - It has been determined that using the DirectOutputCommitter when speculation is enabled is unsafe @@ -2183,38 +2255,38 @@ sqlContext.setConf("spark.sql.retainGroupColumns", "false") ## Upgrading from Spark SQL 1.0-1.2 to 1.3 In Spark 1.3 we removed the "Alpha" label from Spark SQL and as part of this did a cleanup of the -available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other -releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked +available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other +releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked as unstable (i.e., DeveloperAPI or Experimental). #### Rename of SchemaRDD to DataFrame The largest change that users will notice when upgrading to Spark SQL 1.3 is that `SchemaRDD` has -been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD +been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD directly, but instead provide most of the functionality that RDDs provide though their own -implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. +implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for -some use cases. It is still recommended that users update their code to use `DataFrame` instead. +some use cases. It is still recommended that users update their code to use `DataFrame` instead. Java and Python users will need to update their code. #### Unification of the Java and Scala APIs Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) -that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users -of either language should use `SQLContext` and `DataFrame`. In general theses classes try to +that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users +of either language should use `SQLContext` and `DataFrame`. In general theses classes try to use types that are usable from both languages (i.e. `Array` instead of language specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. -Additionally the Java specific types API has been removed. Users of both Scala and Java should +Additionally the Java specific types API has been removed. Users of both Scala and Java should use the classes present in `org.apache.spark.sql.types` to describe schema programmatically. #### Isolation of Implicit Conversions and Removal of dsl Package (Scala-only) Many of the code examples prior to Spark 1.3 started with `import sqlContext._`, which brought -all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit +all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit conversions for converting `RDD`s into `DataFrame`s into an object inside of the `SQLContext`. Users should now write `import sqlContext.implicits._`. @@ -2222,7 +2294,7 @@ Additionally, the implicit conversions now only augment RDDs that are composed o case classes or tuples) with a method `toDF`, instead of applying automatically. When using function inside of the DSL (now replaced with the `DataFrame` API) users used to import -`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used: +`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used: `import org.apache.spark.sql.functions._`. #### Removal of the type aliases in org.apache.spark.sql for DataType (Scala-only) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 3ca5ade7f30f1..bb0fdc4c3d83b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -19,20 +19,60 @@ package org.apache.spark.sql import java.lang.reflect.Modifier +import scala.annotation.implicitNotFound import scala.reflect.{ClassTag, classTag} +import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions.{DecodeUsingSerializer, BoundReference, EncodeUsingSerializer} import org.apache.spark.sql.types._ /** + * :: Experimental :: * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. * - * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking - * and reuse internal buffers to improve performance. + * == Scala == + * Encoders are generally created automatically through implicits from a `SQLContext`. + * + * {{{ + * import sqlContext.implicits._ + * + * val ds = Seq(1, 2, 3).toDS() // implicitly provided (sqlContext.implicits.newIntEncoder) + * }}} + * + * == Java == + * Encoders are specified by calling static methods on [[Encoders]]. + * + * {{{ + * List data = Arrays.asList("abc", "abc", "xyz"); + * Dataset ds = context.createDataset(data, Encoders.STRING()); + * }}} + * + * Encoders can be composed into tuples: + * + * {{{ + * Encoder> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); + * List> data2 = Arrays.asList(new scala.Tuple2(1, "a"); + * Dataset> ds2 = context.createDataset(data2, encoder2); + * }}} + * + * Or constructed from Java Beans: + * + * {{{ + * Encoders.bean(MyClass.class); + * }}} + * + * == Implementation == + * - Encoders are not required to be thread-safe and thus they do not need to use locks to guard + * against concurrent access if they reuse internal buffers to improve performance. * * @since 1.6.0 */ +@Experimental +@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + + "(Int, String, etc) and Product types (case classes) are supported by importing " + + "sqlContext.implicits._ Support for serializing other types will be added in future " + + "releases.") trait Encoder[T] extends Serializable { /** Returns the schema of encoding this type of object as a Row. */ @@ -43,10 +83,12 @@ trait Encoder[T] extends Serializable { } /** - * Methods for creating encoders. + * :: Experimental :: + * Methods for creating an [[Encoder]]. * * @since 1.6.0 */ +@Experimental object Encoders { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ad6af481fadc4..d641fcac1c8ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -73,7 +73,26 @@ class TypedColumn[-T, U]( /** * :: Experimental :: - * A column in a [[DataFrame]]. + * A column that will be computed based on the data in a [[DataFrame]]. + * + * A new column is constructed based on the input columns present in a dataframe: + * + * {{{ + * df("columnName") // On a specific DataFrame. + * col("columnName") // A generic column no yet associcated with a DataFrame. + * col("columnName.field") // Extracting a struct field + * col("`a.column.with.dots`") // Escape `.` in column names. + * $"columnName" // Scala short hand for a named column. + * expr("a + 1") // A column that is constructed from a parsed SQL Expression. + * lit("1") // A column that produces a literal (constant) value. + * }}} + * + * [[Column]] objects can be composed to form complex expressions: + * + * {{{ + * $"a" + 1 + * $"a" === $"b" + * }}} * * @groupname java_expr_ops Java-specific expression operators * @groupname expr_ops Expression operators From a4aaed06d1d0917d336d03db90c7df4345850f8a Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 8 Dec 2015 18:40:21 -0800 Subject: [PATCH 1083/1168] [SPARK-8517][ML][DOC] Reorganizes the spark.ml user guide This PR moves pieces of the spark.ml user guide to reflect suggestions in SPARK-8517. It does not introduce new content, as requested. screen shot 2015-12-08 at 11 36 00 am Author: Timothy Hunter Closes #10207 from thunterdb/spark-8517. --- docs/_data/menu-ml.yaml | 18 +- docs/ml-advanced.md | 13 + docs/ml-ann.md | 62 -- docs/ml-classification-regression.md | 775 ++++++++++++++++++++++ docs/ml-clustering.md | 5 + docs/ml-features.md | 4 +- docs/ml-intro.md | 941 +++++++++++++++++++++++++++ docs/mllib-guide.md | 15 +- 8 files changed, 1752 insertions(+), 81 deletions(-) create mode 100644 docs/ml-advanced.md delete mode 100644 docs/ml-ann.md create mode 100644 docs/ml-classification-regression.md create mode 100644 docs/ml-intro.md diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml index dff3d33bf4ed1..fe37d0573e46b 100644 --- a/docs/_data/menu-ml.yaml +++ b/docs/_data/menu-ml.yaml @@ -1,10 +1,10 @@ -- text: Feature extraction, transformation, and selection +- text: "Overview: estimators, transformers and pipelines" + url: ml-intro.html +- text: Extracting, transforming and selecting features url: ml-features.html -- text: Decision trees for classification and regression - url: ml-decision-tree.html -- text: Ensembles - url: ml-ensembles.html -- text: Linear methods with elastic-net regularization - url: ml-linear-methods.html -- text: Multilayer perceptron classifier - url: ml-ann.html +- text: Classification and Regression + url: ml-classification-regression.html +- text: Clustering + url: ml-clustering.html +- text: Advanced topics + url: ml-advanced.html diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md new file mode 100644 index 0000000000000..b005633e56c11 --- /dev/null +++ b/docs/ml-advanced.md @@ -0,0 +1,13 @@ +--- +layout: global +title: Advanced topics - spark.ml +displayTitle: Advanced topics +--- + +# Optimization of linear methods + +The optimization algorithm underlying the implementation is called +[Orthant-Wise Limited-memory +QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 +regularization and elastic net. diff --git a/docs/ml-ann.md b/docs/ml-ann.md deleted file mode 100644 index 6e763e8f41568..0000000000000 --- a/docs/ml-ann.md +++ /dev/null @@ -1,62 +0,0 @@ ---- -layout: global -title: Multilayer perceptron classifier - ML -displayTitle: ML - Multilayer perceptron classifier ---- - - -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` - - -Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). -MLPC consists of multiple layers of nodes. -Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs -by performing linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. -It can be written in matrix form for MLPC with `$K+1$` layers as follows: -`\[ -\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) -\]` -Nodes in intermediate layers use sigmoid (logistic) function: -`\[ -\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} -\]` -Nodes in the output layer use softmax function: -`\[ -\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} -\]` -The number of nodes `$N$` in the output layer corresponds to the number of classes. - -MLPC employes backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. - -**Examples** - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java %} -
    - -
    -{% include_example python/ml/multilayer_perceptron_classification.py %} -
    - -
    diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md new file mode 100644 index 0000000000000..3663ffee32753 --- /dev/null +++ b/docs/ml-classification-regression.md @@ -0,0 +1,775 @@ +--- +layout: global +title: Classification and regression - spark.ml +displayTitle: Classification and regression in spark.ml +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +In MLlib, we implement popular linear methods such as logistic +regression and linear least squares with $L_1$ or $L_2$ regularization. +Refer to [the linear methods in mllib](mllib-linear-methods.html) for +details. In `spark.ml`, we also include Pipelines API for [Elastic +net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid +of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization +and variable selection via the elastic +net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). +Mathematically, it is defined as a convex combination of the $L_1$ and +the $L_2$ regularization terms: +`\[ +\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 +\]` +By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ +regularization as special cases. For example, if a [linear +regression](https://en.wikipedia.org/wiki/Linear_regression) model is +trained with the elastic net parameter $\alpha$ set to $1$, it is +equivalent to a +[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. +On the other hand, if $\alpha$ is set to $0$, the trained model reduces +to a [ridge +regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. +We implement Pipelines API for both linear regression and logistic +regression with elastic net regularization. + + +# Classification + +## Logistic regression + +Logistic regression is a popular method to predict a binary response. It is a special case of [Generalized Linear models](https://en.wikipedia.org/wiki/Generalized_linear_model) that predicts the probability of the outcome. +For more background and more details about the implementation, refer to the documentation of the [logistic regression in `spark.mllib`](mllib-linear-methods.html#logistic-regression). + + > The current implementation of logistic regression in `spark.ml` only supports binary classes. Support for multiclass regression will be added in the future. + +**Example** + +The following example shows how to train a logistic regression model +with elastic net regularization. `elasticNetParam` corresponds to +$\alpha$ and `regParam` corresponds to $\lambda$. + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java %} +
    + +
    +{% include_example python/ml/logistic_regression_with_elastic_net.py %} +
    + +
    + +The `spark.ml` implementation of logistic regression also supports +extracting a summary of the model over the training set. Note that the +predictions and metrics which are stored as `Dataframe` in +`BinaryLogisticRegressionSummary` are annotated `@transient` and hence +only available on the driver. + +
    + +
    + +[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) +provides a summary for a +[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala %} +
    + +
    +[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) +provides a summary for a +[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java %} +
    + + +
    +Logistic regression model summary is not yet supported in Python. +
    + +
    + + +## Decision tree classifier + +Decision trees are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on decision trees](#decision-trees). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. + +
    +
    + +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier). + +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala %} + +
    + +
    + +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html). + +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java %} + +
    + +
    + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier). + +{% include_example python/ml/decision_tree_classification_example.py %} + +
    + +
    + +## Random forest classifier + +Random forests are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on random forests](#random-forests). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier) for more details. + +{% include_example scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/RandomForestClassifier.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier) for more details. + +{% include_example python/ml/random_forest_classifier_example.py %} +
    +
    + +## Gradient-boosted tree classifier + +Gradient-boosted trees (GBTs) are a popular classification and regression method using ensembles of decision trees. +More information about the `spark.ml` implementation can be found further in the [section on GBTs](#gradient-boosted-trees-gbts). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier) for more details. + +{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/GBTClassifier.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier) for more details. + +{% include_example python/ml/gradient_boosted_tree_classifier_example.py %} +
    +
    + +## Multilayer perceptron classifier + +Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). +MLPC consists of multiple layers of nodes. +Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs +by performing linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. +It can be written in matrix form for MLPC with `$K+1$` layers as follows: +`\[ +\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) +\]` +Nodes in intermediate layers use sigmoid (logistic) function: +`\[ +\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} +\]` +Nodes in the output layer use softmax function: +`\[ +\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} +\]` +The number of nodes `$N$` in the output layer corresponds to the number of classes. + +MLPC employes backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. + +**Example** + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java %} +
    + +
    +{% include_example python/ml/multilayer_perceptron_classification.py %} +
    + +
    + + +## One-vs-Rest classifier (a.k.a. One-vs-All) + +[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All." + +`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. + +Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. + +**Example** + +The example below demonstrates how to load the +[Iris dataset](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale), parse it as a DataFrame and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) for more details. + +{% include_example scala/org/apache/spark/examples/ml/OneVsRestExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRest.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaOneVsRestExample.java %} +
    +
    + + +# Regression + +## Linear regression + +The interface for working with linear regression models and model +summaries is similar to the logistic regression case. + +**Example** + +The following +example demonstrates training an elastic net regularized linear +regression model and extracting model summary statistics. + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java %} +
    + +
    + +{% include_example python/ml/linear_regression_with_elastic_net.py %} +
    + +
    + + +## Decision tree regression + +Decision trees are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on decision trees](#decision-trees). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. + +
    +
    + +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor). + +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala %} +
    + +
    + +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html). + +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java %} +
    + +
    + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor). + +{% include_example python/ml/decision_tree_regression_example.py %} +
    + +
    + + +## Random forest regression + +Random forests are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on random forests](#random-forests). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor) for more details. + +{% include_example scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/RandomForestRegressor.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor) for more details. + +{% include_example python/ml/random_forest_regressor_example.py %} +
    +
    + +## Gradient-boosted tree regression + +Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees. +More information about the `spark.ml` implementation can be found further in the [section on GBTs](#gradient-boosted-trees-gbts). + +**Example** + +Note: For this example dataset, `GBTRegressor` actually only needs 1 iteration, but that will not +be true in general. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor) for more details. + +{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GBTRegressor.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor) for more details. + +{% include_example python/ml/gradient_boosted_tree_regressor_example.py %} +
    +
    + + +## Survival regression + + +In `spark.ml`, we implement the [Accelerated failure time (AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) +model which is a parametric survival regression model for censored data. +It describes a model for the log of survival time, so it's often called +log-linear model for survival analysis. Different from +[Proportional hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model +designed for the same purpose, the AFT model is more easily to parallelize +because each instance contribute to the objective function independently. + +Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of +subjects i = 1, ..., n, with possible right-censoring, +the likelihood function under the AFT model is given as: +`\[ +L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} +\]` +Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. +Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}$, the log-likelihood function +assumes the form: +`\[ +\iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] +\]` +Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, +and $f_{0}(\epsilon_{i})$ is corresponding density function. + +The most commonly used AFT model is based on the Weibull distribution of the survival time. +The Weibull distribution for lifetime corresponding to extreme value distribution for +log of the lifetime, and the $S_{0}(\epsilon)$ function is: +`\[ +S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) +\]` +the $f_{0}(\epsilon_{i})$ function is: +`\[ +f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) +\]` +The log-likelihood function for AFT model with Weibull distribution of lifetime is: +`\[ +\iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] +\]` +Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, +the loss function we use to optimize is $-\iota(\beta,\sigma)$. +The gradient functions for $\beta$ and $\log\sigma$ respectively are: +`\[ +\frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} +\]` +`\[ +\frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] +\]` + +The AFT model can be formulated as a convex optimization problem, +i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ +that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. +The optimization algorithm underlying the implementation is L-BFGS. +The implementation matches the result from R's survival function +[survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) + +**Example** + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %} +
    + +
    +{% include_example python/ml/aft_survival_regression.py %} +
    + +
    + + + +# Decision trees + +[Decision trees](http://en.wikipedia.org/wiki/Decision_tree_learning) +and their ensembles are popular methods for the machine learning tasks of +classification and regression. Decision trees are widely used since they are easy to interpret, +handle categorical features, extend to the multiclass classification setting, do not require +feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble +algorithms such as random forests and boosting are among the top performers for classification and +regression tasks. + +MLlib supports decision trees for binary and multiclass classification and for regression, +using both continuous and categorical features. The implementation partitions data by rows, +allowing distributed training with millions or even billions of instances. + +Users can find more information about the decision tree algorithm in the [MLlib Decision Tree guide](mllib-decision-tree.html). +The main differences between this API and the [original MLlib Decision Tree API](mllib-decision-tree.html) are: + +* support for ML Pipelines +* separation of Decision Trees for classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features + + +The Pipelines API for Decision Trees offers a bit more functionality than the original API. In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities). + +Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described below in the [Tree ensembles section](#tree-ensembles). + +## Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +### Input Columns + +
    10485760 (10 MB) Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when - performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently + performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently statistics are only supported for Hive Metastore tables where the command ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.
    + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +### Output Columns + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    + + +# Tree Ensembles + +The Pipelines API supports two major tree ensemble algorithms: [Random Forests](http://en.wikipedia.org/wiki/Random_forest) and [Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting). +Both use [MLlib decision trees](ml-decision-tree.html) as their base models. + +Users can find more information about ensemble algorithms in the [MLlib Ensemble guide](mllib-ensembles.html). In this section, we demonstrate the Pipelines API for ensembles. + +The main differences between this API and the [original MLlib ensembles API](mllib-ensembles.html) are: + +* support for ML Pipelines +* separation of classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features +* a bit more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification. + +## Random Forests + +[Random forests](http://en.wikipedia.org/wiki/Random_forest) +are ensembles of [decision trees](ml-decision-tree.html). +Random forests combine many decision trees in order to reduce the risk of overfitting. +MLlib supports random forests for binary and multiclass classification and for regression, +using both continuous and categorical features. + +For more information on the algorithm itself, please see the [`spark.mllib` documentation on random forests](mllib-ensembles.html). + +### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +#### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +#### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    + + + +## Gradient-Boosted Trees (GBTs) + +[Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting) +are ensembles of [decision trees](ml-decision-tree.html). +GBTs iteratively train decision trees in order to minimize a loss function. +MLlib supports GBTs for binary classification and for regression, +using both continuous and categorical features. + +For more information on the algorithm itself, please see the [`spark.mllib` documentation on GBTs](mllib-ensembles.html). + +### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +#### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +Note that `GBTClassifier` currently only supports binary labels. + +#### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    + +In the future, `GBTClassifier` will also output columns for `rawPrediction` and `probability`, just as `RandomForestClassifier` does. + diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md index cfefb5dfbde9e..697777714b05b 100644 --- a/docs/ml-clustering.md +++ b/docs/ml-clustering.md @@ -6,6 +6,11 @@ displayTitle: ML - Clustering In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + ## Latent Dirichlet allocation (LDA) `LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, diff --git a/docs/ml-features.md b/docs/ml-features.md index e15c26836affc..55e401221917e 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1,7 +1,7 @@ --- layout: global -title: Feature Extraction, Transformation, and Selection - SparkML -displayTitle: ML - Features +title: Extracting, transforming and selecting features +displayTitle: Extracting, transforming and selecting features --- This section covers algorithms for working with features, roughly divided into these groups: diff --git a/docs/ml-intro.md b/docs/ml-intro.md new file mode 100644 index 0000000000000..d95a66ba23566 --- /dev/null +++ b/docs/ml-intro.md @@ -0,0 +1,941 @@ +--- +layout: global +title: "Overview: estimators, transformers and pipelines - spark.ml" +displayTitle: "Overview: estimators, transformers and pipelines" +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +The `spark.ml` package aims to provide a uniform set of high-level APIs built on top of +[DataFrames](sql-programming-guide.html#dataframes) that help users create and tune practical +machine learning pipelines. +See the [algorithm guides](#algorithm-guides) section below for guides on sub-packages of +`spark.ml`, including feature transformers unique to the Pipelines API, ensembles, and more. + +**Table of contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + + +# Main concepts in Pipelines + +Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple +algorithms into a single pipeline, or workflow. +This section covers the key concepts introduced by the Spark ML API, where the pipeline concept is +mostly inspired by the [scikit-learn](http://scikit-learn.org/) project. + +* **[`DataFrame`](ml-guide.html#dataframe)**: Spark ML uses `DataFrame` from Spark SQL as an ML + dataset, which can hold a variety of data types. + E.g., a `DataFrame` could have different columns storing text, feature vectors, true labels, and predictions. + +* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. +E.g., an ML model is a `Transformer` which transforms `DataFrame` with features into a `DataFrame` with predictions. + +* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. +E.g., a learning algorithm is an `Estimator` which trains on a `DataFrame` and produces a model. + +* **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. + +* **[`Parameter`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. + +## DataFrame + +Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. +Spark ML adopts the `DataFrame` from Spark SQL in order to support a variety of data types. + +`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. +In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](mllib-data-types.html#local-vector) types. + +A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. + +Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label." + +## Pipeline components + +### Transformers + +A `Transformer` is an abstraction that includes feature transformers and learned models. +Technically, a `Transformer` implements a method `transform()`, which converts one `DataFrame` into +another, generally by appending one or more columns. +For example: + +* A feature transformer might take a `DataFrame`, read a column (e.g., text), map it into a new + column (e.g., feature vectors), and output a new `DataFrame` with the mapped column appended. +* A learning model might take a `DataFrame`, read the column containing feature vectors, predict the + label for each feature vector, and output a new `DataFrame` with predicted labels appended as a + column. + +### Estimators + +An `Estimator` abstracts the concept of a learning algorithm or any algorithm that fits or trains on +data. +Technically, an `Estimator` implements a method `fit()`, which accepts a `DataFrame` and produces a +`Model`, which is a `Transformer`. +For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling +`fit()` trains a `LogisticRegressionModel`, which is a `Model` and hence a `Transformer`. + +### Properties of pipeline components + +`Transformer.transform()`s and `Estimator.fit()`s are both stateless. In the future, stateful algorithms may be supported via alternative concepts. + +Each instance of a `Transformer` or `Estimator` has a unique ID, which is useful in specifying parameters (discussed below). + +## Pipeline + +In machine learning, it is common to run a sequence of algorithms to process and learn from data. +E.g., a simple text document processing workflow might include several stages: + +* Split each document's text into words. +* Convert each document's words into a numerical feature vector. +* Learn a prediction model using the feature vectors and labels. + +Spark ML represents such a workflow as a `Pipeline`, which consists of a sequence of +`PipelineStage`s (`Transformer`s and `Estimator`s) to be run in a specific order. +We will use this simple workflow as a running example in this section. + +### How it works + +A `Pipeline` is specified as a sequence of stages, and each stage is either a `Transformer` or an `Estimator`. +These stages are run in order, and the input `DataFrame` is transformed as it passes through each stage. +For `Transformer` stages, the `transform()` method is called on the `DataFrame`. +For `Estimator` stages, the `fit()` method is called to produce a `Transformer` (which becomes part of the `PipelineModel`, or fitted `Pipeline`), and that `Transformer`'s `transform()` method is called on the `DataFrame`. + +We illustrate this for the simple text document workflow. The figure below is for the *training time* usage of a `Pipeline`. + +

    + Spark ML Pipeline Example +

    + +Above, the top row represents a `Pipeline` with three stages. +The first two (`Tokenizer` and `HashingTF`) are `Transformer`s (blue), and the third (`LogisticRegression`) is an `Estimator` (red). +The bottom row represents data flowing through the pipeline, where cylinders indicate `DataFrame`s. +The `Pipeline.fit()` method is called on the original `DataFrame`, which has raw text documents and labels. +The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words to the `DataFrame`. +The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the `DataFrame`. +Now, since `LogisticRegression` is an `Estimator`, the `Pipeline` first calls `LogisticRegression.fit()` to produce a `LogisticRegressionModel`. +If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` +method on the `DataFrame` before passing the `DataFrame` to the next stage. + +A `Pipeline` is an `Estimator`. +Thus, after a `Pipeline`'s `fit()` method runs, it produces a `PipelineModel`, which is a +`Transformer`. +This `PipelineModel` is used at *test time*; the figure below illustrates this usage. + +

    + Spark ML PipelineModel Example +

    + +In the figure above, the `PipelineModel` has the same number of stages as the original `Pipeline`, but all `Estimator`s in the original `Pipeline` have become `Transformer`s. +When the `PipelineModel`'s `transform()` method is called on a test dataset, the data are passed +through the fitted pipeline in order. +Each stage's `transform()` method updates the dataset and passes it to the next stage. + +`Pipeline`s and `PipelineModel`s help to ensure that training and test data go through identical feature processing steps. + +### Details + +*DAG `Pipeline`s*: A `Pipeline`'s stages are specified as an ordered array. The examples given here are all for linear `Pipeline`s, i.e., `Pipeline`s in which each stage uses data produced by the previous stage. It is possible to create non-linear `Pipeline`s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the `Pipeline` forms a DAG, then the stages must be specified in topological order. + +*Runtime checking*: Since `Pipeline`s can operate on `DataFrame`s with varied types, they cannot use +compile-time type checking. +`Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. +This type checking is done using the `DataFrame` *schema*, a description of the data types of columns in the `DataFrame`. + +*Unique Pipeline stages*: A `Pipeline`'s stages should be unique instances. E.g., the same instance +`myHashingTF` should not be inserted into the `Pipeline` twice since `Pipeline` stages must have +unique IDs. However, different instances `myHashingTF1` and `myHashingTF2` (both of type `HashingTF`) +can be put into the same `Pipeline` since different instances will be created with different IDs. + +## Parameters + +Spark ML `Estimator`s and `Transformer`s use a uniform API for specifying parameters. + +A `Param` is a named parameter with self-contained documentation. +A `ParamMap` is a set of (parameter, value) pairs. + +There are two main ways to pass parameters to an algorithm: + +1. Set parameters for an instance. E.g., if `lr` is an instance of `LogisticRegression`, one could + call `lr.setMaxIter(10)` to make `lr.fit()` use at most 10 iterations. + This API resembles the API used in `spark.mllib` package. +2. Pass a `ParamMap` to `fit()` or `transform()`. Any parameters in the `ParamMap` will override parameters previously specified via setter methods. + +Parameters belong to specific instances of `Estimator`s and `Transformer`s. +For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. +This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. + +# Code examples + +This section gives code examples illustrating the functionality discussed above. +For more info, please refer to the API documentation +([Scala](api/scala/index.html#org.apache.spark.ml.package), +[Java](api/java/org/apache/spark/ml/package-summary.html), +and [Python](api/python/pyspark.ml.html)). +Some Spark ML algorithms are wrappers for `spark.mllib` algorithms, and the +[MLlib programming guide](mllib-guide.html) has details on specific algorithms. + +## Example: Estimator, Transformer, and Param + +This example covers the concepts of `Estimator`, `Transformer`, and `Param`. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.sql.Row + +// Prepare training data from a list of (label, features) tuples. +val training = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(0.0, 1.1, 0.1)), + (0.0, Vectors.dense(2.0, 1.0, -1.0)), + (0.0, Vectors.dense(2.0, 1.3, 1.0)), + (1.0, Vectors.dense(0.0, 1.2, -0.5)) +)).toDF("label", "features") + +// Create a LogisticRegression instance. This instance is an Estimator. +val lr = new LogisticRegression() +// Print out the parameters, documentation, and any default values. +println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + +// We may set parameters using setter methods. +lr.setMaxIter(10) + .setRegParam(0.01) + +// Learn a LogisticRegression model. This uses the parameters stored in lr. +val model1 = lr.fit(training) +// Since model1 is a Model (i.e., a Transformer produced by an Estimator), +// we can view the parameters it used during fit(). +// This prints the parameter (name: value) pairs, where names are unique IDs for this +// LogisticRegression instance. +println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) + +// We may alternatively specify parameters using a ParamMap, +// which supports several methods for specifying parameters. +val paramMap = ParamMap(lr.maxIter -> 20) + .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. + .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + +// One can also combine ParamMaps. +val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name +val paramMapCombined = paramMap ++ paramMap2 + +// Now learn a new model using the paramMapCombined parameters. +// paramMapCombined overrides all parameters set earlier via lr.set* methods. +val model2 = lr.fit(training, paramMapCombined) +println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) + +// Prepare test data. +val test = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(-1.0, 1.5, 1.3)), + (0.0, Vectors.dense(3.0, 2.0, -0.1)), + (1.0, Vectors.dense(0.0, 2.2, -1.5)) +)).toDF("label", "features") + +// Make predictions on test data using the Transformer.transform() method. +// LogisticRegression.transform will only use the 'features' column. +// Note that model2.transform() outputs a 'myProbability' column instead of the usual +// 'probability' column since we renamed the lr.probabilityCol parameter previously. +model2.transform(test) + .select("features", "label", "myProbability", "prediction") + .collect() + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println(s"($features, $label) -> prob=$prob, prediction=$prediction") + } + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; + +// Prepare training data. +// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans +// into DataFrames, where it uses the bean metadata to infer the schema. +DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) +), LabeledPoint.class); + +// Create a LogisticRegression instance. This instance is an Estimator. +LogisticRegression lr = new LogisticRegression(); +// Print out the parameters, documentation, and any default values. +System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n"); + +// We may set parameters using setter methods. +lr.setMaxIter(10) + .setRegParam(0.01); + +// Learn a LogisticRegression model. This uses the parameters stored in lr. +LogisticRegressionModel model1 = lr.fit(training); +// Since model1 is a Model (i.e., a Transformer produced by an Estimator), +// we can view the parameters it used during fit(). +// This prints the parameter (name: value) pairs, where names are unique IDs for this +// LogisticRegression instance. +System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); + +// We may alternatively specify parameters using a ParamMap. +ParamMap paramMap = new ParamMap() + .put(lr.maxIter().w(20)) // Specify 1 Param. + .put(lr.maxIter(), 30) // This overwrites the original maxIter. + .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. + +// One can also combine ParamMaps. +ParamMap paramMap2 = new ParamMap() + .put(lr.probabilityCol().w("myProbability")); // Change output column name +ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); + +// Now learn a new model using the paramMapCombined parameters. +// paramMapCombined overrides all parameters set earlier via lr.set* methods. +LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); +System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); + +// Prepare test documents. +DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) +), LabeledPoint.class); + +// Make predictions on test documents using the Transformer.transform() method. +// LogisticRegression.transform will only use the 'features' column. +// Note that model2.transform() outputs a 'myProbability' column instead of the usual +// 'probability' column since we renamed the lr.probabilityCol parameter previously. +DataFrame results = model2.transform(test); +for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + + ", prediction=" + r.get(3)); +} + +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.param import Param, Params + +# Prepare training data from a list of (label, features) tuples. +training = sqlContext.createDataFrame([ + (1.0, Vectors.dense([0.0, 1.1, 0.1])), + (0.0, Vectors.dense([2.0, 1.0, -1.0])), + (0.0, Vectors.dense([2.0, 1.3, 1.0])), + (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"]) + +# Create a LogisticRegression instance. This instance is an Estimator. +lr = LogisticRegression(maxIter=10, regParam=0.01) +# Print out the parameters, documentation, and any default values. +print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" + +# Learn a LogisticRegression model. This uses the parameters stored in lr. +model1 = lr.fit(training) + +# Since model1 is a Model (i.e., a transformer produced by an Estimator), +# we can view the parameters it used during fit(). +# This prints the parameter (name: value) pairs, where names are unique IDs for this +# LogisticRegression instance. +print "Model 1 was fit using parameters: " +print model1.extractParamMap() + +# We may alternatively specify parameters using a Python dictionary as a paramMap +paramMap = {lr.maxIter: 20} +paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. +paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + +# You can combine paramMaps, which are python dictionaries. +paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name +paramMapCombined = paramMap.copy() +paramMapCombined.update(paramMap2) + +# Now learn a new model using the paramMapCombined parameters. +# paramMapCombined overrides all parameters set earlier via lr.set* methods. +model2 = lr.fit(training, paramMapCombined) +print "Model 2 was fit using parameters: " +print model2.extractParamMap() + +# Prepare test data +test = sqlContext.createDataFrame([ + (1.0, Vectors.dense([-1.0, 1.5, 1.3])), + (0.0, Vectors.dense([3.0, 2.0, -0.1])), + (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) + +# Make predictions on test data using the Transformer.transform() method. +# LogisticRegression.transform will only use the 'features' column. +# Note that model2.transform() outputs a "myProbability" column instead of the usual +# 'probability' column since we renamed the lr.probabilityCol parameter previously. +prediction = model2.transform(test) +selected = prediction.select("features", "label", "myProbability", "prediction") +for row in selected.collect(): + print row + +{% endhighlight %} +
    + +
    + +## Example: Pipeline + +This example follows the simple text document `Pipeline` illustrated in the figures above. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.Row + +// Prepare training documents from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0) +)).toDF("id", "text", "label") + +// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. +val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") +val hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") +val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01) +val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + +// Fit the pipeline to training documents. +val model = pipeline.fit(training) + +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") + +// Make predictions on test documents. +model.transform(test) + .select("id", "text", "probability", "prediction") + .collect() + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; + +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from Java Beans. +public class Document implements Serializable { + private long id; + private String text; + + public Document(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } + + public String getText() { return this.text; } + public void setText(String text) { this.text = text; } +} + +public class LabeledDocument extends Document implements Serializable { + private double label; + + public LabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } + + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } +} + +// Prepare training documents, which are labeled. +DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new LabeledDocument(0L, "a b c d e spark", 1.0), + new LabeledDocument(1L, "b d", 0.0), + new LabeledDocument(2L, "spark f g h", 1.0), + new LabeledDocument(3L, "hadoop mapreduce", 0.0) +), LabeledDocument.class); + +// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. +Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); +HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); +LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + +// Fit the pipeline to training documents. +PipelineModel model = pipeline.fit(training); + +// Prepare test documents, which are unlabeled. +DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new Document(4L, "spark i j k"), + new Document(5L, "l m n"), + new Document(6L, "mapreduce spark"), + new Document(7L, "apache hadoop") +), Document.class); + +// Make predictions on test documents. +DataFrame predictions = model.transform(test); +for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); +} + +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.sql import Row + +# Prepare training documents from a list of (id, text, label) tuples. +LabeledDocument = Row("id", "text", "label") +training = sqlContext.createDataFrame([ + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) + +# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. +tokenizer = Tokenizer(inputCol="text", outputCol="words") +hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") +lr = LogisticRegression(maxIter=10, regParam=0.01) +pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + +# Fit the pipeline to training documents. +model = pipeline.fit(training) + +# Prepare test documents, which are unlabeled (id, text) tuples. +test = sqlContext.createDataFrame([ + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")], ["id", "text"]) + +# Make predictions on test documents and print columns of interest. +prediction = model.transform(test) +selected = prediction.select("id", "text", "prediction") +for row in selected.collect(): + print(row) + +{% endhighlight %} +
    + +
    + +## Example: model selection via cross-validation + +An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. +`Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. + +Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator). +`CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. +`CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. + +The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) +for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) +for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` +method in each of these evaluators. + +The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. +`CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. + +The following example demonstrates using `CrossValidator` to select from a grid of parameters. +To help construct the parameter grid, we use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility. + +Note that cross-validation over a grid of parameters is expensive. +E.g., in the example below, the parameter grid has 3 values for `hashingTF.numFeatures` and 2 values for `lr.regParam`, and `CrossValidator` uses 2 folds. This multiplies out to `$(3 \times 2) \times 2 = 12$` different models being trained. +In realistic settings, it can be common to try many more parameters and use more folds (`$k=3$` and `$k=10$` are common). +In other words, using `CrossValidator` can be very expensive. +However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.Row + +// Prepare training data from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0), + (4L, "b spark who", 1.0), + (5L, "g d a y", 0.0), + (6L, "spark fly", 1.0), + (7L, "was mapreduce", 0.0), + (8L, "e spark program", 1.0), + (9L, "a e c l", 0.0), + (10L, "spark compile", 1.0), + (11L, "hadoop software", 0.0) +)).toDF("id", "text", "label") + +// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. +val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") +val hashingTF = new HashingTF() + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") +val lr = new LogisticRegression() + .setMaxIter(10) +val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, +// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. +val paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) + .addGrid(lr.regParam, Array(0.1, 0.01)) + .build() + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2) // Use 3+ in practice + +// Run cross-validation, and choose the best set of parameters. +val cvModel = cv.fit(training) + +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") + +// Make predictions on test documents. cvModel uses the best model found (lrModel). +cvModel.transform(test) + .select("id", "text", "probability", "prediction") + .collect() + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.tuning.CrossValidator; +import org.apache.spark.ml.tuning.CrossValidatorModel; +import org.apache.spark.ml.tuning.ParamGridBuilder; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; + +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from Java Beans. +public class Document implements Serializable { + private long id; + private String text; + + public Document(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } + + public String getText() { return this.text; } + public void setText(String text) { this.text = text; } +} + +public class LabeledDocument extends Document implements Serializable { + private double label; + + public LabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } + + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } +} + + +// Prepare training documents, which are labeled. +DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new LabeledDocument(0L, "a b c d e spark", 1.0), + new LabeledDocument(1L, "b d", 0.0), + new LabeledDocument(2L, "spark f g h", 1.0), + new LabeledDocument(3L, "hadoop mapreduce", 0.0), + new LabeledDocument(4L, "b spark who", 1.0), + new LabeledDocument(5L, "g d a y", 0.0), + new LabeledDocument(6L, "spark fly", 1.0), + new LabeledDocument(7L, "was mapreduce", 0.0), + new LabeledDocument(8L, "e spark program", 1.0), + new LabeledDocument(9L, "a e c l", 0.0), + new LabeledDocument(10L, "spark compile", 1.0), + new LabeledDocument(11L, "hadoop software", 0.0) +), LabeledDocument.class); + +// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. +Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); +HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); +LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, +// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. +ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) + .addGrid(lr.regParam(), new double[]{0.1, 0.01}) + .build(); + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +CrossValidator cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2); // Use 3+ in practice + +// Run cross-validation, and choose the best set of parameters. +CrossValidatorModel cvModel = cv.fit(training); + +// Prepare test documents, which are unlabeled. +DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new Document(4L, "spark i j k"), + new Document(5L, "l m n"), + new Document(6L, "mapreduce spark"), + new Document(7L, "apache hadoop") +), Document.class); + +// Make predictions on test documents. cvModel uses the best model found (lrModel). +DataFrame predictions = cvModel.transform(test); +for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); +} + +{% endhighlight %} +
    + +
    + +## Example: model selection via train validation split +In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. +`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in + case of `CrossValidator`. It is therefore less expensive, + but will not produce as reliable results when the training dataset is not sufficiently large. + +`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter, +and an `Evaluator`. +It begins by splitting the dataset into two parts using `trainRatio` parameter +which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default), +`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation. +Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s. +For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`. +The `ParamMap` which produces the best evaluation metric is selected as the best option. +`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} + +// Prepare training and test data. +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") +val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) + +val lr = new LinearRegression() + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() + +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + // 80% of the data will be used for training and the remaining 20% for validation. + .setTrainRatio(0.8) + +// Run train validation split, and choose the best set of parameters. +val model = trainValidationSplit.fit(training) + +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show() + +{% endhighlight %} +
    + +
    +{% highlight java %} +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.*; +import org.apache.spark.sql.DataFrame; + +DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + +// Prepare training and test data. +DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); +DataFrame training = splits[0]; +DataFrame test = splits[1]; + +LinearRegression lr = new LinearRegression(); + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation + +// Run train validation split, and choose the best set of parameters. +TrainValidationSplitModel model = trainValidationSplit.fit(training); + +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show(); + +{% endhighlight %} +
    + +
    \ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 43772adcf26e1..3bc2b780601c2 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -66,15 +66,14 @@ We list major functionality from both below, with links to detailed guides. # spark.ml: high-level APIs for ML pipelines -**[spark.ml programming guide](ml-guide.html)** provides an overview of the Pipelines API and major -concepts. It also contains sections on using algorithms within the Pipelines API, for example: - -* [Feature extraction, transformation, and selection](ml-features.html) +* [Overview: estimators, transformers and pipelines](ml-intro.html) +* [Extracting, transforming and selecting features](ml-features.html) +* [Classification and regression](ml-classification-regression.html) * [Clustering](ml-clustering.html) -* [Decision trees for classification and regression](ml-decision-tree.html) -* [Ensembles](ml-ensembles.html) -* [Linear methods with elastic net regularization](ml-linear-methods.html) -* [Multilayer perceptron classifier](ml-ann.html) +* [Advanced topics](ml-advanced.html) + +Some techniques are not available yet in spark.ml, most notably dimensionality reduction +Users can seemlessly combine the implementation of these techniques found in `spark.mllib` with the rest of the algorithms found in `spark.ml`. # Dependencies From c6a00f420006cf7d8d5e9391953742a0ffc38252 Mon Sep 17 00:00:00 2001 From: Dominik Dahlem Date: Tue, 8 Dec 2015 18:54:10 -0800 Subject: [PATCH 1084/1168] [SPARK-11343][ML] Documentation of float and double prediction/label columns in RegressionEvaluator felixcheung , mengxr Just added a message to require() Author: Dominik Dahlem Closes #9598 from dahlem/ddahlem_regression_evaluator_double_predictions_message_04112015. --- .../apache/spark/ml/evaluation/RegressionEvaluator.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index daaa174a086e0..b6b25ecd01b3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -73,10 +73,15 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.4.0") override def evaluate(dataset: DataFrame): Double = { val schema = dataset.schema + val predictionColName = $(predictionCol) val predictionType = schema($(predictionCol)).dataType - require(predictionType == FloatType || predictionType == DoubleType) + require(predictionType == FloatType || predictionType == DoubleType, + s"Prediction column $predictionColName must be of type float or double, " + + s" but not $predictionType") + val labelColName = $(labelCol) val labelType = schema($(labelCol)).dataType - require(labelType == FloatType || labelType == DoubleType) + require(labelType == FloatType || labelType == DoubleType, + s"Label column $labelColName must be of type float or double, but not $labelType") val predictionAndLabels = dataset .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) From 1258f458ec80e54ddb3cb31f0b08334c646e5436 Mon Sep 17 00:00:00 2001 From: Fei Wang Date: Tue, 8 Dec 2015 21:32:31 -0800 Subject: [PATCH 1085/1168] [SPARK-12222] [CORE] Deserialize RoaringBitmap using Kryo serializer throw Buffer underflow exception Jira: https://issues.apache.org/jira/browse/SPARK-12222 Deserialize RoaringBitmap using Kryo serializer throw Buffer underflow exception: ``` com.esotericsoftware.kryo.KryoException: Buffer underflow. at com.esotericsoftware.kryo.io.Input.require(Input.java:156) at com.esotericsoftware.kryo.io.Input.skip(Input.java:131) at com.esotericsoftware.kryo.io.Input.skip(Input.java:264) ``` This is caused by a bug of kryo's `Input.skip(long count)`(https://github.com/EsotericSoftware/kryo/issues/119) and we call this method in `KryoInputDataInputBridge`. Instead of upgrade kryo's version, this pr bypass the kryo's `Input.skip(long count)` by directly call another `skip` method in kryo's Input.java(https://github.com/EsotericSoftware/kryo/blob/kryo-2.21/src/com/esotericsoftware/kryo/io/Input.java#L124), i.e. write the bug-fixed version of `Input.skip(long count)` in KryoInputDataInputBridge's `skipBytes` method. more detail link to https://github.com/apache/spark/pull/9748#issuecomment-162860246 Author: Fei Wang Closes #10213 from scwf/patch-1. --- .../spark/serializer/KryoSerializer.scala | 10 ++++++- .../serializer/KryoSerializerSuite.scala | 28 ++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 62d445f3d7bd9..cb2ac5ea167ec 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -400,7 +400,15 @@ private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends Dat override def readUTF(): String = input.readString() // readString in kryo does utf8 override def readInt(): Int = input.readInt() override def readUnsignedShort(): Int = input.readShortUnsigned() - override def skipBytes(n: Int): Int = input.skip(n.toLong).toInt + override def skipBytes(n: Int): Int = { + var remaining: Long = n + while (remaining > 0) { + val skip = Math.min(Integer.MAX_VALUE, remaining).asInstanceOf[Int] + input.skip(skip) + remaining -= skip + } + n + } override def readFully(b: Array[Byte]): Unit = input.read(b) override def readFully(b: Array[Byte], off: Int, len: Int): Unit = input.read(b, off, len) override def readLine(): String = throw new UnsupportedOperationException("readLine") diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index f81fe3113106f..9fcc22b608c65 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -17,17 +17,21 @@ package org.apache.spark.serializer -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileOutputStream, FileInputStream} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} + +import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ +import org.apache.spark.util.Utils import org.apache.spark.storage.BlockManagerId class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { @@ -350,6 +354,28 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { assert(thrown.getMessage.contains(kryoBufferMaxProperty)) } + test("SPARK-12222: deserialize RoaringBitmap throw Buffer underflow exception") { + val dir = Utils.createTempDir() + val tmpfile = dir.toString + "/RoaringBitmap" + val outStream = new FileOutputStream(tmpfile) + val output = new KryoOutput(outStream) + val bitmap = new RoaringBitmap + bitmap.add(1) + bitmap.add(3) + bitmap.add(5) + bitmap.serialize(new KryoOutputDataOutputBridge(output)) + output.flush() + output.close() + + val inStream = new FileInputStream(tmpfile) + val input = new KryoInput(inStream) + val ret = new RoaringBitmap + ret.deserialize(new KryoInputDataInputBridge(input)) + input.close() + assert(ret == bitmap) + Utils.deleteRecursively(dir) + } + test("getAutoReset") { val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance] assert(ser.getAutoReset) From 7298d578d8726b046919e11cb94c98e39a781b13 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 9 Dec 2015 15:15:30 +0800 Subject: [PATCH 1086/1168] [SPARK-11676][SQL] Parquet filter tests all pass if filters are not really pushed down Currently Parquet predicate tests all pass even if filters are not pushed down or this is disabled. In this PR, For checking evaluating filters, Simply it makes the expression from `expression.Filter` and then try to create filters just like Spark does. For checking the results, this manually accesses to the child rdd (of `expression.Filter`) and produces the results which should be filtered properly, and then compares it to expected values. Now, if filters are not pushed down or this is disabled, this throws exceptions. Author: hyukjinkwon Closes #9659 from HyukjinKwon/SPARK-11676. --- .../parquet/ParquetFilterSuite.scala | 69 +++++++++++-------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index cc5aae03d5516..daf41bc292cc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -50,27 +50,33 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val output = predicate.collect { case a: Attribute => a }.distinct withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) - - val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation, _)) => filters - }.flatten.reduceLeftOption(_ && _) - assert(maybeAnalyzedPredicate.isDefined) - - val selectedFilters = maybeAnalyzedPredicate.flatMap(DataSourceStrategy.translateFilter) - assert(selectedFilters.nonEmpty) - - selectedFilters.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(df.schema, pred) - assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") - maybeFilter.foreach { f => - // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) - assert(f.getClass === filterClass) + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[ParquetRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(relation: ParquetRelation, _)) => + maybeRelation = Some(relation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + selectedFilters.foreach { pred => + val maybeFilter = ParquetFilters.createFilter(df.schema, pred) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") + maybeFilter.foreach { f => + // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) + assert(f.getClass === filterClass) + } } + checker(stripSparkFilter(query), expected) } - checker(query, expected) } } @@ -104,6 +110,21 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val childRDD = df + .queryExecution + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .child + .execute() + .map(row => Row.fromSeq(row.toSeq(schema))) + + sqlContext.createDataFrame(childRDD, schema) + } + test("filter pushdown - boolean") { withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) @@ -347,19 +368,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) val df = sqlContext.read.parquet(path).filter("a = 2") - // This is the source RDD without Spark-side filtering. - val childRDD = - df - .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] - .child - .execute() - // The result should be single row. // When a filter is pushed to Parquet, Parquet can apply it to every row. // So, we can check the number of rows returned from the Parquet // to make sure our filter pushdown work. - assert(childRDD.count == 1) + assert(stripSparkFilter(df).count == 1) } } } From 81a02af4d430ae74a8c7d71a00dee57984d48e4f Mon Sep 17 00:00:00 2001 From: uncleGen Date: Wed, 9 Dec 2015 15:09:40 +0000 Subject: [PATCH 1087/1168] [SPARK-12031][CORE][BUG] Integer overflow when do sampling Author: uncleGen Closes #10023 from uncleGen/1.6-bugfix. --- .../src/main/scala/org/apache/spark/Partitioner.scala | 4 ++-- .../org/apache/spark/util/random/SamplingUtils.scala | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index e4df7af81a6d2..ef9a2dab1c106 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -253,7 +253,7 @@ private[spark] object RangePartitioner { */ def sketch[K : ClassTag]( rdd: RDD[K], - sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { + sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = { val shift = rdd.id // val classTagK = classTag[K] // to avoid serializing the entire partitioner object val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => @@ -262,7 +262,7 @@ private[spark] object RangePartitioner { iter, sampleSizePerPartition, seed) Iterator((idx, n, sample)) }.collect() - val numItems = sketched.map(_._2.toLong).sum + val numItems = sketched.map(_._2).sum (numItems, sketched) } diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index c9a864ae62778..f98932a470165 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -34,7 +34,7 @@ private[spark] object SamplingUtils { input: Iterator[T], k: Int, seed: Long = Random.nextLong()) - : (Array[T], Int) = { + : (Array[T], Long) = { val reservoir = new Array[T](k) // Put the first k elements in the reservoir. var i = 0 @@ -52,16 +52,17 @@ private[spark] object SamplingUtils { (trimReservoir, i) } else { // If input size > k, continue the sampling process. + var l = i.toLong val rand = new XORShiftRandom(seed) while (input.hasNext) { val item = input.next() - val replacementIndex = rand.nextInt(i) + val replacementIndex = (rand.nextDouble() * l).toLong if (replacementIndex < k) { - reservoir(replacementIndex) = item + reservoir(replacementIndex.toInt) = item } - i += 1 + l += 1 } - (reservoir, i) + (reservoir, l) } } From 49c3d5df2c25b413dc8fcc59ded70e8dd2e6e72f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 9 Dec 2015 23:30:42 +0800 Subject: [PATCH 1088/1168] [SPARK-12012][SQL] Show more comprehensive PhysicalRDD metadata when visualizing SQL query plan This PR adds a `private[sql]` method `metadata` to `SparkPlan`, which can be used to describe detail information about a physical plan during visualization. Specifically, this PR uses this method to provide details of `PhysicalRDD`s translated from a data source relation. For example, a `ParquetRelation` converted from Hive metastore table `default.psrc` is now shown as the following screenshot: ![image](https://cloud.githubusercontent.com/assets/230655/11526657/e10cb7e6-9916-11e5-9afa-f108932ec890.png) And here is the screenshot for a regular `ParquetRelation` (not converted from Hive metastore table) loaded from a really long path: ![output](https://cloud.githubusercontent.com/assets/230655/11680582/37c66460-9e94-11e5-8f50-842db5309d5a.png) Author: Cheng Lian Closes #10004 from liancheng/spark-12012.physical-rdd-metadata. --- python/pyspark/sql/dataframe.py | 2 +- .../spark/sql/execution/ExistingRDD.scala | 19 +++++--- .../spark/sql/execution/SparkPlan.scala | 5 +++ .../spark/sql/execution/SparkPlanInfo.scala | 3 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../datasources/DataSourceStrategy.scala | 22 +++++++-- .../datasources/parquet/ParquetRelation.scala | 10 +++++ .../sql/execution/ui/SparkPlanGraph.scala | 45 +++++++++++-------- .../apache/spark/sql/sources/interfaces.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 7 ++- 11 files changed, 87 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 746bb55e14f22..78ab475eb466b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -213,7 +213,7 @@ def explain(self, extended=False): >>> df.explain() == Physical Plan == - Scan PhysicalRDD[age#0,name#1] + Scan ExistingRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 623348f6768a4..b8a43025882e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -97,22 +97,31 @@ private[sql] case class LogicalRDD( private[sql] case class PhysicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow], - extraInformation: String, + override val nodeName: String, + override val metadata: Map[String, String] = Map.empty, override val outputsUnsafeRows: Boolean = false) extends LeafNode { protected override def doExecute(): RDD[InternalRow] = rdd - override def simpleString: String = "Scan " + extraInformation + output.mkString("[", ",", "]") + override def simpleString: String = { + val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" + s"Scan $nodeName${output.mkString("[", ",", "]")}${metadataEntries.mkString(" ", ", ", "")}" + } } private[sql] object PhysicalRDD { + // Metadata keys + val INPUT_PATHS = "InputPaths" + val PUSHED_FILTERS = "PushedFilters" + def createFromDataSource( output: Seq[Attribute], rdd: RDD[InternalRow], relation: BaseRelation, - extraInformation: String = ""): PhysicalRDD = { - PhysicalRDD(output, rdd, relation.toString + extraInformation, - relation.isInstanceOf[HadoopFsRelation]) + metadata: Map[String, String] = Map.empty): PhysicalRDD = { + // All HadoopFsRelations output UnsafeRows + val outputUnsafeRows = relation.isInstanceOf[HadoopFsRelation] + PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a78177751c9dc..ec98f81041343 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -67,6 +67,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ super.makeCopy(newArgs) } + /** + * Return all metadata that describes more details of this SparkPlan. + */ + private[sql] def metadata: Map[String, String] = Map.empty + /** * Return all metrics containing metrics of this SparkPlan. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 486ce34064e43..4f750ad13ab84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -30,6 +30,7 @@ class SparkPlanInfo( val nodeName: String, val simpleString: String, val children: Seq[SparkPlanInfo], + val metadata: Map[String, String], val metrics: Seq[SQLMetricInfo]) private[sql] object SparkPlanInfo { @@ -41,6 +42,6 @@ private[sql] object SparkPlanInfo { } val children = plan.children.map(fromSparkPlan) - new SparkPlanInfo(plan.nodeName, plan.simpleString, children, metrics) + new SparkPlanInfo(plan.nodeName, plan.simpleString, children, plan.metadata, metrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f67c951bc0663..25e98c0bdd431 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -363,7 +363,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil - case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "PhysicalRDD") :: Nil + case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => apply(child) case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 544d5eccec037..8a15a51d825ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala @@ -25,6 +27,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} +import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} @@ -315,7 +318,20 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // `Filter`s or cannot be handled by `relation`. val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) - val pushedFiltersString = pushedFilters.mkString(" PushedFilter: [", ",", "] ") + val metadata: Map[String, String] = { + val pairs = ArrayBuffer.empty[(String, String)] + + if (pushedFilters.nonEmpty) { + pairs += (PUSHED_FILTERS -> pushedFilters.mkString("[", ", ", "]")) + } + + relation.relation match { + case r: HadoopFsRelation => pairs += INPUT_PATHS -> r.paths.mkString(", ") + case _ => + } + + pairs.toMap + } if (projects.map(_.toAttribute) == projects && projectSet.size == projects.size && @@ -334,7 +350,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val scan = execution.PhysicalRDD.createFromDataSource( projects.map(_.toAttribute), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation, pushedFiltersString) + relation.relation, metadata) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { // Don't request columns that are only referenced by pushed filters. @@ -344,7 +360,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val scan = execution.PhysicalRDD.createFromDataSource( requestedColumns, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation, pushedFiltersString) + relation.relation, metadata) execution.Project( projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index bb3e2786978c5..1af2a394f399a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -146,6 +146,12 @@ private[sql] class ParquetRelation( meta } + override def toString: String = { + parameters.get(ParquetRelation.METASTORE_TABLE_NAME).map { tableName => + s"${getClass.getSimpleName}: $tableName" + }.getOrElse(super.toString) + } + override def equals(other: Any): Boolean = other match { case that: ParquetRelation => val schemaEquality = if (shouldMergeSchemas) { @@ -521,6 +527,10 @@ private[sql] object ParquetRelation extends Logging { // internally. private[sql] val METASTORE_SCHEMA = "metastoreSchema" + // If a ParquetRelation is converted from a Hive metastore table, this option is set to the + // original Hive table name. + private[sql] val METASTORE_TABLE_NAME = "metastoreTableName" + /** * If parquet's block size (row group size) setting is larger than the min split size, * we use parquet's block size setting as the min split size. Otherwise, we will create diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 7af0ff09c5c6d..3a6eff9399825 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -66,7 +66,9 @@ private[sql] object SparkPlanGraph { SQLMetrics.getMetricParam(metric.metricParam)) } val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), planInfo.nodeName, planInfo.simpleString, metrics) + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, + planInfo.simpleString, planInfo.metadata, metrics) + nodes += node val childrenNodes = planInfo.children.map( child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) @@ -85,26 +87,33 @@ private[sql] object SparkPlanGraph { * @param metrics metrics that this SparkPlan node will track */ private[ui] case class SparkPlanGraphNode( - id: Long, name: String, desc: String, metrics: Seq[SQLPlanMetric]) { + id: Long, + name: String, + desc: String, + metadata: Map[String, String], + metrics: Seq[SQLPlanMetric]) { def makeDotNode(metricsValue: Map[Long, String]): String = { - val values = { - for (metric <- metrics; - value <- metricsValue.get(metric.accumulatorId)) yield { - metric.name + ": " + value - } + val builder = new mutable.StringBuilder(name) + + val values = for { + metric <- metrics + value <- metricsValue.get(metric.accumulatorId) + } yield { + metric.name + ": " + value } - val label = if (values.isEmpty) { - name - } else { - // If there are metrics, display all metrics in a separate line. We should use an escaped - // "\n" here to follow the dot syntax. - // - // Note: whitespace between two "\n"s is to create an empty line between the name of - // SparkPlan and metrics. If removing it, it won't display the empty line in UI. - name + "\\n \\n" + values.mkString("\\n") - } - s""" $id [label="$label"];""" + + if (values.nonEmpty) { + // If there are metrics, display each entry in a separate line. We should use an escaped + // "\n" here to follow the dot syntax. + // + // Note: whitespace between two "\n"s is to create an empty line between the name of + // SparkPlan and metrics. If removing it, it won't display the empty line in UI. + builder ++= "\\n \\n" + builder ++= values.mkString("\\n") + } + + s""" $id [label="${builder.toString()}"];""" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 9ace25dc7d21b..fc8ce6901dfca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -422,7 +422,7 @@ abstract class HadoopFsRelation private[sql]( parameters: Map[String, String]) extends BaseRelation with FileRelation with Logging { - override def toString: String = getClass.getSimpleName + paths.mkString("[", ",", "]") + override def toString: String = getClass.getSimpleName def this() = this(None, Map.empty[String, String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a4626259b2823..2fb439f50117a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -169,7 +169,7 @@ class PlannerSuite extends SharedSQLContext { withTempTable("testPushed") { val exp = sql("select * from testPushed where key = 15").queryExecution.executedPlan - assert(exp.toString.contains("PushedFilter: [EqualTo(key,15)]")) + assert(exp.toString.contains("PushedFilters: [EqualTo(key,15)]")) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 9a981d02ad67c..08b291e088238 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -411,7 +411,12 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // evil case insensitivity issue, which is reconciled within `ParquetRelation`. val parquetOptions = Map( ParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json, - ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString) + ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString, + ParquetRelation.METASTORE_TABLE_NAME -> TableIdentifier( + metastoreRelation.tableName, + Some(metastoreRelation.databaseName) + ).unquotedString + ) val tableIdentifier = QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) From ab069c0bc5d94d75eb31657dacafeb23c1b55e43 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 9 Dec 2015 16:45:13 +0000 Subject: [PATCH 1089/1168] [SPARK-10299][ML] word2vec should allow users to specify the window size Currently word2vec has the window hard coded at 5, some users may want different sizes (for example if using on n-gram input or similar). User request comes from http://stackoverflow.com/questions/32231975/spark-word2vec-window-size . Author: Holden Karau Author: Holden Karau Closes #8513 from holdenk/SPARK-10299-word2vec-should-allow-users-to-specify-the-window-size. --- .../apache/spark/ml/feature/Word2Vec.scala | 15 +++++++ .../apache/spark/mllib/feature/Word2Vec.scala | 11 ++++- .../spark/ml/feature/Word2VecSuite.scala | 43 +++++++++++++++++-- 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index a8d61b6dea00b..f105a983a34f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -49,6 +49,17 @@ private[feature] trait Word2VecBase extends Params /** @group getParam */ def getVectorSize: Int = $(vectorSize) + /** + * The window size (context words from [-window, window]) default 5. + * @group expertParam + */ + final val windowSize = new IntParam( + this, "windowSize", "the window size (context words from [-window, window])") + setDefault(windowSize -> 5) + + /** @group expertGetParam */ + def getWindowSize: Int = $(windowSize) + /** * Number of partitions for sentences of words. * Default: 1 @@ -106,6 +117,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setVectorSize(value: Int): this.type = set(vectorSize, value) + /** @group expertSetParam */ + def setWindowSize(value: Int): this.type = set(windowSize, value) + /** @group setParam */ def setStepSize(value: Double): this.type = set(stepSize, value) @@ -131,6 +145,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] .setNumPartitions($(numPartitions)) .setSeed($(seed)) .setVectorSize($(vectorSize)) + .setWindowSize($(windowSize)) .fit(input) copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 23b1514e3080e..1f400e1430eba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -125,6 +125,15 @@ class Word2Vec extends Serializable with Logging { this } + /** + * Sets the window of words (default: 5) + */ + @Since("1.6.0") + def setWindowSize(window: Int): this.type = { + this.window = window + this + } + /** * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). @@ -141,7 +150,7 @@ class Word2Vec extends Serializable with Logging { private val MAX_SENTENCE_LENGTH = 1000 /** context words from [-window, window] */ - private val window = 5 + private var window = 5 private var trainWordsCount = 0 private var vocabSize = 0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index a773244cd735e..d561bbbb25529 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -35,7 +35,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec") { - val sqlContext = new SQLContext(sc) + + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -77,7 +78,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("getVectors") { - val sqlContext = new SQLContext(sc) + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -118,7 +119,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("findSynonyms") { - val sqlContext = new SQLContext(sc) + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -141,7 +142,43 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul expectedSimilarity.zip(similarity).map { case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) } + } + + test("window size") { + + val sqlContext = this.sqlContext + import sqlContext.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setWindowSize(2) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + val (synonyms, similarity) = model.findSynonyms("a", 6).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + // Increase the window size + val biggerModel = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .setWindowSize(10) + .fit(docDF) + + val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + // The similarity score should be very different with the larger window + assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5) } test("Word2Vec read/write") { From 2f2b101683ebc3e17ae3984e6c8c79b562f5a7ec Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 9 Dec 2015 09:50:43 -0800 Subject: [PATCH 1090/1168] [SPARK-10582][YARN][CORE] Fix AM failure situation for dynamic allocation Because of AM failure, the target executor number between driver and AM will be different, which will lead to unexpected behavior in dynamic allocation. So when AM is re-registered with driver, state in `ExecutorAllocationManager` and `CoarseGrainedSchedulerBacked` should be reset. This issue is originally addressed in #8737 , here re-opened again. Thanks a lot KaiXinXiaoLei for finding this issue. andrewor14 and vanzin would you please help to review this, thanks a lot. Author: jerryshao Closes #9963 from jerryshao/SPARK-10582. --- .../spark/ExecutorAllocationManager.scala | 18 +++- .../CoarseGrainedSchedulerBackend.scala | 19 +++++ .../ExecutorAllocationManagerSuite.scala | 84 +++++++++++++++++++ .../cluster/YarnSchedulerBackend.scala | 23 +++++ 4 files changed, 142 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 34c32ce312964..6176e258989db 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -89,6 +89,8 @@ private[spark] class ExecutorAllocationManager( private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", Integer.MAX_VALUE) + private val initialNumExecutors = conf.getInt("spark.dynamicAllocation.initialExecutors", + minNumExecutors) // How long there must be backlogged tasks for before an addition is triggered (seconds) private val schedulerBacklogTimeoutS = conf.getTimeAsSeconds( @@ -121,8 +123,7 @@ private[spark] class ExecutorAllocationManager( // The desired number of executors at this moment in time. If all our executors were to die, this // is the number of executors we would immediately want from the cluster manager. - private var numExecutorsTarget = - conf.getInt("spark.dynamicAllocation.initialExecutors", minNumExecutors) + private var numExecutorsTarget = initialNumExecutors // Executors that have been requested to be removed but have not been killed yet private val executorsPendingToRemove = new mutable.HashSet[String] @@ -240,6 +241,19 @@ private[spark] class ExecutorAllocationManager( executor.awaitTermination(10, TimeUnit.SECONDS) } + /** + * Reset the allocation manager to the initial state. Currently this will only be called in + * yarn-client mode when AM re-registers after a failure. + */ + def reset(): Unit = synchronized { + initializing = true + numExecutorsTarget = initialNumExecutors + numExecutorsToAdd = 1 + + executorsPendingToRemove.clear() + removeTimes.clear() + } + /** * The maximum number of executors we would need under the current load to satisfy all running * and pending tasks, rounded up. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 505c161141c88..7efe16749e59d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -341,6 +341,25 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } + /** + * Reset the state of CoarseGrainedSchedulerBackend to the initial state. Currently it will only + * be called in the yarn-client mode when AM re-registers after a failure, also dynamic + * allocation is enabled. + * */ + protected def reset(): Unit = synchronized { + if (Utils.isDynamicAllocationEnabled(conf)) { + numPendingExecutors = 0 + executorsPendingToRemove.clear() + + // Remove all the lingering executors that should be removed but not yet. The reason might be + // because (1) disconnected event is not yet received; (2) executors die silently. + executorDataMap.toMap.foreach { case (eid, _) => + driverEndpoint.askWithRetry[Boolean]( + RemoveExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered."))) + } + } + } + override def reviveOffers() { driverEndpoint.send(ReviveOffers) } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 116f027a0f987..fedfbd547b91b 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -805,6 +805,90 @@ class ExecutorAllocationManagerSuite assert(maxNumExecutorsNeeded(manager) === 1) } + test("reset the state of allocation manager") { + sc = createSparkContext() + val manager = sc.executorAllocationManager.get + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + + // Allocation manager is reset when adding executor requests are sent without reporting back + // executor added. + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + + assert(addExecutors(manager) === 1) + assert(numExecutorsTarget(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsTarget(manager) === 4) + assert(addExecutors(manager) === 1) + assert(numExecutorsTarget(manager) === 5) + + manager.reset() + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorIds(manager) === Set.empty) + + // Allocation manager is reset when executors are added. + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + + addExecutors(manager) + addExecutors(manager) + addExecutors(manager) + assert(numExecutorsTarget(manager) === 5) + + onExecutorAdded(manager, "first") + onExecutorAdded(manager, "second") + onExecutorAdded(manager, "third") + onExecutorAdded(manager, "fourth") + onExecutorAdded(manager, "fifth") + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + // Cluster manager lost will make all the live executors lost, so here simulate this behavior + onExecutorRemoved(manager, "first") + onExecutorRemoved(manager, "second") + onExecutorRemoved(manager, "third") + onExecutorRemoved(manager, "fourth") + onExecutorRemoved(manager, "fifth") + + manager.reset() + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorIds(manager) === Set.empty) + assert(removeTimes(manager) === Map.empty) + + // Allocation manager is reset when executors are pending to remove + addExecutors(manager) + addExecutors(manager) + addExecutors(manager) + assert(numExecutorsTarget(manager) === 5) + + onExecutorAdded(manager, "first") + onExecutorAdded(manager, "second") + onExecutorAdded(manager, "third") + onExecutorAdded(manager, "fourth") + onExecutorAdded(manager, "fifth") + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + removeExecutor(manager, "first") + removeExecutor(manager, "second") + assert(executorsPendingToRemove(manager) === Set("first", "second")) + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + + // Cluster manager lost will make all the live executors lost, so here simulate this behavior + onExecutorRemoved(manager, "first") + onExecutorRemoved(manager, "second") + onExecutorRemoved(manager, "third") + onExecutorRemoved(manager, "fourth") + onExecutorRemoved(manager, "fifth") + + manager.reset() + + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorsPendingToRemove(manager) === Set.empty) + assert(removeTimes(manager) === Map.empty) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index e3dd87798f018..1431bceb256a7 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -60,6 +60,9 @@ private[spark] abstract class YarnSchedulerBackend( /** Scheduler extension services. */ private val services: SchedulerExtensionServices = new SchedulerExtensionServices() + // Flag to specify whether this schedulerBackend should be reset. + private var shouldResetOnAmRegister = false + /** * Bind to YARN. This *must* be done before calling [[start()]]. * @@ -155,6 +158,16 @@ private[spark] abstract class YarnSchedulerBackend( new YarnDriverEndpoint(rpcEnv, properties) } + /** + * Reset the state of SchedulerBackend to the initial state. This is happened when AM is failed + * and re-registered itself to driver after a failure. The stale state in driver should be + * cleaned. + */ + override protected def reset(): Unit = { + super.reset() + sc.executorAllocationManager.foreach(_.reset()) + } + /** * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected. * This endpoint communicates with the executors and queries the AM for an executor's exit @@ -218,6 +231,8 @@ private[spark] abstract class YarnSchedulerBackend( case None => logWarning("Attempted to check for an executor loss reason" + " before the AM has registered!") + driverEndpoint.askWithRetry[Boolean]( + RemoveExecutor(executorId, SlaveLost("AM is not yet registered."))) } } @@ -225,6 +240,13 @@ private[spark] abstract class YarnSchedulerBackend( case RegisterClusterManager(am) => logInfo(s"ApplicationMaster registered as $am") amEndpoint = Option(am) + if (!shouldResetOnAmRegister) { + shouldResetOnAmRegister = true + } else { + // AM is already registered before, this potentially means that AM failed and + // a new one registered after the failure. This will only happen in yarn-client mode. + reset() + } case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) @@ -270,6 +292,7 @@ private[spark] abstract class YarnSchedulerBackend( override def onDisconnected(remoteAddress: RpcAddress): Unit = { if (amEndpoint.exists(_.address == remoteAddress)) { logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + amEndpoint = None } } From d525dd8cbbfcdf849232cea88b5cf4ac531ef11c Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Wed, 9 Dec 2015 10:25:38 -0800 Subject: [PATCH 1091/1168] [SPARK-12241][YARN] Improve failure reporting in Yarn client obtainTokenForHBase() This lines up the HBase token logic with that done for Hive in SPARK-11265: reflection with only CFNE being swallowed. There is a test, one which doesn't try to put HBase on the yarn/test class and really do the reflection (the way the hive introspection does). If people do want that then it could be added with careful POM work +also: cut an incorrect comment from the Hive test case before copying it, and a couple of imports that may have been related to the hive test in the past. Author: Steve Loughran Closes #10227 from steveloughran/stevel/patches/SPARK-12241-obtainTokenForHBase. --- .../org/apache/spark/deploy/yarn/Client.scala | 32 ++---------- .../deploy/yarn/YarnSparkHadoopUtil.scala | 51 ++++++++++++++++++- .../yarn/YarnSparkHadoopUtilSuite.scala | 12 ++++- 3 files changed, 64 insertions(+), 31 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index f0590d2d222ec..7742ec92eb4e8 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1369,40 +1369,16 @@ object Client extends Logging { } /** - * Obtain security token for HBase. + * Obtain a security token for HBase. */ def obtainTokenForHBase( sparkConf: SparkConf, conf: Configuration, credentials: Credentials): Unit = { if (shouldGetTokens(sparkConf, "hbase") && UserGroupInformation.isSecurityEnabled) { - val mirror = universe.runtimeMirror(getClass.getClassLoader) - - try { - val confCreate = mirror.classLoader. - loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). - getMethod("create", classOf[Configuration]) - val obtainToken = mirror.classLoader. - loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). - getMethod("obtainToken", classOf[Configuration]) - - logDebug("Attempting to fetch HBase security token.") - - val hbaseConf = confCreate.invoke(null, conf).asInstanceOf[Configuration] - if ("kerberos" == hbaseConf.get("hbase.security.authentication")) { - val token = obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]] - credentials.addToken(token.getService, token) - logInfo("Added HBase security token to credentials.") - } - } catch { - case e: java.lang.NoSuchMethodException => - logInfo("HBase Method not found: " + e) - case e: java.lang.ClassNotFoundException => - logDebug("HBase Class not found: " + e) - case e: java.lang.NoClassDefFoundError => - logDebug("HBase Class not found: " + e) - case e: Exception => - logError("Exception when obtaining HBase security token: " + e) + YarnSparkHadoopUtil.get.obtainTokenForHBase(conf).foreach { token => + credentials.addToken(token.getService, token) + logInfo("Added HBase security token to credentials.") } } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index a290ebeec9001..36a2d61429887 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.{Master, JobConf} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation -import org.apache.hadoop.security.token.Token +import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -216,6 +216,55 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { None } } + + /** + * Obtain a security token for HBase. + * + * Requirements + * + * 1. `"hbase.security.authentication" == "kerberos"` + * 2. The HBase classes `HBaseConfiguration` and `TokenUtil` could be loaded + * and invoked. + * + * @param conf Hadoop configuration; an HBase configuration is created + * from this. + * @return a token if the requirements were met, `None` if not. + */ + def obtainTokenForHBase(conf: Configuration): Option[Token[TokenIdentifier]] = { + try { + obtainTokenForHBaseInner(conf) + } catch { + case e: ClassNotFoundException => + logInfo(s"HBase class not found $e") + logDebug("HBase class not found", e) + None + } + } + + /** + * Obtain a security token for HBase if `"hbase.security.authentication" == "kerberos"` + * + * @param conf Hadoop configuration; an HBase configuration is created + * from this. + * @return a token if one was needed + */ + def obtainTokenForHBaseInner(conf: Configuration): Option[Token[TokenIdentifier]] = { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + val confCreate = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). + getMethod("create", classOf[Configuration]) + val obtainToken = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). + getMethod("obtainToken", classOf[Configuration]) + val hbaseConf = confCreate.invoke(null, conf).asInstanceOf[Configuration] + if ("kerberos" == hbaseConf.get("hbase.security.authentication")) { + logDebug("Attempting to fetch HBase security token.") + Some(obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]]) + } else { + None + } + } + } object YarnSparkHadoopUtil { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index a70e66d39a64e..3fafc91a166aa 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -27,7 +27,6 @@ import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.io.Text import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers @@ -259,7 +258,6 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging assertNestedHiveException(intercept[InvocationTargetException] { util.obtainTokenForHiveMetastoreInner(hadoopConf, "alice") }) - // expect exception trapping code to unwind this hive-side exception assertNestedHiveException(intercept[InvocationTargetException] { util.obtainTokenForHiveMetastore(hadoopConf) }) @@ -276,6 +274,16 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging inner } + test("Obtain tokens For HBase") { + val hadoopConf = new Configuration() + hadoopConf.set("hbase.security.authentication", "kerberos") + val util = new YarnSparkHadoopUtil + intercept[ClassNotFoundException] { + util.obtainTokenForHBaseInner(hadoopConf) + } + util.obtainTokenForHBase(hadoopConf) should be (None) + } + // This test needs to live here because it depends on isYarnMode returning true, which can only // happen in the YARN module. test("security manager token generation") { From a012ce5978f5f89ef192cacaebf714ec6a768a61 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 9 Dec 2015 11:39:59 -0800 Subject: [PATCH 1092/1168] [SPARK-12165][SPARK-12189] Fix bugs in eviction of storage memory by execution This patch fixes a bug in the eviction of storage memory by execution. ## The bug: In general, execution should be able to evict storage memory when the total storage memory usage is greater than `maxMemory * spark.memory.storageFraction`. Due to a bug, however, Spark might wind up evicting no storage memory in certain cases where the storage memory usage was between `maxMemory * spark.memory.storageFraction` and `maxMemory`. For example, here is a regression test which illustrates the bug: ```scala val maxMemory = 1000L val taskAttemptId = 0L val (mm, ms) = makeThings(maxMemory) // Since we used the default storage fraction (0.5), we should be able to allocate 500 bytes // of storage memory which are immune to eviction by execution memory pressure. // Acquire enough storage memory to exceed the storage region size assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.executionMemoryUsed === 0L) assert(mm.storageMemoryUsed === 750L) // At this point, storage is using 250 more bytes of memory than it is guaranteed, so execution // should be able to reclaim up to 250 bytes of storage memory. // Therefore, execution should now be able to require up to 500 bytes of memory: assert(mm.acquireExecutionMemory(500L, taskAttemptId, MemoryMode.ON_HEAP) === 500L) // <--- fails by only returning 250L assert(mm.storageMemoryUsed === 500L) assert(mm.executionMemoryUsed === 500L) assertEvictBlocksToFreeSpaceCalled(ms, 250L) ``` The problem relates to the control flow / interaction between `StorageMemoryPool.shrinkPoolToReclaimSpace()` and `MemoryStore.ensureFreeSpace()`. While trying to allocate the 500 bytes of execution memory, the `UnifiedMemoryManager` discovers that it will need to reclaim 250 bytes of memory from storage, so it calls `StorageMemoryPool.shrinkPoolToReclaimSpace(250L)`. This method, in turn, calls `MemoryStore.ensureFreeSpace(250L)`. However, `ensureFreeSpace()` first checks whether the requested space is less than `maxStorageMemory - storageMemoryUsed`, which will be true if there is any free execution memory because it turns out that `MemoryStore.maxStorageMemory = (maxMemory - onHeapExecutionMemoryPool.memoryUsed)` when the `UnifiedMemoryManager` is used. The control flow here is somewhat confusing (it grew to be messy / confusing over time / as a result of the merging / refactoring of several components). In the pre-Spark 1.6 code, `ensureFreeSpace` was called directly by the `MemoryStore` itself, whereas in 1.6 it's involved in a confusing control flow where `MemoryStore` calls `MemoryManager.acquireStorageMemory`, which then calls back into `MemoryStore.ensureFreeSpace`, which, in turn, calls `MemoryManager.freeStorageMemory`. ## The solution: The solution implemented in this patch is to remove the confusing circular control flow between `MemoryManager` and `MemoryStore`, making the storage memory acquisition process much more linear / straightforward. The key changes: - Remove a layer of inheritance which made the memory manager code harder to understand (53841174760a24a0df3eb1562af1f33dbe340eb9). - Move some bounds checks earlier in the call chain (13ba7ada77f87ef1ec362aec35c89a924e6987cb). - Refactor `ensureFreeSpace()` so that the part which evicts blocks can be called independently from the part which checks whether there is enough free space to avoid eviction (7c68ca09cb1b12f157400866983f753ac863380e). - Realize that this lets us remove a layer of overloads from `ensureFreeSpace` (eec4f6c87423d5e482b710e098486b3bbc4daf06). - Realize that `ensureFreeSpace()` can simply be replaced with an `evictBlocksToFreeSpace()` method which is called [after we've already figured out](https://github.com/apache/spark/blob/2dc842aea82c8895125d46a00aa43dfb0d121de9/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala#L88) how much memory needs to be reclaimed via eviction; (2dc842aea82c8895125d46a00aa43dfb0d121de9). Along the way, I fixed some problems with the mocks in `MemoryManagerSuite`: the old mocks would [unconditionally](https://github.com/apache/spark/blob/80a824d36eec9d9a9f092ee1741453851218ec73/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala#L84) report that a block had been evicted even if there was enough space in the storage pool such that eviction would be avoided. I also fixed a problem where `StorageMemoryPool._memoryUsed` might become negative due to freed memory being double-counted when excution evicts storage. The problem was that `StorageMemoryPoolshrinkPoolToFreeSpace` would [decrement `_memoryUsed`](https://github.com/apache/spark/commit/7c68ca09cb1b12f157400866983f753ac863380e#diff-935c68a9803be144ed7bafdd2f756a0fL133) even though `StorageMemoryPool.freeMemory` had already decremented it as each evicted block was freed. See SPARK-12189 for details. Author: Josh Rosen Author: Andrew Or Closes #10170 from JoshRosen/SPARK-12165. --- .../apache/spark/memory/MemoryManager.scala | 11 +- .../spark/memory/StaticMemoryManager.scala | 37 ++++- .../spark/memory/StorageMemoryPool.scala | 37 +++-- .../spark/memory/UnifiedMemoryManager.scala | 8 +- .../apache/spark/storage/MemoryStore.scala | 76 ++-------- .../spark/memory/MemoryManagerSuite.scala | 137 +++++++++--------- .../memory/StaticMemoryManagerSuite.scala | 52 ++++--- .../memory/UnifiedMemoryManagerSuite.scala | 76 +++++++--- 8 files changed, 230 insertions(+), 204 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index ceb8ea434e1be..ae9e1ac0e246b 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -77,9 +77,7 @@ private[spark] abstract class MemoryManager( def acquireStorageMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) - } + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean /** * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. @@ -109,12 +107,7 @@ private[spark] abstract class MemoryManager( def acquireExecutionMemory( numBytes: Long, taskAttemptId: Long, - memoryMode: MemoryMode): Long = synchronized { - memoryMode match { - case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) - case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) - } - } + memoryMode: MemoryMode): Long /** * Release numBytes of execution memory belonging to the given task. diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala index 12a094306861f..3554b558f2123 100644 --- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -49,19 +49,50 @@ private[spark] class StaticMemoryManager( } // Max number of bytes worth of blocks to evict when unrolling - private val maxMemoryToEvictForUnroll: Long = { + private val maxUnrollMemory: Long = { (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong } + override def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + if (numBytes > maxStorageMemory) { + // Fail fast if the block simply won't fit + logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " + + s"memory limit ($maxStorageMemory bytes)") + false + } else { + storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) + } + } + override def acquireUnrollMemory( blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { val currentUnrollMemory = storageMemoryPool.memoryStore.currentUnrollMemory - val maxNumBytesToFree = math.max(0, maxMemoryToEvictForUnroll - currentUnrollMemory) - val numBytesToFree = math.min(numBytes, maxNumBytesToFree) + val freeMemory = storageMemoryPool.memoryFree + // When unrolling, we will use all of the existing free memory, and, if necessary, + // some extra space freed from evicting cached blocks. We must place a cap on the + // amount of memory to be evicted by unrolling, however, otherwise unrolling one + // big block can blow away the entire cache. + val maxNumBytesToFree = math.max(0, maxUnrollMemory - currentUnrollMemory - freeMemory) + // Keep it within the range 0 <= X <= maxNumBytesToFree + val numBytesToFree = math.max(0, math.min(maxNumBytesToFree, numBytes - freeMemory)) storageMemoryPool.acquireMemory(blockId, numBytes, numBytesToFree, evictedBlocks) } + + private[memory] + override def acquireExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Long = synchronized { + memoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + } + } } diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala index fc4f0357e9f16..70af83b5ee092 100644 --- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -65,7 +65,8 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = lock.synchronized { - acquireMemory(blockId, numBytes, numBytes, evictedBlocks) + val numBytesToFree = math.max(0, numBytes - memoryFree) + acquireMemory(blockId, numBytes, numBytesToFree, evictedBlocks) } /** @@ -73,7 +74,7 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w * * @param blockId the ID of the block we are acquiring storage memory for * @param numBytesToAcquire the size of this block - * @param numBytesToFree the size of space to be freed through evicting blocks + * @param numBytesToFree the amount of space to be freed through evicting blocks * @return whether all N bytes were successfully granted. */ def acquireMemory( @@ -84,16 +85,18 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w assert(numBytesToAcquire >= 0) assert(numBytesToFree >= 0) assert(memoryUsed <= poolSize) - memoryStore.ensureFreeSpace(blockId, numBytesToFree, evictedBlocks) - // Register evicted blocks, if any, with the active task metrics - Option(TaskContext.get()).foreach { tc => - val metrics = tc.taskMetrics() - val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) + if (numBytesToFree > 0) { + memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree, evictedBlocks) + // Register evicted blocks, if any, with the active task metrics + Option(TaskContext.get()).foreach { tc => + val metrics = tc.taskMetrics() + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) + } } // NOTE: If the memory store evicts blocks, then those evictions will synchronously call - // back into this StorageMemoryPool in order to free. Therefore, these variables should have - // been updated. + // back into this StorageMemoryPool in order to free memory. Therefore, these variables + // should have been updated. val enoughMemory = numBytesToAcquire <= memoryFree if (enoughMemory) { _memoryUsed += numBytesToAcquire @@ -121,18 +124,20 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w */ def shrinkPoolToFreeSpace(spaceToFree: Long): Long = lock.synchronized { // First, shrink the pool by reclaiming free memory: - val spaceFreedByReleasingUnusedMemory = Math.min(spaceToFree, memoryFree) + val spaceFreedByReleasingUnusedMemory = math.min(spaceToFree, memoryFree) decrementPoolSize(spaceFreedByReleasingUnusedMemory) - if (spaceFreedByReleasingUnusedMemory == spaceToFree) { - spaceFreedByReleasingUnusedMemory - } else { + val remainingSpaceToFree = spaceToFree - spaceFreedByReleasingUnusedMemory + if (remainingSpaceToFree > 0) { // If reclaiming free memory did not adequately shrink the pool, begin evicting blocks: val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - memoryStore.ensureFreeSpace(spaceToFree - spaceFreedByReleasingUnusedMemory, evictedBlocks) + memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree, evictedBlocks) val spaceFreedByEviction = evictedBlocks.map(_._2.memSize).sum - _memoryUsed -= spaceFreedByEviction + // When a block is released, BlockManager.dropFromMemory() calls releaseMemory(), so we do + // not need to decrement _memoryUsed here. However, we do need to decrement the pool size. decrementPoolSize(spaceFreedByEviction) spaceFreedByReleasingUnusedMemory + spaceFreedByEviction + } else { + spaceFreedByReleasingUnusedMemory } } } diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 0f1ea9ab39c07..0b9f6a9dc0525 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -100,7 +100,7 @@ private[spark] class UnifiedMemoryManager private[memory] ( case MemoryMode.OFF_HEAP => // For now, we only support on-heap caching of data, so we do not need to interact with // the storage pool when allocating off-heap memory. This will change in the future, though. - super.acquireExecutionMemory(numBytes, taskAttemptId, memoryMode) + offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) } } @@ -110,6 +110,12 @@ private[spark] class UnifiedMemoryManager private[memory] ( evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) assert(numBytes >= 0) + if (numBytes > maxStorageMemory) { + // Fail fast if the block simply won't fit + logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " + + s"memory limit ($maxStorageMemory bytes)") + return false + } if (numBytes > storageMemoryPool.memoryFree) { // There is not enough free memory in the storage pool, so try to borrow free memory from // the execution pool. diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 4dbac388e098b..bdab8c2332fae 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -406,85 +406,41 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } /** - * Try to free up a given amount of space by evicting existing blocks. - * - * @param space the amount of memory to free, in bytes - * @param droppedBlocks a holder for blocks evicted in the process - * @return whether the requested free space is freed. - */ - private[spark] def ensureFreeSpace( - space: Long, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - ensureFreeSpace(None, space, droppedBlocks) - } - - /** - * Try to free up a given amount of space to store a block by evicting existing ones. - * - * @param space the amount of memory to free, in bytes - * @param droppedBlocks a holder for blocks evicted in the process - * @return whether the requested free space is freed. - */ - private[spark] def ensureFreeSpace( - blockId: BlockId, - space: Long, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - ensureFreeSpace(Some(blockId), space, droppedBlocks) - } - - /** - * Try to free up a given amount of space to store a particular block, but can fail if - * either the block is bigger than our memory or it would require replacing another block - * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that - * don't fit into memory that we want to avoid). - * - * @param blockId the ID of the block we are freeing space for, if any - * @param space the size of this block - * @param droppedBlocks a holder for blocks evicted in the process - * @return whether the requested free space is freed. - */ - private def ensureFreeSpace( + * Try to evict blocks to free up a given amount of space to store a particular block. + * Can fail if either the block is bigger than our memory or it would require replacing + * another block from the same RDD (which leads to a wasteful cyclic replacement pattern for + * RDDs that don't fit into memory that we want to avoid). + * + * @param blockId the ID of the block we are freeing space for, if any + * @param space the size of this block + * @param droppedBlocks a holder for blocks evicted in the process + * @return whether the requested free space is freed. + */ + private[spark] def evictBlocksToFreeSpace( blockId: Option[BlockId], space: Long, droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + assert(space > 0) memoryManager.synchronized { - val freeMemory = maxMemory - memoryUsed + var freedMemory = 0L val rddToAdd = blockId.flatMap(getRddId) val selectedBlocks = new ArrayBuffer[BlockId] - var selectedMemory = 0L - - logInfo(s"Ensuring $space bytes of free space " + - blockId.map { id => s"for block $id" }.getOrElse("") + - s"(free: $freeMemory, max: $maxMemory)") - - // Fail fast if the block simply won't fit - if (space > maxMemory) { - logInfo("Will not " + blockId.map { id => s"store $id" }.getOrElse("free memory") + - s" as the required space ($space bytes) exceeds our memory limit ($maxMemory bytes)") - return false - } - - // No need to evict anything if there is already enough free space - if (freeMemory >= space) { - return true - } - // This is synchronized to ensure that the set of entries is not changed // (because of getValue or getBytes) while traversing the iterator, as that // can lead to exceptions. entries.synchronized { val iterator = entries.entrySet().iterator() - while (freeMemory + selectedMemory < space && iterator.hasNext) { + while (freedMemory < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { selectedBlocks += blockId - selectedMemory += pair.getValue.size + freedMemory += pair.getValue.size } } } - if (freeMemory + selectedMemory >= space) { + if (freedMemory >= space) { logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index f55d435fa33a6..555b640cb4244 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -24,9 +24,10 @@ import scala.concurrent.duration.Duration import scala.concurrent.{Await, ExecutionContext, Future} import org.mockito.Matchers.{any, anyLong} -import org.mockito.Mockito.{mock, when} +import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite @@ -36,105 +37,105 @@ import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, StorageLevel /** * Helper trait for sharing code among [[MemoryManager]] tests. */ -private[memory] trait MemoryManagerSuite extends SparkFunSuite { +private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAfterEach { - import MemoryManagerSuite.DEFAULT_ENSURE_FREE_SPACE_CALLED + protected val evictedBlocks = new mutable.ArrayBuffer[(BlockId, BlockStatus)] + + import MemoryManagerSuite.DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED // Note: Mockito's verify mechanism does not provide a way to reset method call counts // without also resetting stubbed methods. Since our test code relies on the latter, - // we need to use our own variable to track invocations of `ensureFreeSpace`. + // we need to use our own variable to track invocations of `evictBlocksToFreeSpace`. /** - * The amount of free space requested in the last call to [[MemoryStore.ensureFreeSpace]] + * The amount of space requested in the last call to [[MemoryStore.evictBlocksToFreeSpace]]. * - * This set whenever [[MemoryStore.ensureFreeSpace]] is called, and cleared when the test - * code makes explicit assertions on this variable through [[assertEnsureFreeSpaceCalled]]. + * This set whenever [[MemoryStore.evictBlocksToFreeSpace]] is called, and cleared when the test + * code makes explicit assertions on this variable through + * [[assertEvictBlocksToFreeSpaceCalled]]. */ - private val ensureFreeSpaceCalled = new AtomicLong(DEFAULT_ENSURE_FREE_SPACE_CALLED) + private val evictBlocksToFreeSpaceCalled = new AtomicLong(0) + + override def beforeEach(): Unit = { + super.beforeEach() + evictedBlocks.clear() + evictBlocksToFreeSpaceCalled.set(DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED) + } /** - * Make a mocked [[MemoryStore]] whose [[MemoryStore.ensureFreeSpace]] method is stubbed. + * Make a mocked [[MemoryStore]] whose [[MemoryStore.evictBlocksToFreeSpace]] method is stubbed. * - * This allows our test code to release storage memory when [[MemoryStore.ensureFreeSpace]] - * is called without relying on [[org.apache.spark.storage.BlockManager]] and all of its - * dependencies. + * This allows our test code to release storage memory when these methods are called + * without relying on [[org.apache.spark.storage.BlockManager]] and all of its dependencies. */ protected def makeMemoryStore(mm: MemoryManager): MemoryStore = { - val ms = mock(classOf[MemoryStore]) - when(ms.ensureFreeSpace(anyLong(), any())).thenAnswer(ensureFreeSpaceAnswer(mm, 0)) - when(ms.ensureFreeSpace(any(), anyLong(), any())).thenAnswer(ensureFreeSpaceAnswer(mm, 1)) + val ms = mock(classOf[MemoryStore], RETURNS_SMART_NULLS) + when(ms.evictBlocksToFreeSpace(any(), anyLong(), any())) + .thenAnswer(evictBlocksToFreeSpaceAnswer(mm)) mm.setMemoryStore(ms) ms } /** - * Make an [[Answer]] that stubs [[MemoryStore.ensureFreeSpace]] with the right arguments. - */ - private def ensureFreeSpaceAnswer(mm: MemoryManager, numBytesPos: Int): Answer[Boolean] = { + * Simulate the part of [[MemoryStore.evictBlocksToFreeSpace]] that releases storage memory. + * + * This is a significant simplification of the real method, which actually drops existing + * blocks based on the size of each block. Instead, here we simply release as many bytes + * as needed to ensure the requested amount of free space. This allows us to set up the + * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in + * many other dependencies. + * + * Every call to this method will set a global variable, [[evictBlocksToFreeSpaceCalled]], that + * records the number of bytes this is called with. This variable is expected to be cleared + * by the test code later through [[assertEvictBlocksToFreeSpaceCalled]]. + */ + private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Boolean] = { new Answer[Boolean] { override def answer(invocation: InvocationOnMock): Boolean = { val args = invocation.getArguments - require(args.size > numBytesPos, s"bad test: expected >$numBytesPos arguments " + - s"in ensureFreeSpace, found ${args.size}") - require(args(numBytesPos).isInstanceOf[Long], s"bad test: expected ensureFreeSpace " + - s"argument at index $numBytesPos to be a Long: ${args.mkString(", ")}") - val numBytes = args(numBytesPos).asInstanceOf[Long] - val success = mockEnsureFreeSpace(mm, numBytes) - if (success) { + val numBytesToFree = args(1).asInstanceOf[Long] + assert(numBytesToFree > 0) + require(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED, + "bad test: evictBlocksToFreeSpace() variable was not reset") + evictBlocksToFreeSpaceCalled.set(numBytesToFree) + if (numBytesToFree <= mm.storageMemoryUsed) { + // We can evict enough blocks to fulfill the request for space + mm.releaseStorageMemory(numBytesToFree) args.last.asInstanceOf[mutable.Buffer[(BlockId, BlockStatus)]].append( - (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytes, 0L, 0L))) + (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L, 0L))) + // We need to add this call so that that the suite-level `evictedBlocks` is updated when + // execution evicts storage; in that case, args.last will not be equal to evictedBlocks + // because it will be a temporary buffer created inside of the MemoryManager rather than + // being passed in by the test code. + if (!(evictedBlocks eq args.last)) { + evictedBlocks.append( + (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L, 0L))) + } + true + } else { + // No blocks were evicted because eviction would not free enough space. + false } - success - } - } - } - - /** - * Simulate the part of [[MemoryStore.ensureFreeSpace]] that releases storage memory. - * - * This is a significant simplification of the real method, which actually drops existing - * blocks based on the size of each block. Instead, here we simply release as many bytes - * as needed to ensure the requested amount of free space. This allows us to set up the - * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in - * many other dependencies. - * - * Every call to this method will set a global variable, [[ensureFreeSpaceCalled]], that - * records the number of bytes this is called with. This variable is expected to be cleared - * by the test code later through [[assertEnsureFreeSpaceCalled]]. - */ - private def mockEnsureFreeSpace(mm: MemoryManager, numBytes: Long): Boolean = mm.synchronized { - require(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED, - "bad test: ensure free space variable was not reset") - // Record the number of bytes we freed this call - ensureFreeSpaceCalled.set(numBytes) - if (numBytes <= mm.maxStorageMemory) { - def freeMemory = mm.maxStorageMemory - mm.storageMemoryUsed - val spaceToRelease = numBytes - freeMemory - if (spaceToRelease > 0) { - mm.releaseStorageMemory(spaceToRelease) } - freeMemory >= numBytes - } else { - // We attempted to free more bytes than our max allowable memory - false } } /** - * Assert that [[MemoryStore.ensureFreeSpace]] is called with the given parameters. + * Assert that [[MemoryStore.evictBlocksToFreeSpace]] is called with the given parameters. */ - protected def assertEnsureFreeSpaceCalled(ms: MemoryStore, numBytes: Long): Unit = { - assert(ensureFreeSpaceCalled.get() === numBytes, - s"expected ensure free space to be called with $numBytes") - ensureFreeSpaceCalled.set(DEFAULT_ENSURE_FREE_SPACE_CALLED) + protected def assertEvictBlocksToFreeSpaceCalled(ms: MemoryStore, numBytes: Long): Unit = { + assert(evictBlocksToFreeSpaceCalled.get() === numBytes, + s"expected evictBlocksToFreeSpace() to be called with $numBytes") + evictBlocksToFreeSpaceCalled.set(DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED) } /** - * Assert that [[MemoryStore.ensureFreeSpace]] is NOT called. + * Assert that [[MemoryStore.evictBlocksToFreeSpace]] is NOT called. */ - protected def assertEnsureFreeSpaceNotCalled[T](ms: MemoryStore): Unit = { - assert(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED, - "ensure free space should not have been called!") + protected def assertEvictBlocksToFreeSpaceNotCalled[T](ms: MemoryStore): Unit = { + assert(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED, + "evictBlocksToFreeSpace() should not have been called!") + assert(evictedBlocks.isEmpty) } /** @@ -291,5 +292,5 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { } private object MemoryManagerSuite { - private val DEFAULT_ENSURE_FREE_SPACE_CALLED = -1L + private val DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED = -1L } diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 54cb28c389c2f..6700b94f0f57f 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.memory -import scala.collection.mutable.ArrayBuffer - import org.mockito.Mockito.when import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} +import org.apache.spark.storage.{MemoryStore, TestBlockId} class StaticMemoryManagerSuite extends MemoryManagerSuite { private val conf = new SparkConf().set("spark.storage.unrollFraction", "0.4") - private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] /** * Make a [[StaticMemoryManager]] and a [[MemoryStore]] with limited class dependencies. @@ -85,33 +82,38 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) - // `ensureFreeSpace` should be called with the number of bytes requested - assertEnsureFreeSpaceCalled(ms, 10L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 10L) + assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire more than the max, not granted assert(!mm.acquireStorageMemory(dummyBlock, maxStorageMem + 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, maxStorageMem + 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire up to the max, requests after this are still granted due to LRU eviction assert(mm.acquireStorageMemory(dummyBlock, maxStorageMem, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1000L) + assertEvictBlocksToFreeSpaceCalled(ms, 110L) assert(mm.storageMemoryUsed === 1000L) assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceCalled(ms, 1L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() + // Note: We evicted 1 byte to put another 1-byte block in, so the storage memory used remains at + // 1000 bytes. This is different from real behavior, where the 1-byte block would have evicted + // the 1000-byte block entirely. This is set up differently so we can write finer-grained tests. assert(mm.storageMemoryUsed === 1000L) mm.releaseStorageMemory(800L) assert(mm.storageMemoryUsed === 200L) // Acquire after release assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 201L) mm.releaseAllStorageMemory() assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 1L) // Release beyond what was acquired mm.releaseStorageMemory(100L) @@ -133,7 +135,7 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { assert(mm.executionMemoryUsed === 200L) // Only storage memory should increase assert(mm.acquireStorageMemory(dummyBlock, 50L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 50L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 50L) assert(mm.executionMemoryUsed === 200L) // Only execution memory should be released @@ -151,21 +153,25 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { val dummyBlock = TestBlockId("lonely water") val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) assert(mm.acquireUnrollMemory(dummyBlock, 100L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 100L) + when(ms.currentUnrollMemory).thenReturn(100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 100L) mm.releaseUnrollMemory(40L) assert(mm.storageMemoryUsed === 60L) when(ms.currentUnrollMemory).thenReturn(60L) - assert(mm.acquireUnrollMemory(dummyBlock, 500L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 800L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 860L) // `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes. - // Since we already occupy 60 bytes, we will try to ensure only 400 - 60 = 340 bytes. - assertEnsureFreeSpaceCalled(ms, 340L) - assert(mm.storageMemoryUsed === 560L) - when(ms.currentUnrollMemory).thenReturn(560L) - assert(!mm.acquireUnrollMemory(dummyBlock, 800L, evictedBlocks)) - assert(mm.storageMemoryUsed === 560L) - // We already have 560 bytes > the max unroll space of 400 bytes, so no bytes are freed - assertEnsureFreeSpaceCalled(ms, 0L) + // Since we already occupy 60 bytes, we will try to evict only 400 - 60 = 340 bytes. + assert(mm.acquireUnrollMemory(dummyBlock, 240L, evictedBlocks)) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) + when(ms.currentUnrollMemory).thenReturn(300L) // 60 + 240 + assert(mm.storageMemoryUsed === 1000L) + evictedBlocks.clear() + assert(!mm.acquireUnrollMemory(dummyBlock, 150L, evictedBlocks)) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) // 400 - 300 + assert(mm.storageMemoryUsed === 900L) // 100 bytes were evicted // Release beyond what was acquired mm.releaseUnrollMemory(maxStorageMem) assert(mm.storageMemoryUsed === 0L) diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index e97c898a44783..71221deeb4c28 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.memory -import scala.collection.mutable.ArrayBuffer - import org.scalatest.PrivateMethodTester import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} +import org.apache.spark.storage.{MemoryStore, TestBlockId} class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTester { private val dummyBlock = TestBlockId("--") - private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] private val storageFraction: Double = 0.5 @@ -78,33 +75,40 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val (mm, ms) = makeThings(maxMemory) assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) - // `ensureFreeSpace` should be called with the number of bytes requested - assertEnsureFreeSpaceCalled(ms, 10L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 10L) + assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire more than the max, not granted assert(!mm.acquireStorageMemory(dummyBlock, maxMemory + 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, maxMemory + 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire up to the max, requests after this are still granted due to LRU eviction assert(mm.acquireStorageMemory(dummyBlock, maxMemory, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1000L) + assertEvictBlocksToFreeSpaceCalled(ms, 110L) assert(mm.storageMemoryUsed === 1000L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceCalled(ms, 1L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() + // Note: We evicted 1 byte to put another 1-byte block in, so the storage memory used remains at + // 1000 bytes. This is different from real behavior, where the 1-byte block would have evicted + // the 1000-byte block entirely. This is set up differently so we can write finer-grained tests. assert(mm.storageMemoryUsed === 1000L) mm.releaseStorageMemory(800L) assert(mm.storageMemoryUsed === 200L) // Acquire after release assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 201L) mm.releaseAllStorageMemory() assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 1L) // Release beyond what was acquired mm.releaseStorageMemory(100L) @@ -117,25 +121,27 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val (mm, ms) = makeThings(maxMemory) // Acquire enough storage memory to exceed the storage region assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 750L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.executionMemoryUsed === 0L) assert(mm.storageMemoryUsed === 750L) // Execution needs to request 250 bytes to evict storage memory assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) assert(mm.executionMemoryUsed === 100L) assert(mm.storageMemoryUsed === 750L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) // Execution wants 200 bytes but only 150 are free, so storage is evicted assert(mm.acquireExecutionMemory(200L, taskAttemptId, MemoryMode.ON_HEAP) === 200L) assert(mm.executionMemoryUsed === 300L) - assertEnsureFreeSpaceCalled(ms, 50L) - assert(mm.executionMemoryUsed === 300L) + assert(mm.storageMemoryUsed === 700L) + assertEvictBlocksToFreeSpaceCalled(ms, 50L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() mm.releaseAllStorageMemory() require(mm.executionMemoryUsed === 300L) require(mm.storageMemoryUsed === 0, "bad test: all storage memory should have been released") // Acquire some storage memory again, but this time keep it within the storage region assert(mm.acquireStorageMemory(dummyBlock, 400L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 400L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 400L) assert(mm.executionMemoryUsed === 300L) // Execution cannot evict storage because the latter is within the storage fraction, @@ -143,7 +149,27 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(mm.acquireExecutionMemory(400L, taskAttemptId, MemoryMode.ON_HEAP) === 300L) assert(mm.executionMemoryUsed === 600L) assert(mm.storageMemoryUsed === 400L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) + } + + test("execution memory requests smaller than free memory should evict storage (SPARK-12165)") { + val maxMemory = 1000L + val taskAttemptId = 0L + val (mm, ms) = makeThings(maxMemory) + // Acquire enough storage memory to exceed the storage region size + assert(mm.acquireStorageMemory(dummyBlock, 700L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.executionMemoryUsed === 0L) + assert(mm.storageMemoryUsed === 700L) + // SPARK-12165: previously, MemoryStore would not evict anything because it would + // mistakenly think that the 300 bytes of free space was still available even after + // using it to expand the execution pool. Consequently, no storage memory was released + // and the following call granted only 300 bytes to execution. + assert(mm.acquireExecutionMemory(500L, taskAttemptId, MemoryMode.ON_HEAP) === 500L) + assertEvictBlocksToFreeSpaceCalled(ms, 200L) + assert(mm.storageMemoryUsed === 500L) + assert(mm.executionMemoryUsed === 500L) + assert(evictedBlocks.nonEmpty) } test("storage does not evict execution") { @@ -154,32 +180,34 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(mm.acquireExecutionMemory(800L, taskAttemptId, MemoryMode.ON_HEAP) === 800L) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 0L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) // Storage should not be able to evict execution assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 100L) - assertEnsureFreeSpaceCalled(ms, 100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(!mm.acquireStorageMemory(dummyBlock, 250L, evictedBlocks)) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 100L) - assertEnsureFreeSpaceCalled(ms, 250L) + // Do not attempt to evict blocks, since evicting will not free enough memory: + assertEvictBlocksToFreeSpaceNotCalled(ms) mm.releaseExecutionMemory(maxMemory, taskAttemptId, MemoryMode.ON_HEAP) mm.releaseStorageMemory(maxMemory) // Acquire some execution memory again, but this time keep it within the execution region assert(mm.acquireExecutionMemory(200L, taskAttemptId, MemoryMode.ON_HEAP) === 200L) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 0L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) // Storage should still not be able to evict execution assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 750L) - assertEnsureFreeSpaceCalled(ms, 750L) + assertEvictBlocksToFreeSpaceNotCalled(ms) // since there were 800 bytes free assert(!mm.acquireStorageMemory(dummyBlock, 850L, evictedBlocks)) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 750L) - assertEnsureFreeSpaceCalled(ms, 850L) + // Do not attempt to evict blocks, since evicting will not free enough memory: + assertEvictBlocksToFreeSpaceNotCalled(ms) } test("small heap") { From 874635ec86a33efe5187c436c08f8c338275f323 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 9 Dec 2015 19:47:38 +0000 Subject: [PATCH 1093/1168] [SPARK-11824][WEBUI] WebUI does not render descriptions with 'bad' HTML, throws console error Don't warn when description isn't valid HTML since it may properly be like "SELECT ... where foo <= 1" The tests for this code indicate that it's normal to handle strings like this that don't contain HTML as a string rather than markup. Hence logging every such instance as a warning is too noisy since it's not a problem. this is an issue for stages whose name contain SQL like the above CC tdas as author of this bit of code Author: Sean Owen Closes #10159 from srowen/SPARK-11824. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 1e8194f57888e..81a6f07ec836a 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -448,7 +448,6 @@ private[spark] object UIUtils extends Logging { new RuleTransformer(rule).transform(xml) } catch { case NonFatal(e) => - logWarning(s"Invalid job description: $desc ", e) {desc} } } From 6ae8e1e7e54b955111d0a452a04173f30f91a6e6 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 9 Dec 2015 12:00:48 -0800 Subject: [PATCH 1094/1168] [SPARK-11551][DOC] Replace example code in ml-features.md using include_example PR on behalf of somideshmukh, thanks! Author: Xusen Yin Author: somideshmukh Closes #10219 from yinxusen/SPARK-11551. --- docs/ml-features.md | 1112 +---------------- .../examples/ml/JavaBinarizerExample.java | 68 + .../examples/ml/JavaBucketizerExample.java | 71 ++ .../spark/examples/ml/JavaDCTExample.java | 65 + .../ml/JavaElementwiseProductExample.java | 75 ++ .../examples/ml/JavaMinMaxScalerExample.java | 51 + .../spark/examples/ml/JavaNGramExample.java | 71 ++ .../examples/ml/JavaNormalizerExample.java | 54 + .../examples/ml/JavaOneHotEncoderExample.java | 78 ++ .../spark/examples/ml/JavaPCAExample.java | 71 ++ .../ml/JavaPolynomialExpansionExample.java | 71 ++ .../examples/ml/JavaRFormulaExample.java | 69 + .../ml/JavaStandardScalerExample.java | 54 + .../ml/JavaStopWordsRemoverExample.java | 65 + .../examples/ml/JavaStringIndexerExample.java | 66 + .../examples/ml/JavaTokenizerExample.java | 75 ++ .../ml/JavaVectorAssemblerExample.java | 67 + .../examples/ml/JavaVectorIndexerExample.java | 61 + .../examples/ml/JavaVectorSlicerExample.java | 73 ++ .../src/main/python/ml/binarizer_example.py | 43 + .../src/main/python/ml/bucketizer_example.py | 43 + .../python/ml/elementwise_product_example.py | 39 + examples/src/main/python/ml/n_gram_example.py | 42 + .../src/main/python/ml/normalizer_example.py | 43 + .../main/python/ml/onehot_encoder_example.py | 48 + examples/src/main/python/ml/pca_example.py | 42 + .../python/ml/polynomial_expansion_example.py | 43 + .../src/main/python/ml/rformula_example.py | 44 + .../main/python/ml/standard_scaler_example.py | 43 + .../python/ml/stopwords_remover_example.py | 40 + .../main/python/ml/string_indexer_example.py | 39 + .../src/main/python/ml/tokenizer_example.py | 44 + .../python/ml/vector_assembler_example.py | 42 + .../main/python/ml/vector_indexer_example.py | 40 + .../spark/examples/ml/BinarizerExample.scala | 48 + .../spark/examples/ml/BucketizerExample.scala | 52 + .../apache/spark/examples/ml/DCTExample.scala | 54 + .../ml/ElementWiseProductExample.scala | 52 + .../examples/ml/MinMaxScalerExample.scala | 50 + .../spark/examples/ml/NGramExample.scala | 47 + .../spark/examples/ml/NormalizerExample.scala | 52 + .../examples/ml/OneHotEncoderExample.scala | 58 + .../apache/spark/examples/ml/PCAExample.scala | 53 + .../ml/PolynomialExpansionExample.scala | 51 + .../spark/examples/ml/RFormulaExample.scala | 49 + .../examples/ml/StandardScalerExample.scala | 52 + .../examples/ml/StopWordsRemoverExample.scala | 48 + .../examples/ml/StringIndexerExample.scala | 48 + .../spark/examples/ml/TokenizerExample.scala | 54 + .../examples/ml/VectorAssemblerExample.scala | 49 + .../examples/ml/VectorIndexerExample.scala | 54 + .../examples/ml/VectorSlicerExample.scala | 58 + 52 files changed, 2820 insertions(+), 1061 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java create mode 100644 examples/src/main/python/ml/binarizer_example.py create mode 100644 examples/src/main/python/ml/bucketizer_example.py create mode 100644 examples/src/main/python/ml/elementwise_product_example.py create mode 100644 examples/src/main/python/ml/n_gram_example.py create mode 100644 examples/src/main/python/ml/normalizer_example.py create mode 100644 examples/src/main/python/ml/onehot_encoder_example.py create mode 100644 examples/src/main/python/ml/pca_example.py create mode 100644 examples/src/main/python/ml/polynomial_expansion_example.py create mode 100644 examples/src/main/python/ml/rformula_example.py create mode 100644 examples/src/main/python/ml/standard_scaler_example.py create mode 100644 examples/src/main/python/ml/stopwords_remover_example.py create mode 100644 examples/src/main/python/ml/string_indexer_example.py create mode 100644 examples/src/main/python/ml/tokenizer_example.py create mode 100644 examples/src/main/python/ml/vector_assembler_example.py create mode 100644 examples/src/main/python/ml/vector_indexer_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 55e401221917e..7ad7c4eb7ea65 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -170,25 +170,7 @@ Refer to the [Tokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.fea and the [RegexTokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} - -val sentenceDataFrame = sqlContext.createDataFrame(Seq( - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -)).toDF("label", "sentence") -val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) - -val tokenized = tokenizer.transform(sentenceDataFrame) -tokenized.select("words", "label").take(3).foreach(println) -val regexTokenized = regexTokenizer.transform(sentenceDataFrame) -regexTokenized.select("words", "label").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/TokenizerExample.scala %}
    @@ -197,44 +179,7 @@ Refer to the [Tokenizer Java docs](api/java/org/apache/spark/ml/feature/Tokenize and the [RegexTokenizer Java docs](api/java/org/apache/spark/ml/feature/RegexTokenizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RegexTokenizer; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(1, "I wish Java could use case classes"), - RowFactory.create(2, "Logistic,regression,models,are,neat") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); -Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); -DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); -for (Row r : wordsDataFrame.select("words", "label").take(3)) { - java.util.List words = r.getList(0); - for (String word : words) System.out.print(word + " "); - System.out.println(); -} - -RegexTokenizer regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaTokenizerExample.java %}
    @@ -243,21 +188,7 @@ Refer to the [Tokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.featu the the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Tokenizer, RegexTokenizer - -sentenceDataFrame = sqlContext.createDataFrame([ - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -], ["label", "sentence"]) -tokenizer = Tokenizer(inputCol="sentence", outputCol="words") -wordsDataFrame = tokenizer.transform(sentenceDataFrame) -for words_label in wordsDataFrame.select("words", "label").take(3): - print(words_label) -regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") -# alternatively, pattern="\\w+", gaps(False) -{% endhighlight %} +{% include_example python/ml/tokenizer_example.py %}
    @@ -306,19 +237,7 @@ filtered out. Refer to the [StopWordsRemover Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StopWordsRemover - -val remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered") -val dataSet = sqlContext.createDataFrame(Seq( - (0, Seq("I", "saw", "the", "red", "baloon")), - (1, Seq("Mary", "had", "a", "little", "lamb")) -)).toDF("id", "raw") - -remover.transform(dataSet).show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala %}
    @@ -326,34 +245,7 @@ remover.transform(dataSet).show() Refer to the [StopWordsRemover Java docs](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StopWordsRemover; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -StopWordsRemover remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered"); - -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), - RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) -)); -StructType schema = new StructType(new StructField[] { - new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame dataset = jsql.createDataFrame(rdd, schema); - -remover.transform(dataset).show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java %}
    @@ -361,17 +253,7 @@ remover.transform(dataset).show(); Refer to the [StopWordsRemover Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StopWordsRemover - -sentenceData = sqlContext.createDataFrame([ - (0, ["I", "saw", "the", "red", "baloon"]), - (1, ["Mary", "had", "a", "little", "lamb"]) -], ["label", "raw"]) - -remover = StopWordsRemover(inputCol="raw", outputCol="filtered") -remover.transform(sentenceData).show(truncate=False) -{% endhighlight %} +{% include_example python/ml/stopwords_remover_example.py %}
    @@ -388,19 +270,7 @@ An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (t Refer to the [NGram Scala docs](api/scala/index.html#org.apache.spark.ml.feature.NGram) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.NGram - -val wordDataFrame = sqlContext.createDataFrame(Seq( - (0, Array("Hi", "I", "heard", "about", "Spark")), - (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), - (2, Array("Logistic", "regression", "models", "are", "neat")) -)).toDF("label", "words") - -val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") -val ngramDataFrame = ngram.transform(wordDataFrame) -ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NGramExample.scala %}
    @@ -408,38 +278,7 @@ ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(pri Refer to the [NGram Java docs](api/java/org/apache/spark/ml/feature/NGram.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.NGram; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), - RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), - RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); -NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); -DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); -for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { - java.util.List ngrams = r.getList(0); - for (String ngram : ngrams) System.out.print(ngram + " --- "); - System.out.println(); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNGramExample.java %}
    @@ -447,19 +286,7 @@ for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { Refer to the [NGram Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import NGram - -wordDataFrame = sqlContext.createDataFrame([ - (0, ["Hi", "I", "heard", "about", "Spark"]), - (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), - (2, ["Logistic", "regression", "models", "are", "neat"]) -], ["label", "words"]) -ngram = NGram(inputCol="words", outputCol="ngrams") -ngramDataFrame = ngram.transform(wordDataFrame) -for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): - print(ngrams_label) -{% endhighlight %} +{% include_example python/ml/n_gram_example.py %}
    @@ -476,26 +303,7 @@ Binarization is the process of thresholding numerical features to binary (0/1) f Refer to the [Binarizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Binarizer -import org.apache.spark.sql.DataFrame - -val data = Array( - (0, 0.1), - (1, 0.8), - (2, 0.2) -) -val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") - -val binarizer: Binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5) - -val binarizedDataFrame = binarizer.transform(dataFrame) -val binarizedFeatures = binarizedDataFrame.select("binarized_feature") -binarizedFeatures.collect().foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BinarizerExample.scala %}
    @@ -503,40 +311,7 @@ binarizedFeatures.collect().foreach(println) Refer to the [Binarizer Java docs](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Binarizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, 0.1), - RowFactory.create(1, 0.8), - RowFactory.create(2, 0.2) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); -Binarizer binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5); -DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); -DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); -for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); - System.out.println(binarized_value); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBinarizerExample.java %}
    @@ -544,20 +319,7 @@ for (Row r : binarizedFeatures.collect()) { Refer to the [Binarizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Binarizer - -continuousDataFrame = sqlContext.createDataFrame([ - (0, 0.1), - (1, 0.8), - (2, 0.2) -], ["label", "feature"]) -binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") -binarizedDataFrame = binarizer.transform(continuousDataFrame) -binarizedFeatures = binarizedDataFrame.select("binarized_feature") -for binarized_feature, in binarizedFeatures.collect(): - print(binarized_feature) -{% endhighlight %} +{% include_example python/ml/binarizer_example.py %}
    @@ -571,25 +333,7 @@ for binarized_feature, in binarizedFeatures.collect(): Refer to the [PCA Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PCA) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PCA -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), - Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), - Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df) -val pcaDF = pca.transform(df) -val result = pcaDF.select("pcaFeatures") -result.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PCAExample.scala %}
    @@ -597,42 +341,7 @@ result.show() Refer to the [PCA Java docs](api/java/org/apache/spark/ml/feature/PCA.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.PCA -import org.apache.spark.ml.feature.PCAModel -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), - RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -PCAModel pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df); -DataFrame result = pca.transform(df).select("pcaFeatures"); -result.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPCAExample.java %}
    @@ -640,19 +349,7 @@ result.show(); Refer to the [PCA Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PCA -from pyspark.mllib.linalg import Vectors - -data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), - (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), - (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] -df = sqlContext.createDataFrame(data,["features"]) -pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") -model = pca.fit(df) -result = model.transform(df).select("pcaFeatures") -result.show(truncate=False) -{% endhighlight %} +{% include_example python/ml/pca_example.py %}
    @@ -666,23 +363,7 @@ result.show(truncate=False) Refer to the [PolynomialExpansion Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PolynomialExpansion -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0), - Vectors.dense(0.6, -1.1) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val polynomialExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3) -val polyDF = polynomialExpansion.transform(df) -polyDF.select("polyFeatures").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala %}
    @@ -690,43 +371,7 @@ polyDF.select("polyFeatures").take(3).foreach(println) Refer to the [PolynomialExpansion Java docs](api/java/org/apache/spark/ml/feature/PolynomialExpansion.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -PolynomialExpansion polyExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3); -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(-2.0, 2.3)), - RowFactory.create(Vectors.dense(0.0, 0.0)), - RowFactory.create(Vectors.dense(0.6, -1.1)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DataFrame polyDF = polyExpansion.transform(df); -Row[] row = polyDF.select("polyFeatures").take(3); -for (Row r : row) { - System.out.println(r.get(0)); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java %}
    @@ -734,20 +379,7 @@ for (Row r : row) { Refer to the [PolynomialExpansion Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PolynomialExpansion -from pyspark.mllib.linalg import Vectors - -df = sqlContext.createDataFrame( - [(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], - ["features"]) -px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") -polyDF = px.transform(df) -for expanded in polyDF.select("polyFeatures").take(3): - print(expanded) -{% endhighlight %} +{% include_example python/ml/polynomial_expansion_example.py %}
    @@ -771,22 +403,7 @@ $0$th DCT coefficient and _not_ the $N/2$th). Refer to the [DCT Scala docs](api/scala/index.html#org.apache.spark.ml.feature.DCT) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.DCT -import org.apache.spark.mllib.linalg.Vectors - -val data = Seq( - Vectors.dense(0.0, 1.0, -2.0, 3.0), - Vectors.dense(-1.0, 2.0, 4.0, -7.0), - Vectors.dense(14.0, -2.0, -5.0, 1.0)) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false) -val dctDf = dct.transform(df) -dctDf.select("featuresDCT").show(3) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/DCTExample.scala %}
    @@ -794,39 +411,7 @@ dctDf.select("featuresDCT").show(3) Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.DCT; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), - RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), - RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DCT dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false); -DataFrame dctDf = dct.transform(df); -dctDf.select("featuresDCT").show(3); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}}
    @@ -881,18 +466,7 @@ index `2`. Refer to the [StringIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StringIndexer - -val df = sqlContext.createDataFrame( - Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) -).toDF("id", "category") -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") -val indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StringIndexerExample.scala %}
    @@ -900,37 +474,7 @@ indexed.show() Refer to the [StringIndexer Java docs](api/java/org/apache/spark/ml/feature/StringIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import static org.apache.spark.sql.types.DataTypes.*; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[] { - createStructField("id", DoubleType, false), - createStructField("category", StringType, false) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexer indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex"); -DataFrame indexed = indexer.fit(df).transform(df); -indexed.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStringIndexerExample.java %}
    @@ -938,16 +482,7 @@ indexed.show(); Refer to the [StringIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StringIndexer - -df = sqlContext.createDataFrame( - [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], - ["id", "category"]) -indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example python/ml/string_indexer_example.py %}
    @@ -1030,30 +565,7 @@ for more details on the API. Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} - -val df = sqlContext.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -)).toDF("id", "category") - -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) -val indexed = indexer.transform(df) - -val encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec") -val encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %}
    @@ -1061,46 +573,7 @@ encoded.select("id", "categoryVec").show() Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); -DataFrame indexed = indexer.transform(df); - -OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); -DataFrame encoded = encoder.transform(indexed); -encoded.select("id", "categoryVec").show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %}
    @@ -1108,25 +581,7 @@ encoded.select("id", "categoryVec").show(); Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import OneHotEncoder, StringIndexer - -df = sqlContext.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -], ["id", "category"]) - -stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -model = stringIndexer.fit(df) -indexed = model.transform(df) -encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") -encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").show() -{% endhighlight %} +{% include_example python/ml/onehot_encoder_example.py %}
    @@ -1150,23 +605,7 @@ In the example below, we read in a dataset of labeled points and then use `Vecto Refer to the [VectorIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.VectorIndexer - -val data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10) -val indexerModel = indexer.fit(data) -val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet -println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) - -// Create new column "indexed" with categorical values transformed to indices -val indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorIndexerExample.scala %}
    @@ -1174,30 +613,7 @@ val indexedData = indexerModel.transform(data) Refer to the [VectorIndexer Java docs](api/java/org/apache/spark/ml/feature/VectorIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Map; - -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame data = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -VectorIndexer indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10); -VectorIndexerModel indexerModel = indexer.fit(data); -Map> categoryMaps = indexerModel.javaCategoryMaps(); -System.out.print("Chose " + categoryMaps.size() + "categorical features:"); -for (Integer feature : categoryMaps.keySet()) { - System.out.print(" " + feature); -} -System.out.println(); - -// Create new column "indexed" with categorical values transformed to indices -DataFrame indexedData = indexerModel.transform(data); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java %}
    @@ -1205,17 +621,7 @@ DataFrame indexedData = indexerModel.transform(data); Refer to the [VectorIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import VectorIndexer - -data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) -indexerModel = indexer.fit(data) - -# Create new column "indexed" with categorical values transformed to indices -indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example python/ml/vector_indexer_example.py %}
    @@ -1232,22 +638,7 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Normalizer - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -// Normalize each Vector using $L^1$ norm. -val normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0) -val l1NormData = normalizer.transform(dataFrame) - -// Normalize each Vector using $L^\infty$ norm. -val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %}
    @@ -1255,24 +646,7 @@ val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.Positi Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); - -// Normalize each Vector using $L^1$ norm. -Normalizer normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0); -DataFrame l1NormData = normalizer.transform(dataFrame); - -// Normalize each Vector using $L^\infty$ norm. -DataFrame lInfNormData = - normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %}
    @@ -1280,19 +654,7 @@ DataFrame lInfNormData = Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Normalizer - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -# Normalize each Vector using $L^1$ norm. -normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) -l1NormData = normalizer.transform(dataFrame) - -# Normalize each Vector using $L^\infty$ norm. -lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) -{% endhighlight %} +{% include_example python/ml/normalizer_example.py %}
    @@ -1316,23 +678,7 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StandardScaler - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false) - -// Compute summary statistics by fitting the StandardScaler -val scalerModel = scaler.fit(dataFrame) - -// Normalize each feature to have unit standard deviation. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %}
    @@ -1340,25 +686,7 @@ val scaledData = scalerModel.transform(dataFrame) Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.ml.feature.StandardScalerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -StandardScaler scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false); - -// Compute summary statistics by fitting the StandardScaler -StandardScalerModel scalerModel = scaler.fit(dataFrame); - -// Normalize each feature to have unit standard deviation. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %}
    @@ -1366,20 +694,7 @@ DataFrame scaledData = scalerModel.transform(dataFrame); Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StandardScaler - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", - withStd=True, withMean=False) - -# Compute summary statistics by fitting the StandardScaler -scalerModel = scaler.fit(dataFrame) - -# Normalize each feature to have unit standard deviation. -scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/standard_scaler_example.py %}
    @@ -1409,21 +724,7 @@ Refer to the [MinMaxScaler Scala docs](api/scala/index.html#org.apache.spark.ml. and the [MinMaxScalerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.MinMaxScaler - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - -// Compute summary statistics and generate MinMaxScalerModel -val scalerModel = scaler.fit(dataFrame) - -// rescale each feature to range [min, max]. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala %}
    @@ -1432,24 +733,7 @@ Refer to the [MinMaxScaler Java docs](api/java/org/apache/spark/ml/feature/MinMa and the [MinMaxScalerModel Java docs](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) for more details on the API. -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.MinMaxScaler; -import org.apache.spark.ml.feature.MinMaxScalerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -MinMaxScaler scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures"); - -// Compute summary statistics and generate MinMaxScalerModel -MinMaxScalerModel scalerModel = scaler.fit(dataFrame); - -// rescale each feature to range [min, max]. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %}
    @@ -1473,23 +757,7 @@ The following example demonstrates how to bucketize a column of `Double`s into a Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Bucketizer -import org.apache.spark.sql.DataFrame - -val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) - -val data = Array(-0.5, -0.3, 0.0, 0.2) -val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - -val bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits) - -// Transform original data into its bucket index. -val bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %}
    @@ -1497,38 +765,7 @@ val bucketedData = bucketizer.transform(dataFrame) Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(-0.5), - RowFactory.create(-0.3), - RowFactory.create(0.0), - RowFactory.create(0.2) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame dataFrame = jsql.createDataFrame(data, schema); - -Bucketizer bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits); - -// Transform original data into its bucket index. -DataFrame bucketedData = bucketizer.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %}
    @@ -1536,19 +773,7 @@ DataFrame bucketedData = bucketizer.transform(dataFrame); Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Bucketizer - -splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] - -data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] -dataFrame = sqlContext.createDataFrame(data, ["features"]) - -bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") - -# Transform original data into its bucket index. -bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/bucketizer_example.py %}
    @@ -1580,25 +805,7 @@ This example below demonstrates how to transform vectors using a transforming ve Refer to the [ElementwiseProduct Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.ElementwiseProduct -import org.apache.spark.mllib.linalg.Vectors - -// Create some vector data; also works for sparse vectors -val dataFrame = sqlContext.createDataFrame(Seq( - ("a", Vectors.dense(1.0, 2.0, 3.0)), - ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") - -val transformingVector = Vectors.dense(0.0, 1.0, 2.0) -val transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector") - -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show() - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala %}
    @@ -1606,41 +813,7 @@ transformer.transform(dataFrame).show() Refer to the [ElementwiseProduct Java docs](api/java/org/apache/spark/ml/feature/ElementwiseProduct.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.ElementwiseProduct; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -// Create some vector data; also works for sparse vectors -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), - RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) -)); -List fields = new ArrayList(2); -fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); -fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); -StructType schema = DataTypes.createStructType(fields); -DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); -Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); -ElementwiseProduct transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector"); -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show(); - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java %}
    @@ -1648,19 +821,8 @@ transformer.transform(dataFrame).show(); Refer to the [ElementwiseProduct Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import ElementwiseProduct -from pyspark.mllib.linalg import Vectors - -data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] -df = sqlContext.createDataFrame(data, ["vector"]) -transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), - inputCol="vector", outputCol="transformedVector") -transformer.transform(df).show() - -{% endhighlight %} +{% include_example python/ml/elementwise_product_example.py %}
    - ## SQLTransformer @@ -1763,19 +925,7 @@ output column to `features`, after transformation we should get the following Da Refer to the [VectorAssembler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.feature.VectorAssembler - -val dataset = sqlContext.createDataFrame( - Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) -).toDF("id", "hour", "mobile", "userFeatures", "clicked") -val assembler = new VectorAssembler() - .setInputCols(Array("hour", "mobile", "userFeatures")) - .setOutputCol("features") -val output = assembler.transform(dataset) -println(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala %}
    @@ -1783,36 +933,7 @@ println(output.select("features", "clicked").first()) Refer to the [VectorAssembler Java docs](api/java/org/apache/spark/ml/feature/VectorAssembler.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("hour", IntegerType, false), - createStructField("mobile", DoubleType, false), - createStructField("userFeatures", new VectorUDT(), false), - createStructField("clicked", DoubleType, false) -}); -Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); -JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) - .setOutputCol("features"); - -DataFrame output = assembler.transform(dataset); -System.out.println(output.select("features", "clicked").first()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java %}
    @@ -1820,19 +941,7 @@ System.out.println(output.select("features", "clicked").first()); Refer to the [VectorAssembler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) for more details on the API. -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.ml.feature import VectorAssembler - -dataset = sqlContext.createDataFrame( - [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], - ["id", "hour", "mobile", "userFeatures", "clicked"]) -assembler = VectorAssembler( - inputCols=["hour", "mobile", "userFeatures"], - outputCol="features") -output = assembler.transform(dataset) -print(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example python/ml/vector_assembler_example.py %}
    @@ -1962,33 +1071,7 @@ Suppose also that we have a potential input attributes for the `userFeatures`, i Refer to the [VectorSlicer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} -import org.apache.spark.ml.feature.VectorSlicer -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext} - -val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.dense(-2.0, 2.3, 0.0) -) - -val defaultAttr = NumericAttribute.defaultAttr -val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) -val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) - -val dataRDD = sc.parallelize(data).map(Row.apply) -val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField())) - -val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") - -slicer.setIndices(1).setNames("f3") -// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) - -val output = slicer.transform(dataset) -println(output.select("userFeatures", "features").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorSlicerExample.scala %}
    @@ -1996,41 +1079,7 @@ println(output.select("userFeatures", "features").first()) Refer to the [VectorSlicer Java docs](api/java/org/apache/spark/ml/feature/VectorSlicer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -Attribute[] attrs = new Attribute[]{ - NumericAttribute.defaultAttr().withName("f1"), - NumericAttribute.defaultAttr().withName("f2"), - NumericAttribute.defaultAttr().withName("f3") -}; -AttributeGroup group = new AttributeGroup("userFeatures", attrs); - -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), - RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) -)); - -DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); - -VectorSlicer vectorSlicer = new VectorSlicer() - .setInputCol("userFeatures").setOutputCol("features"); - -vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); -// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) - -DataFrame output = vectorSlicer.transform(dataset); - -System.out.println(output.select("userFeatures", "features").first()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java %}
    @@ -2067,21 +1116,7 @@ id | country | hour | clicked | features | label Refer to the [RFormula Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RFormula) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.RFormula - -val dataset = sqlContext.createDataFrame(Seq( - (7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0) -)).toDF("id", "country", "hour", "clicked") -val formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label") -val output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/RFormulaExample.scala %}
    @@ -2089,38 +1124,7 @@ output.select("features", "label").show() Refer to the [RFormula Java docs](api/java/org/apache/spark/ml/feature/RFormula.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RFormula; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("country", StringType, false), - createStructField("hour", IntegerType, false), - createStructField("clicked", DoubleType, false) -}); -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(7, "US", 18, 1.0), - RowFactory.create(8, "CA", 12, 0.0), - RowFactory.create(9, "NZ", 15, 0.0) -)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -RFormula formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label"); - -DataFrame output = formula.fit(dataset).transform(dataset); -output.select("features", "label").show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaRFormulaExample.java %}
    @@ -2128,21 +1132,7 @@ output.select("features", "label").show(); Refer to the [RFormula Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import RFormula - -dataset = sqlContext.createDataFrame( - [(7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0)], - ["id", "country", "hour", "clicked"]) -formula = RFormula( - formula="clicked ~ country + hour", - featuresCol="features", - labelCol="label") -output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example python/ml/rformula_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java new file mode 100644 index 0000000000000..9698cac504371 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Binarizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBinarizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBinarizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 0.1), + RowFactory.create(1, 0.8), + RowFactory.create(2, 0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); + Binarizer binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5); + DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); + DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); + for (Row r : binarizedFeatures.collect()) { + Double binarized_value = r.getDouble(0); + System.out.println(binarized_value); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java new file mode 100644 index 0000000000000..8ad369cc93e8a --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Bucketizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBucketizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBucketizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame dataFrame = jsql.createDataFrame(data, schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits); + + // Transform original data into its bucket index. + DataFrame bucketedData = bucketizer.transform(dataFrame); + bucketedData.show(); + // $example off$ + jsc.stop(); + } +} + + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java new file mode 100644 index 0000000000000..35c0d534a45e9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.DCT; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaDCTExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDCTExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), + RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), + RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + DataFrame df = jsql.createDataFrame(data, schema); + DCT dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false); + DataFrame dctDf = dct.transform(df); + dctDf.select("featuresDCT").show(3); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java new file mode 100644 index 0000000000000..2898accec61b0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaElementwiseProductExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Create some vector data; also works for sparse vectors + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) + )); + + List fields = new ArrayList(2); + fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("vector", new VectorUDT(), false)); + + StructType schema = DataTypes.createStructType(fields); + + DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); + + Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); + + ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java new file mode 100644 index 0000000000000..2d50ba7faa1a1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaMinMaxScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JaveMinMaxScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + + // Compute summary statistics and generate MinMaxScalerModel + MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + + // rescale each feature to range [min, max]. + DataFrame scaledData = scalerModel.transform(dataFrame); + scaledData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java new file mode 100644 index 0000000000000..8fd75ed8b5f4e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaNGramExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNGramExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField( + "words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); + + NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); + + DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); + + for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java new file mode 100644 index 0000000000000..ed3f6163c0558 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.Normalizer; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaNormalizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNormalizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Normalize each Vector using $L^1$ norm. + Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0); + + DataFrame l1NormData = normalizer.transform(dataFrame); + l1NormData.show(); + + // Normalize each Vector using $L^\infty$ norm. + DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); + lInfNormData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java new file mode 100644 index 0000000000000..bc509607084b1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.OneHotEncoder; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaOneHotEncoderExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaOneHotEncoderExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); + DataFrame indexed = indexer.transform(df); + + OneHotEncoder encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec"); + DataFrame encoded = encoder.transform(indexed); + encoded.select("id", "categoryVec").show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java new file mode 100644 index 0000000000000..8282fab084f36 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PCA; +import org.apache.spark.ml.feature.PCAModel; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPCAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPCAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + + PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); + + DataFrame result = pca.transform(df).select("pcaFeatures"); + result.show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java new file mode 100644 index 0000000000000..668f71e64056b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PolynomialExpansion; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPolynomialExpansionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPolynomialExpansionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(-2.0, 2.3)), + RowFactory.create(Vectors.dense(0.0, 0.0)), + RowFactory.create(Vectors.dense(0.6, -1.1)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + DataFrame polyDF = polyExpansion.transform(df); + + Row[] row = polyDF.select("polyFeatures").take(3); + for (Row r : row) { + System.out.println(r.get(0)); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java new file mode 100644 index 0000000000000..1e1062b541ad9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaRFormulaExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRFormulaExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) + }); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) + )); + + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + DataFrame output = formula.fit(dataset).transform(dataset); + output.select("features", "label").show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java new file mode 100644 index 0000000000000..da4756643f3c4 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.ml.feature.StandardScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaStandardScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStandardScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + + // Compute summary statistics by fitting the StandardScaler + StandardScalerModel scalerModel = scaler.fit(dataFrame); + + // Normalize each feature to have unit standard deviation. + DataFrame scaledData = scalerModel.transform(dataFrame); + scaledData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java new file mode 100644 index 0000000000000..b6b201c6b68d2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaStopWordsRemoverExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStopWordsRemoverExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField( + "raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame dataset = jsql.createDataFrame(rdd, schema); + remover.transform(dataset).show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java new file mode 100644 index 0000000000000..05d12c1e702f1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaStringIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStringIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + StructType schema = new StructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("category", StringType, false) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); + DataFrame indexed = indexer.fit(df).transform(df); + indexed.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java new file mode 100644 index 0000000000000..617dc3f66e3bf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaTokenizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTokenizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "Hi I heard about Spark"), + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); + + Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); + + DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); + for (Row r : wordsDataFrame.select("words", "label"). take(3)) { + java.util.List words = r.getList(0); + for (String word : words) System.out.print(word + " "); + System.out.println(); + } + + RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java new file mode 100644 index 0000000000000..7e230b5897c1e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaVectorAssemblerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorAssemblerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) + }); + Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + + DataFrame output = assembler.transform(dataset); + System.out.println(output.select("features", "clicked").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java new file mode 100644 index 0000000000000..545758e31d972 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Map; + +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaVectorIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10); + VectorIndexerModel indexerModel = indexer.fit(data); + + Map> categoryMaps = indexerModel.javaCategoryMaps(); + System.out.print("Chose " + categoryMaps.size() + " categorical features:"); + + for (Integer feature : categoryMaps.keySet()) { + System.out.print(" " + feature); + } + System.out.println(); + + // Create new column "indexed" with categorical values transformed to indices + DataFrame indexedData = indexerModel.transform(data); + indexedData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java new file mode 100644 index 0000000000000..4d5cb04ff5e2b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.attribute.Attribute; +import org.apache.spark.ml.attribute.AttributeGroup; +import org.apache.spark.ml.attribute.NumericAttribute; +import org.apache.spark.ml.feature.VectorSlicer; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaVectorSlicerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorSlicerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") + }; + AttributeGroup group = new AttributeGroup("userFeatures", attrs); + + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) + )); + + DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + + VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + + vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); + // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + + DataFrame output = vectorSlicer.transform(dataset); + + System.out.println(output.select("userFeatures", "features").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/python/ml/binarizer_example.py b/examples/src/main/python/ml/binarizer_example.py new file mode 100644 index 0000000000000..317cfa638a5a9 --- /dev/null +++ b/examples/src/main/python/ml/binarizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Binarizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinarizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + continuousDataFrame = sqlContext.createDataFrame([ + (0, 0.1), + (1, 0.8), + (2, 0.2) + ], ["label", "feature"]) + binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") + binarizedDataFrame = binarizer.transform(continuousDataFrame) + binarizedFeatures = binarizedDataFrame.select("binarized_feature") + for binarized_feature, in binarizedFeatures.collect(): + print(binarized_feature) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/bucketizer_example.py b/examples/src/main/python/ml/bucketizer_example.py new file mode 100644 index 0000000000000..4304255f350db --- /dev/null +++ b/examples/src/main/python/ml/bucketizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Bucketizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BucketizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] + + data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] + dataFrame = sqlContext.createDataFrame(data, ["features"]) + + bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") + + # Transform original data into its bucket index. + bucketedData = bucketizer.transform(dataFrame) + bucketedData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py new file mode 100644 index 0000000000000..c85cb0d89543c --- /dev/null +++ b/examples/src/main/python/ml/elementwise_product_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="ElementwiseProductExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] + df = sqlContext.createDataFrame(data, ["vector"]) + transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), + inputCol="vector", outputCol="transformedVector") + transformer.transform(df).show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/n_gram_example.py b/examples/src/main/python/ml/n_gram_example.py new file mode 100644 index 0000000000000..f2d85f53e7219 --- /dev/null +++ b/examples/src/main/python/ml/n_gram_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import NGram +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NGramExample") + sqlContext = SQLContext(sc) + + # $example on$ + wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) + ], ["label", "words"]) + ngram = NGram(inputCol="words", outputCol="ngrams") + ngramDataFrame = ngram.transform(wordDataFrame) + for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/normalizer_example.py b/examples/src/main/python/ml/normalizer_example.py new file mode 100644 index 0000000000000..d490221474c24 --- /dev/null +++ b/examples/src/main/python/ml/normalizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Normalizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NormalizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Normalize each Vector using $L^1$ norm. + normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) + l1NormData = normalizer.transform(dataFrame) + l1NormData.show() + + # Normalize each Vector using $L^\infty$ norm. + lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) + lInfNormData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py new file mode 100644 index 0000000000000..0f94c26638d35 --- /dev/null +++ b/examples/src/main/python/ml/onehot_encoder_example.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import OneHotEncoder, StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="OneHotEncoderExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + ], ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + encoder = OneHotEncoder(dropLast=False, inputCol="categoryIndex", outputCol="categoryVec") + encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/pca_example.py b/examples/src/main/python/ml/pca_example.py new file mode 100644 index 0000000000000..a17181f1b8a51 --- /dev/null +++ b/examples/src/main/python/ml/pca_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PCAExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] + df = sqlContext.createDataFrame(data, ["features"]) + pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") + model = pca.fit(df) + result = model.transform(df).select("pcaFeatures") + result.show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py new file mode 100644 index 0000000000000..3d4fafd1a42e9 --- /dev/null +++ b/examples/src/main/python/ml/polynomial_expansion_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PolynomialExpansion +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PolynomialExpansionExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext\ + .createDataFrame([(Vectors.dense([-2.0, 2.3]), ), + (Vectors.dense([0.0, 0.0]), ), + (Vectors.dense([0.6, -1.1]), )], + ["features"]) + px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") + polyDF = px.transform(df) + for expanded in polyDF.select("polyFeatures").take(3): + print(expanded) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/rformula_example.py b/examples/src/main/python/ml/rformula_example.py new file mode 100644 index 0000000000000..b544a14700762 --- /dev/null +++ b/examples/src/main/python/ml/rformula_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import RFormula +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="RFormulaExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) + formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") + output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/standard_scaler_example.py b/examples/src/main/python/ml/standard_scaler_example.py new file mode 100644 index 0000000000000..ae7aa85005bcd --- /dev/null +++ b/examples/src/main/python/ml/standard_scaler_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StandardScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StandardScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", + withStd=True, withMean=False) + + # Compute summary statistics by fitting the StandardScaler + scalerModel = scaler.fit(dataFrame) + + # Normalize each feature to have unit standard deviation. + scaledData = scalerModel.transform(dataFrame) + scaledData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/stopwords_remover_example.py b/examples/src/main/python/ml/stopwords_remover_example.py new file mode 100644 index 0000000000000..01f94af8ca752 --- /dev/null +++ b/examples/src/main/python/ml/stopwords_remover_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StopWordsRemover +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StopWordsRemoverExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) + ], ["label", "raw"]) + + remover = StopWordsRemover(inputCol="raw", outputCol="filtered") + remover.transform(sentenceData).show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/string_indexer_example.py b/examples/src/main/python/ml/string_indexer_example.py new file mode 100644 index 0000000000000..58a8cb5d56b73 --- /dev/null +++ b/examples/src/main/python/ml/string_indexer_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StringIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) + indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + indexed = indexer.fit(df).transform(df) + indexed.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/tokenizer_example.py b/examples/src/main/python/ml/tokenizer_example.py new file mode 100644 index 0000000000000..ce9b225be5357 --- /dev/null +++ b/examples/src/main/python/ml/tokenizer_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Tokenizer, RegexTokenizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="TokenizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceDataFrame = sqlContext.createDataFrame([ + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + ], ["label", "sentence"]) + tokenizer = Tokenizer(inputCol="sentence", outputCol="words") + wordsDataFrame = tokenizer.transform(sentenceDataFrame) + for words_label in wordsDataFrame.select("words", "label").take(3): + print(words_label) + regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") + # alternatively, pattern="\\w+", gaps(False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_assembler_example.py b/examples/src/main/python/ml/vector_assembler_example.py new file mode 100644 index 0000000000000..04f64839f188d --- /dev/null +++ b/examples/src/main/python/ml/vector_assembler_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorAssemblerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) + assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") + output = assembler.transform(dataset) + print(output.select("features", "clicked").first()) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_indexer_example.py b/examples/src/main/python/ml/vector_indexer_example.py new file mode 100644 index 0000000000000..146f41c1dd903 --- /dev/null +++ b/examples/src/main/python/ml/vector_indexer_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import VectorIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) + indexerModel = indexer.fit(data) + + # Create new column "indexed" with categorical values transformed to indices + indexedData = indexerModel.transform(data) + indexedData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala new file mode 100644 index 0000000000000..e724aa587294b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Binarizer +// $example off$ +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.{SparkConf, SparkContext} + +object BinarizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BinarizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + val data = Array((0, 0.1), (1, 0.8), (2, 0.2)) + val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5) + + val binarizedDataFrame = binarizer.transform(dataFrame) + val binarizedFeatures = binarizedDataFrame.select("binarized_feature") + binarizedFeatures.collect().foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala new file mode 100644 index 0000000000000..7c75e3d72b47b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Bucketizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object BucketizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BucketizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + + val data = Array(-0.5, -0.3, 0.0, 0.2) + val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits) + + // Transform original data into its bucket index. + val bucketedData = bucketizer.transform(dataFrame) + bucketedData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala new file mode 100644 index 0000000000000..314c2c28a2a10 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.DCT +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object DCTExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DCTExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Seq( + Vectors.dense(0.0, 1.0, -2.0, 3.0), + Vectors.dense(-1.0, 2.0, 4.0, -7.0), + Vectors.dense(14.0, -2.0, -5.0, 1.0)) + + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false) + + val dctDf = dct.transform(df) + dctDf.select("featuresDCT").show(3) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala new file mode 100644 index 0000000000000..872de51dc75df --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object ElementwiseProductExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ElementwiseProductExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Create some vector data; also works for sparse vectors + val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + + val transformingVector = Vectors.dense(0.0, 1.0, 2.0) + val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala new file mode 100644 index 0000000000000..fb7f28c9886bb --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object MinMaxScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MinMaxScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + + // Compute summary statistics and generate MinMaxScalerModel + val scalerModel = scaler.fit(dataFrame) + + // rescale each feature to range [min, max]. + val scaledData = scalerModel.transform(dataFrame) + scaledData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala new file mode 100644 index 0000000000000..8a85f71b56f3d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.NGram +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NGramExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NGramExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) + )).toDF("label", "words") + + val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") + val ngramDataFrame = ngram.transform(wordDataFrame) + ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala new file mode 100644 index 0000000000000..1990b55e8c5e8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Normalizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NormalizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NormalizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Normalize each Vector using $L^1$ norm. + val normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0) + + val l1NormData = normalizer.transform(dataFrame) + l1NormData.show() + + // Normalize each Vector using $L^\infty$ norm. + val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) + lInfNormData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala new file mode 100644 index 0000000000000..66602e2118506 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object OneHotEncoderExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("OneHotEncoderExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + )).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) + val indexed = indexer.transform(df) + + val encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec") + val encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala new file mode 100644 index 0000000000000..4c806f71a32c3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PCAExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PCAExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) + val pcaDF = pca.transform(df) + val result = pcaDF.select("pcaFeatures") + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala new file mode 100644 index 0000000000000..39fb79af35766 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PolynomialExpansion +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PolynomialExpansionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PolynomialExpansionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0), + Vectors.dense(0.6, -1.1) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) + val polyDF = polynomialExpansion.transform(df) + polyDF.select("polyFeatures").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala new file mode 100644 index 0000000000000..286866edea502 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.RFormula +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object RFormulaExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RFormulaExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) + )).toDF("id", "country", "hour", "clicked") + val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") + val output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala new file mode 100644 index 0000000000000..e0a41e383a7ea --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StandardScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StandardScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false) + + // Compute summary statistics by fitting the StandardScaler. + val scalerModel = scaler.fit(dataFrame) + + // Normalize each feature to have unit standard deviation. + val scaledData = scalerModel.transform(dataFrame) + scaledData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala new file mode 100644 index 0000000000000..655ffce08d3ab --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StopWordsRemover +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StopWordsRemoverExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StopWordsRemoverExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + + val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) + )).toDF("id", "raw") + + remover.transform(dataSet).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala new file mode 100644 index 0000000000000..9fa494cd2473b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StringIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StringIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StringIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + ).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + + val indexed = indexer.fit(df).transform(df) + indexed.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala new file mode 100644 index 0000000000000..01e0d1388a2f4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object TokenizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("TokenizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val sentenceDataFrame = sqlContext.createDataFrame(Seq( + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + )).toDF("label", "sentence") + + val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") + val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + + val tokenized = tokenizer.transform(sentenceDataFrame) + tokenized.select("words", "label").take(3).foreach(println) + val regexTokenized = regexTokenizer.transform(sentenceDataFrame) + regexTokenized.select("words", "label").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala new file mode 100644 index 0000000000000..d527924419f81 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorAssemblerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorAssemblerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) + ).toDF("id", "hour", "mobile", "userFeatures", "clicked") + + val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") + + val output = assembler.transform(dataset) + println(output.select("features", "clicked").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala new file mode 100644 index 0000000000000..685891c164e70 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10) + + val indexerModel = indexer.fit(data) + + val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet + println(s"Chose ${categoricalFeatures.size} categorical features: " + + categoricalFeatures.mkString(", ")) + + // Create new column "indexed" with categorical values transformed to indices + val indexedData = indexerModel.transform(data) + indexedData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala new file mode 100644 index 0000000000000..04f19829eff87 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorSlicerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorSlicerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0))) + + val defaultAttr = NumericAttribute.defaultAttr + val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) + val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + + val dataRDD = sc.parallelize(data) + val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField()))) + + val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + + slicer.setIndices(Array(1)).setNames(Array("f3")) + // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + + val output = slicer.transform(dataset) + println(output.select("userFeatures", "features").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println From c3f99697b4f3b1976080c52398c52f3dc5ac35bc Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 9 Dec 2015 17:16:01 -0800 Subject: [PATCH 1095/1168] [SPARK-12211][DOC][GRAPHX] Fix version number in graphx doc for migration from 1.1 Migration from 1.1 section added to the GraphX doc in 1.2.0 (see https://spark.apache.org/docs/1.2.0/graphx-programming-guide.html#migrating-from-spark-11) uses \{{site.SPARK_VERSION}} as the version where changes were introduced, it should be just 1.2. Author: Andrew Ray Closes #10206 from aray/graphx-doc-1.1-migration. --- docs/graphx-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 6a512ab234bb2..9dea9b5904d2d 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -70,7 +70,7 @@ operators (e.g., [subgraph](#structural_operators), [joinVertices](#join_operato ## Migrating from Spark 1.1 -GraphX in Spark {{site.SPARK_VERSION}} contains a few user facing API changes: +GraphX in Spark 1.2 contains a few user facing API changes: 1. To improve performance we have introduced a new version of [`mapReduceTriplets`][Graph.mapReduceTriplets] called From 0e97f52a62bb9f773bd089d8f9397cff74b3791d Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 9 Dec 2015 17:24:04 -0800 Subject: [PATCH 1096/1168] [SPARK-12165][ADDENDUM] Fix outdated comments on unroll test JoshRosen Author: Andrew Or Closes #10229 from andrewor14/unroll-test-comments. --- .../spark/memory/StaticMemoryManagerSuite.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 6700b94f0f57f..272253bc94e91 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -163,15 +163,20 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 860L) // `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes. - // Since we already occupy 60 bytes, we will try to evict only 400 - 60 = 340 bytes. + // As of this point, cache memory is 800 bytes and current unroll memory is 60 bytes. + // Requesting 240 more bytes of unroll memory will leave our total unroll memory at + // 300 bytes, still under the 400-byte limit. Therefore, all 240 bytes are granted. assert(mm.acquireUnrollMemory(dummyBlock, 240L, evictedBlocks)) - assertEvictBlocksToFreeSpaceCalled(ms, 100L) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) // 860 + 240 - 1000 when(ms.currentUnrollMemory).thenReturn(300L) // 60 + 240 assert(mm.storageMemoryUsed === 1000L) evictedBlocks.clear() + // We already have 300 bytes of unroll memory, so requesting 150 more will leave us + // above the 400-byte limit. Since there is not enough free memory, this request will + // fail even after evicting as much as we can (400 - 300 = 100 bytes). assert(!mm.acquireUnrollMemory(dummyBlock, 150L, evictedBlocks)) - assertEvictBlocksToFreeSpaceCalled(ms, 100L) // 400 - 300 - assert(mm.storageMemoryUsed === 900L) // 100 bytes were evicted + assertEvictBlocksToFreeSpaceCalled(ms, 100L) + assert(mm.storageMemoryUsed === 900L) // Release beyond what was acquired mm.releaseUnrollMemory(maxStorageMem) assert(mm.storageMemoryUsed === 0L) From c4187c6dc235364d608768a121a3271db5793ede Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 9 Dec 2015 18:09:36 -0800 Subject: [PATCH 1097/1168] [SPARK-11678][SQL][DOCS] Document basePath in the programming guide. This PR adds document for `basePath`, which is a new parameter used by `HadoopFsRelation`. The compiled doc is shown below. ![image](https://cloud.githubusercontent.com/assets/2072857/11673132/1ba01192-9dcb-11e5-98d9-ac0b4e92e98c.png) JIRA: https://issues.apache.org/jira/browse/SPARK-11678 Author: Yin Huai Closes #10211 from yhuai/basePathDoc. --- docs/sql-programming-guide.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 9f87accd30f40..3f9a831eddc88 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1233,6 +1233,13 @@ infer the data types of the partitioning columns. For these use cases, the autom can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to `true`. When type inference is disabled, string type will be used for the partitioning columns. +Starting from Spark 1.6.0, partition discovery only finds partitions under the given paths +by default. For the above example, if users pass `path/to/table/gender=male` to either +`SQLContext.read.parquet` or `SQLContext.read.load`, `gender` will not be considered as a +partitioning column. If users need to specify the base path that partition discovery +should start with, they can set `basePath` in the data source options. For example, +when `path/to/table/gender=male` is the path of the data and +users set `basePath` to `path/to/table/`, `gender` will be a partitioning column. ### Schema Merging From c006bba2fee441726029a41db77ed4bdf401ecff Mon Sep 17 00:00:00 2001 From: Mark Grover Date: Wed, 9 Dec 2015 18:37:35 -0800 Subject: [PATCH 1098/1168] [SPARK-11796] Fix httpclient and httpcore depedency issues related to docker-client This commit fixes dependency issues which prevented the Docker-based JDBC integration tests from running in the Maven build. Author: Mark Grover Closes #9876 from markgrover/master_docker. --- docker-integration-tests/pom.xml | 22 ++++++++++++++++++++++ pom.xml | 28 ++++++++++++++++++++++++++++ sql/core/pom.xml | 2 -- 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml index dee0c4aa37ae8..39d3f344615e1 100644 --- a/docker-integration-tests/pom.xml +++ b/docker-integration-tests/pom.xml @@ -71,6 +71,18 @@
    + + org.apache.httpcomponents + httpclient + 4.5 + test + + + org.apache.httpcomponents + httpcore + 4.4.1 + test + com.google.guava @@ -109,6 +121,16 @@ ${project.version} test + + mysql + mysql-connector-java + test + + + org.postgresql + postgresql + test + prob=$prob, prediction=$prediction") - } - -{% endhighlight %} - - -
    -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from Java Beans. -public class Document implements Serializable { - private long id; - private String text; - - public Document(long id, String text) { - this.id = id; - this.text = text; - } - - public long getId() { return this.id; } - public void setId(long id) { this.id = id; } - - public String getText() { return this.text; } - public void setText(String text) { this.text = text; } -} - -public class LabeledDocument extends Document implements Serializable { - private double label; - - public LabeledDocument(long id, String text, double label) { - super(id, text); - this.label = label; - } - - public double getLabel() { return this.label; } - public void setLabel(double label) { this.label = label; } -} - -// Prepare training documents, which are labeled. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0) -), LabeledDocument.class); - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); -HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); -LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - -// Fit the pipeline to training documents. -PipelineModel model = pipeline.fit(training); - -// Prepare test documents, which are unlabeled. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop") -), Document.class); - -// Make predictions on test documents. -DataFrame predictions = model.transform(test); -for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} -
    - -
    -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import LogisticRegression -from pyspark.ml.feature import HashingTF, Tokenizer -from pyspark.sql import Row - -# Prepare training documents from a list of (id, text, label) tuples. -LabeledDocument = Row("id", "text", "label") -training = sqlContext.createDataFrame([ - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) - -# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. -tokenizer = Tokenizer(inputCol="text", outputCol="words") -hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") -lr = LogisticRegression(maxIter=10, regParam=0.01) -pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) - -# Fit the pipeline to training documents. -model = pipeline.fit(training) - -# Prepare test documents, which are unlabeled (id, text) tuples. -test = sqlContext.createDataFrame([ - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")], ["id", "text"]) - -# Make predictions on test documents and print columns of interest. -prediction = model.transform(test) -selected = prediction.select("id", "text", "prediction") -for row in selected.collect(): - print(row) - -{% endhighlight %} -
    - - - -## Example: model selection via cross-validation - -An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. -`Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. - -Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator). -`CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. -`CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. - -The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) -for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) -for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) -for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` -method in each of these evaluators. - -The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. -`CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. - -The following example demonstrates using `CrossValidator` to select from a grid of parameters. -To help construct the parameter grid, we use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility. - -Note that cross-validation over a grid of parameters is expensive. -E.g., in the example below, the parameter grid has 3 values for `hashingTF.numFeatures` and 2 values for `lr.regParam`, and `CrossValidator` uses 2 folds. This multiplies out to `$(3 \times 2) \times 2 = 12$` different models being trained. -In realistic settings, it can be common to try many more parameters and use more folds (`$k=3$` and `$k=10$` are common). -In other words, using `CrossValidator` can be very expensive. -However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning. - -
    - -
    -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator -import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.Row - -// Prepare training data from a list of (id, text, label) tuples. -val training = sqlContext.createDataFrame(Seq( - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0), - (4L, "b spark who", 1.0), - (5L, "g d a y", 0.0), - (6L, "spark fly", 1.0), - (7L, "was mapreduce", 0.0), - (8L, "e spark program", 1.0), - (9L, "a e c l", 0.0), - (10L, "spark compile", 1.0), - (11L, "hadoop software", 0.0) -)).toDF("id", "text", "label") - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -val tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words") -val hashingTF = new HashingTF() - .setInputCol(tokenizer.getOutputCol) - .setOutputCol("features") -val lr = new LogisticRegression() - .setMaxIter(10) -val pipeline = new Pipeline() - .setStages(Array(tokenizer, hashingTF, lr)) - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, -// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. -val paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) - .addGrid(lr.regParam, Array(0.1, 0.01)) - .build() - -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric -// is areaUnderROC. -val cv = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator) - .setEstimatorParamMaps(paramGrid) - .setNumFolds(2) // Use 3+ in practice - -// Run cross-validation, and choose the best set of parameters. -val cvModel = cv.fit(training) - -// Prepare test documents, which are unlabeled (id, text) tuples. -val test = sqlContext.createDataFrame(Seq( - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop") -)).toDF("id", "text") - -// Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test) - .select("id", "text", "probability", "prediction") - .collect() - .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println(s"($id, $text) --> prob=$prob, prediction=$prediction") - } - -{% endhighlight %} -
    - -
    -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.tuning.CrossValidator; -import org.apache.spark.ml.tuning.CrossValidatorModel; -import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from Java Beans. -public class Document implements Serializable { - private long id; - private String text; - - public Document(long id, String text) { - this.id = id; - this.text = text; - } - - public long getId() { return this.id; } - public void setId(long id) { this.id = id; } - - public String getText() { return this.text; } - public void setText(String text) { this.text = text; } -} - -public class LabeledDocument extends Document implements Serializable { - private double label; - - public LabeledDocument(long id, String text, double label) { - super(id, text); - this.label = label; - } - - public double getLabel() { return this.label; } - public void setLabel(double label) { this.label = label; } -} - - -// Prepare training documents, which are labeled. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0), - new LabeledDocument(4L, "b spark who", 1.0), - new LabeledDocument(5L, "g d a y", 0.0), - new LabeledDocument(6L, "spark fly", 1.0), - new LabeledDocument(7L, "was mapreduce", 0.0), - new LabeledDocument(8L, "e spark program", 1.0), - new LabeledDocument(9L, "a e c l", 0.0), - new LabeledDocument(10L, "spark compile", 1.0), - new LabeledDocument(11L, "hadoop software", 0.0) -), LabeledDocument.class); - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); -HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); -LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, -// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. -ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) - .addGrid(lr.regParam(), new double[]{0.1, 0.01}) - .build(); - -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric -// is areaUnderROC. -CrossValidator cv = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator()) - .setEstimatorParamMaps(paramGrid) - .setNumFolds(2); // Use 3+ in practice - -// Run cross-validation, and choose the best set of parameters. -CrossValidatorModel cvModel = cv.fit(training); - -// Prepare test documents, which are unlabeled. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop") -), Document.class); - -// Make predictions on test documents. cvModel uses the best model found (lrModel). -DataFrame predictions = cvModel.transform(test); -for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} -
    - -
    - -## Example: model selection via train validation split -In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. -`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in - case of `CrossValidator`. It is therefore less expensive, - but will not produce as reliable results when the training dataset is not sufficiently large. - -`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter, -and an `Evaluator`. -It begins by splitting the dataset into two parts using `trainRatio` parameter -which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default), -`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation. -Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s. -For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`. -The `ParamMap` which produces the best evaluation metric is selected as the best option. -`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. - -
    - -
    -{% highlight scala %} -import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} - -// Prepare training and test data. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") -val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) - -val lr = new LinearRegression() - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// TrainValidationSplit will try all combinations of values and determine best model using -// the evaluator. -val paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam, Array(0.1, 0.01)) - .addGrid(lr.fitIntercept) - .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) - .build() - -// In this case the estimator is simply the linear regression. -// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -val trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator) - .setEstimatorParamMaps(paramGrid) - // 80% of the data will be used for training and the remaining 20% for validation. - .setTrainRatio(0.8) - -// Run train validation split, and choose the best set of parameters. -val model = trainValidationSplit.fit(training) - -// Make predictions on test data. model is the model with combination of parameters -// that performed best. -model.transform(test) - .select("features", "label", "prediction") - .show() - -{% endhighlight %} -
    - -
    -{% highlight java %} -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.regression.LinearRegression; -import org.apache.spark.ml.tuning.*; -import org.apache.spark.sql.DataFrame; - -DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - -// Prepare training and test data. -DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); -DataFrame training = splits[0]; -DataFrame test = splits[1]; - -LinearRegression lr = new LinearRegression(); - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// TrainValidationSplit will try all combinations of values and determine best model using -// the evaluator. -ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam(), new double[] {0.1, 0.01}) - .addGrid(lr.fitIntercept()) - .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) - .build(); - -// In this case the estimator is simply the linear regression. -// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -TrainValidationSplit trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator()) - .setEstimatorParamMaps(paramGrid) - .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation - -// Run train validation split, and choose the best set of parameters. -TrainValidationSplitModel model = trainValidationSplit.fit(training); - -// Make predictions on test data. model is the model with combination of parameters -// that performed best. -model.transform(test) - .select("features", "label", "prediction") - .show(); - -{% endhighlight %} -
    - -
    \ No newline at end of file diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 0c13d7d0c82b3..a8754835cab95 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -1,148 +1,8 @@ --- layout: global -title: Linear Methods - ML -displayTitle: ML - Linear Methods +title: Linear methods - spark.ml +displayTitle: Linear methods - spark.ml --- - -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` - - -In MLlib, we implement popular linear methods such as logistic -regression and linear least squares with $L_1$ or $L_2$ regularization. -Refer to [the linear methods in mllib](mllib-linear-methods.html) for -details. In `spark.ml`, we also include Pipelines API for [Elastic -net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid -of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization -and variable selection via the elastic -net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). -Mathematically, it is defined as a convex combination of the $L_1$ and -the $L_2$ regularization terms: -`\[ -\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 -\]` -By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ -regularization as special cases. For example, if a [linear -regression](https://en.wikipedia.org/wiki/Linear_regression) model is -trained with the elastic net parameter $\alpha$ set to $1$, it is -equivalent to a -[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. -On the other hand, if $\alpha$ is set to $0$, the trained model reduces -to a [ridge -regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. -We implement Pipelines API for both linear regression and logistic -regression with elastic net regularization. - -## Example: Logistic Regression - -The following example shows how to train a logistic regression model -with elastic net regularization. `elasticNetParam` corresponds to -$\alpha$ and `regParam` corresponds to $\lambda$. - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java %} -
    - -
    -{% include_example python/ml/logistic_regression_with_elastic_net.py %} -
    - -
    - -The `spark.ml` implementation of logistic regression also supports -extracting a summary of the model over the training set. Note that the -predictions and metrics which are stored as `Dataframe` in -`BinaryLogisticRegressionSummary` are annotated `@transient` and hence -only available on the driver. - -
    - -
    - -[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) -provides a summary for a -[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). -This will likely change when multiclass classification is supported. - -Continuing the earlier example: - -{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala %} -
    - -
    -[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) -provides a summary for a -[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). -This will likely change when multiclass classification is supported. - -Continuing the earlier example: - -{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java %} -
    - - -
    -Logistic regression model summary is not yet supported in Python. -
    - -
    - -## Example: Linear Regression - -The interface for working with linear regression models and model -summaries is similar to the logistic regression case. The following -example demonstrates training an elastic net regularized linear -regression model and extracting model summary statistics. - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java %} -
    - -
    - -{% include_example python/ml/linear_regression_with_elastic_net.py %} -
    - -
    - -# Optimization - -The optimization algorithm underlying the implementation is called -[Orthant-Wise Limited-memory -QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) -(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 -regularization and elastic net. - + > This section has been moved into the + [classification and regression section](ml-classification-regression.html). diff --git a/docs/ml-survival-regression.md b/docs/ml-survival-regression.md index ab275213b9a84..856ceb2f4e7f6 100644 --- a/docs/ml-survival-regression.md +++ b/docs/ml-survival-regression.md @@ -1,96 +1,8 @@ --- layout: global -title: Survival Regression - ML -displayTitle: ML - Survival Regression +title: Survival Regression - spark.ml +displayTitle: Survival Regression - spark.ml --- - -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` - - -In `spark.ml`, we implement the [Accelerated failure time (AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) -model which is a parametric survival regression model for censored data. -It describes a model for the log of survival time, so it's often called -log-linear model for survival analysis. Different from -[Proportional hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model -designed for the same purpose, the AFT model is more easily to parallelize -because each instance contribute to the objective function independently. - -Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of -subjects i = 1, ..., n, with possible right-censoring, -the likelihood function under the AFT model is given as: -`\[ -L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} -\]` -Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. -Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}$, the log-likelihood function -assumes the form: -`\[ -\iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] -\]` -Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, -and $f_{0}(\epsilon_{i})$ is corresponding density function. - -The most commonly used AFT model is based on the Weibull distribution of the survival time. -The Weibull distribution for lifetime corresponding to extreme value distribution for -log of the lifetime, and the $S_{0}(\epsilon)$ function is: -`\[ -S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) -\]` -the $f_{0}(\epsilon_{i})$ function is: -`\[ -f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) -\]` -The log-likelihood function for AFT model with Weibull distribution of lifetime is: -`\[ -\iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] -\]` -Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, -the loss function we use to optimize is $-\iota(\beta,\sigma)$. -The gradient functions for $\beta$ and $\log\sigma$ respectively are: -`\[ -\frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} -\]` -`\[ -\frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] -\]` - -The AFT model can be formulated as a convex optimization problem, -i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ -that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. -The optimization algorithm underlying the implementation is L-BFGS. -The implementation matches the result from R's survival function -[survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) - -## Example: - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %} -
    - -
    -{% include_example python/ml/aft_survival_regression.py %} -
    - -
    \ No newline at end of file + > This section has been moved into the + [classification and regression section](ml-classification-regression.html#survival-regression). diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index 0210950b89906..aaf8bd465c9ab 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -1,10 +1,10 @@ --- layout: global -title: Classification and Regression - MLlib -displayTitle: MLlib - Classification and Regression +title: Classification and Regression - spark.mllib +displayTitle: Classification and Regression - spark.mllib --- -MLlib supports various methods for +The `spark.mllib` package supports various methods for [binary classification](http://en.wikipedia.org/wiki/Binary_classification), [multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification), and diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 8fbced6c87d9f..48d64cd402b11 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -1,7 +1,7 @@ --- layout: global -title: Clustering - MLlib -displayTitle: MLlib - Clustering +title: Clustering - spark.mllib +displayTitle: Clustering - spark.mllib --- [Clustering](https://en.wikipedia.org/wiki/Cluster_analysis) is an unsupervised learning problem whereby we aim to group subsets @@ -10,19 +10,19 @@ often used for exploratory analysis and/or as a component of a hierarchical [supervised learning](https://en.wikipedia.org/wiki/Supervised_learning) pipeline (in which distinct classifiers or regression models are trained for each cluster). -MLlib supports the following models: +The `spark.mllib` package supports the following models: * Table of contents {:toc} ## K-means -[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the +[K-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the most commonly used clustering algorithms that clusters the data points into a -predefined number of clusters. The MLlib implementation includes a parallelized +predefined number of clusters. The `spark.mllib` implementation includes a parallelized variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). -The implementation in MLlib has the following parameters: +The implementation in `spark.mllib` has the following parameters: * *k* is the number of desired clusters. * *maxIterations* is the maximum number of iterations to run. @@ -171,7 +171,7 @@ sameModel = KMeansModel.load(sc, "myModelPath") A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions, -each with its own probability. The MLlib implementation uses the +each with its own probability. The `spark.mllib` implementation uses the [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) algorithm to induce the maximum-likelihood model given a set of samples. The implementation has the following parameters: @@ -308,13 +308,13 @@ graph given pairwise similarties as edge properties, described in [Lin and Cohen, Power Iteration Clustering](http://www.icml2010.org/papers/387.pdf). It computes a pseudo-eigenvector of the normalized affinity matrix of the graph via [power iteration](http://en.wikipedia.org/wiki/Power_iteration) and uses it to cluster vertices. -MLlib includes an implementation of PIC using GraphX as its backend. +`spark.mllib` includes an implementation of PIC using GraphX as its backend. It takes an `RDD` of `(srcId, dstId, similarity)` tuples and outputs a model with the clustering assignments. The similarities must be nonnegative. PIC assumes that the similarity measure is symmetric. A pair `(srcId, dstId)` regardless of the ordering should appear at most once in the input data. If a pair is missing from input, their similarity is treated as zero. -MLlib's PIC implementation takes the following (hyper-)parameters: +`spark.mllib`'s PIC implementation takes the following (hyper-)parameters: * `k`: number of clusters * `maxIterations`: maximum number of power iterations @@ -323,7 +323,7 @@ MLlib's PIC implementation takes the following (hyper-)parameters: **Examples** -In the following, we show code snippets to demonstrate how to use PIC in MLlib. +In the following, we show code snippets to demonstrate how to use PIC in `spark.mllib`.
    @@ -493,7 +493,7 @@ checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. -All of MLlib's LDA models support: +All of `spark.mllib`'s LDA models support: * `describeTopics`: Returns topics as arrays of most important terms and term weights @@ -721,7 +721,7 @@ sameModel = LDAModel.load(sc, "myModelPath") ## Streaming k-means When data arrive in a stream, we may want to estimate clusters dynamically, -updating them as new data arrive. MLlib provides support for streaming k-means clustering, +updating them as new data arrive. `spark.mllib` provides support for streaming k-means clustering, with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign all points to their nearest cluster, compute new cluster centers, then update each cluster using: diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 7cd1b894e7cb5..1ebb4654aef12 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -1,7 +1,7 @@ --- layout: global -title: Collaborative Filtering - MLlib -displayTitle: MLlib - Collaborative Filtering +title: Collaborative Filtering - spark.mllib +displayTitle: Collaborative Filtering - spark.mllib --- * Table of contents @@ -11,12 +11,12 @@ displayTitle: MLlib - Collaborative Filtering [Collaborative filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) is commonly used for recommender systems. These techniques aim to fill in the -missing entries of a user-item association matrix. MLlib currently supports +missing entries of a user-item association matrix. `spark.mllib` currently supports model-based collaborative filtering, in which users and products are described by a small set of latent factors that can be used to predict missing entries. -MLlib uses the [alternating least squares +`spark.mllib` uses the [alternating least squares (ALS)](http://dl.acm.org/citation.cfm?id=1608614) -algorithm to learn these latent factors. The implementation in MLlib has the +algorithm to learn these latent factors. The implementation in `spark.mllib` has the following parameters: * *numBlocks* is the number of blocks used to parallelize computation (set to -1 to auto-configure). @@ -34,7 +34,7 @@ The standard approach to matrix factorization based collaborative filtering trea the entries in the user-item matrix as *explicit* preferences given by the user to the item. It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, -clicks, purchases, likes, shares etc.). The approach used in MLlib to deal with such data is taken +clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). Essentially instead of trying to model the matrix of ratings directly, this approach treats the data @@ -119,4 +119,4 @@ a dependency. ## Tutorial The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for -[personalized movie recommendation with MLlib](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). +[personalized movie recommendation with `spark.mllib`](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 3c0c0479674df..363dc7c13b306 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -1,7 +1,7 @@ --- layout: global title: Data Types - MLlib -displayTitle: MLlib - Data Types +displayTitle: Data Types - MLlib --- * Table of contents diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 77ce34e91af3c..a8612b6c84fe9 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -1,7 +1,7 @@ --- layout: global -title: Decision Trees - MLlib -displayTitle: MLlib - Decision Trees +title: Decision Trees - spark.mllib +displayTitle: Decision Trees - spark.mllib --- * Table of contents @@ -15,7 +15,7 @@ feature scaling, and are able to capture non-linearities and feature interaction algorithms such as random forests and boosting are among the top performers for classification and regression tasks. -MLlib supports decision trees for binary and multiclass classification and for regression, +`spark.mllib` supports decision trees for binary and multiclass classification and for regression, using both continuous and categorical features. The implementation partitions data by rows, allowing distributed training with millions of instances. diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index ac3526908a9f4..11d8e0bd1d23d 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -1,7 +1,7 @@ --- layout: global -title: Dimensionality Reduction - MLlib -displayTitle: MLlib - Dimensionality Reduction +title: Dimensionality Reduction - spark.mllib +displayTitle: Dimensionality Reduction - spark.mllib --- * Table of contents @@ -11,7 +11,7 @@ displayTitle: MLlib - Dimensionality Reduction of reducing the number of variables under consideration. It can be used to extract latent features from raw and noisy features or compress data while maintaining the structure. -MLlib provides support for dimensionality reduction on the RowMatrix class. +`spark.mllib` provides support for dimensionality reduction on the RowMatrix class. ## Singular value decomposition (SVD) @@ -57,7 +57,7 @@ passes, $O(n)$ storage on each executor, and $O(n k)$ storage on the driver. ### SVD Example -MLlib provides SVD functionality to row-oriented matrices, provided in the +`spark.mllib` provides SVD functionality to row-oriented matrices, provided in the RowMatrix class.
    @@ -141,7 +141,7 @@ statistical method to find a rotation such that the first coordinate has the lar possible, and each succeeding coordinate in turn has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. -MLlib supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors. +`spark.mllib` supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors.
    diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 50450e05d2abb..2416b6fa0aeb3 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -1,7 +1,7 @@ --- layout: global -title: Ensembles - MLlib -displayTitle: MLlib - Ensembles +title: Ensembles - spark.mllib +displayTitle: Ensembles - spark.mllib --- * Table of contents @@ -9,7 +9,7 @@ displayTitle: MLlib - Ensembles An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) is a learning algorithm which creates a model composed of a set of other base models. -MLlib supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). +`spark.mllib` supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). Both use [decision trees](mllib-decision-tree.html) as their base models. ## Gradient-Boosted Trees vs. Random Forests @@ -33,9 +33,9 @@ Like decision trees, random forests handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions. -MLlib supports random forests for binary and multiclass classification and for regression, +`spark.mllib` supports random forests for binary and multiclass classification and for regression, using both continuous and categorical features. -MLlib implements random forests using the existing [decision tree](mllib-decision-tree.html) +`spark.mllib` implements random forests using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. ### Basic algorithm @@ -155,9 +155,9 @@ Like decision trees, GBTs handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions. -MLlib supports GBTs for binary classification and for regression, +`spark.mllib` supports GBTs for binary classification and for regression, using both continuous and categorical features. -MLlib implements GBTs using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. +`spark.mllib` implements GBTs using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. *Note*: GBTs do not yet support multiclass classification. For multiclass problems, please use [decision trees](mllib-decision-tree.html) or [Random Forests](mllib-ensembles.html#Random-Forest). @@ -171,7 +171,7 @@ The specific mechanism for re-labeling instances is defined by a loss function ( #### Losses -The table below lists the losses currently supported by GBTs in MLlib. +The table below lists the losses currently supported by GBTs in `spark.mllib`. Note that each loss is applicable to one of classification or regression, not both. Notation: $N$ = number of instances. $y_i$ = label of instance $i$. $x_i$ = features of instance $i$. $F(x_i)$ = model's predicted label for instance $i$. diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 6924037b941f3..774826c2703f8 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -1,20 +1,20 @@ --- layout: global -title: Evaluation Metrics - MLlib -displayTitle: MLlib - Evaluation Metrics +title: Evaluation Metrics - spark.mllib +displayTitle: Evaluation Metrics - spark.mllib --- * Table of contents {:toc} -Spark's MLlib comes with a number of machine learning algorithms that can be used to learn from and make predictions +`spark.mllib` comes with a number of machine learning algorithms that can be used to learn from and make predictions on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance -of the model on some criteria, which depends on the application and its requirements. Spark's MLlib also provides a +of the model on some criteria, which depends on the application and its requirements. `spark.mllib` also provides a suite of metrics for the purpose of evaluating the performance of machine learning models. Specific machine learning algorithms fall under broader types of machine learning applications like classification, regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those -metrics that are currently available in Spark's MLlib are detailed in this section. +metrics that are currently available in `spark.mllib` are detailed in this section. ## Classification model evaluation diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 5bee170c61fe9..7796bac697562 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -1,7 +1,7 @@ --- layout: global -title: Feature Extraction and Transformation - MLlib -displayTitle: MLlib - Feature Extraction and Transformation +title: Feature Extraction and Transformation - spark.mllib +displayTitle: Feature Extraction and Transformation - spark.mllib --- * Table of contents @@ -31,7 +31,7 @@ The TF-IDF measure is simply the product of TF and IDF: TFIDF(t, d, D) = TF(t, d) \cdot IDF(t, D). \]` There are several variants on the definition of term frequency and document frequency. -In MLlib, we separate TF and IDF to make them flexible. +In `spark.mllib`, we separate TF and IDF to make them flexible. Our implementation of term frequency utilizes the [hashing trick](http://en.wikipedia.org/wiki/Feature_hashing). @@ -44,7 +44,7 @@ To reduce the chance of collision, we can increase the target feature dimension, the number of buckets of the hash table. The default feature dimension is `$2^{20} = 1,048,576$`. -**Note:** MLlib doesn't provide tools for text segmentation. +**Note:** `spark.mllib` doesn't provide tools for text segmentation. We refer users to the [Stanford NLP Group](http://nlp.stanford.edu/) and [scalanlp/chalk](https://github.com/scalanlp/chalk). @@ -86,7 +86,7 @@ val idf = new IDF().fit(tf) val tfidf: RDD[Vector] = idf.transform(tf) {% endhighlight %} -MLlib's IDF implementation provides an option for ignoring terms which occur in less than a +`spark.mllib`'s IDF implementation provides an option for ignoring terms which occur in less than a minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature can be used by passing the `minDocFreq` value to the IDF constructor. @@ -134,7 +134,7 @@ idf = IDF().fit(tf) tfidf = idf.transform(tf) {% endhighlight %} -MLLib's IDF implementation provides an option for ignoring terms which occur in less than a +`spark.mllib`'s IDF implementation provides an option for ignoring terms which occur in less than a minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature can be used by passing the `minDocFreq` value to the IDF constructor. diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index fe42896a05d8e..2c8a8f236163f 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -1,7 +1,7 @@ --- layout: global -title: Frequent Pattern Mining - MLlib -displayTitle: MLlib - Frequent Pattern Mining +title: Frequent Pattern Mining - spark.mllib +displayTitle: Frequent Pattern Mining - spark.mllib --- Mining frequent items, itemsets, subsequences, or other substructures is usually among the @@ -9,7 +9,7 @@ first steps to analyze a large-scale dataset, which has been an active research data mining for years. We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning) for more information. -MLlib provides a parallel implementation of FP-growth, +`spark.mllib` provides a parallel implementation of FP-growth, a popular algorithm to mining frequent itemsets. ## FP-growth @@ -22,13 +22,13 @@ Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) al the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. -In MLlib, we implemented a parallel version of FP-growth called PFP, +In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). PFP distributes the work of growing FP-trees based on the suffices of transactions, and hence more scalable than a single-machine implementation. We refer users to the papers for more details. -MLlib's FP-growth implementation takes the following (hyper-)parameters: +`spark.mllib`'s FP-growth implementation takes the following (hyper-)parameters: * `minSupport`: the minimum support for an itemset to be identified as frequent. For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6. @@ -126,7 +126,7 @@ PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer the reader to the referenced paper for formalizing the sequential pattern mining problem. -MLlib's PrefixSpan implementation takes the following parameters: +`spark.mllib`'s PrefixSpan implementation takes the following parameters: * `minSupport`: the minimum support required to be considered a frequent sequential pattern. diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 3bc2b780601c2..7fef6b5c61f99 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -66,7 +66,7 @@ We list major functionality from both below, with links to detailed guides. # spark.ml: high-level APIs for ML pipelines -* [Overview: estimators, transformers and pipelines](ml-intro.html) +* [Overview: estimators, transformers and pipelines](ml-guide.html) * [Extracting, transforming and selecting features](ml-features.html) * [Classification and regression](ml-classification-regression.html) * [Clustering](ml-clustering.html) diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 85f9226b43416..8ede4407d5843 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -1,7 +1,7 @@ --- layout: global -title: Isotonic regression - MLlib -displayTitle: MLlib - Regression +title: Isotonic regression - spark.mllib +displayTitle: Regression - spark.mllib --- ## Isotonic regression @@ -23,7 +23,7 @@ Essentially isotonic regression is a [monotonic function](http://en.wikipedia.org/wiki/Monotonic_function) best fitting the original data points. -MLlib supports a +`spark.mllib` supports a [pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 132f8c354aa9c..20b35612cab95 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -1,7 +1,7 @@ --- layout: global -title: Linear Methods - MLlib -displayTitle: MLlib - Linear Methods +title: Linear Methods - spark.mllib +displayTitle: Linear Methods - spark.mllib --- * Table of contents @@ -41,7 +41,7 @@ the objective function is of the form Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and `$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. -Several of MLlib's classification and regression algorithms fall into this category, +Several of `spark.mllib`'s classification and regression algorithms fall into this category, and are discussed here. The objective function `$f$` has two parts: @@ -55,7 +55,7 @@ training error) and minimizing model complexity (i.e., to avoid overfitting). ### Loss functions The following table summarizes the loss functions and their gradients or sub-gradients for the -methods MLlib supports: +methods `spark.mllib` supports: @@ -83,7 +83,7 @@ methods MLlib supports: The purpose of the [regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to encourage simple models and avoid overfitting. We support the following -regularizers in MLlib: +regularizers in `spark.mllib`:
    @@ -115,7 +115,10 @@ especially when the number of training examples is small. ### Optimization -Under the hood, linear methods use convex optimization methods to optimize the objective functions. MLlib uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods. +Under the hood, linear methods use convex optimization methods to optimize the objective functions. +`spark.mllib` uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). +Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. +Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods. ## Classification @@ -126,16 +129,16 @@ The most common classification type is categories, usually named positive and negative. If there are more than two categories, it is called [multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification). -MLlib supports two linear methods for classification: linear Support Vector Machines (SVMs) +`spark.mllib` supports two linear methods for classification: linear Support Vector Machines (SVMs) and logistic regression. Linear SVMs supports only binary classification, while logistic regression supports both binary and multiclass classification problems. -For both methods, MLlib supports L1 and L2 regularized variants. +For both methods, `spark.mllib` supports L1 and L2 regularized variants. The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, where labels are class indices starting from zero: $0, 1, 2, \ldots$. Note that, in the mathematical formulation in this guide, a binary label $y$ is denoted as either $+1$ (positive) or $-1$ (negative), which is convenient for the formulation. -*However*, the negative label is represented by $0$ in MLlib instead of $-1$, to be consistent with +*However*, the negative label is represented by $0$ in `spark.mllib` instead of $-1$, to be consistent with multiclass labeling. ### Linear Support Vector Machines (SVMs) @@ -207,7 +210,7 @@ val sameModel = SVMModel.load(sc, "myModelPath") The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we can customize `SVMWithSGD` further by creating a new object directly and -calling setter methods. All other MLlib algorithms support customization in +calling setter methods. All other `spark.mllib` algorithms support customization in this way as well. For example, the following code produces an L1 regularized variant of SVMs with regularization parameter set to 0.1, and runs the training algorithm for 200 iterations. @@ -293,7 +296,7 @@ public class SVMClassifier { The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we can customize `SVMWithSGD` further by creating a new object directly and -calling setter methods. All other MLlib algorithms support customization in +calling setter methods. All other `spark.mllib` algorithms support customization in this way as well. For example, the following code produces an L1 regularized variant of SVMs with regularization parameter set to 0.1, and runs the training algorithm for 200 iterations. @@ -375,7 +378,7 @@ Binary logistic regression can be generalized into train and predict multiclass classification problems. For example, for $K$ possible outcomes, one of the outcomes can be chosen as a "pivot", and the other $K - 1$ outcomes can be separately regressed against the pivot outcome. -In MLlib, the first class $0$ is chosen as the "pivot" class. +In `spark.mllib`, the first class $0$ is chosen as the "pivot" class. See Section 4.4 of [The Elements of Statistical Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for references. @@ -726,7 +729,7 @@ a dependency. ###Streaming linear regression When data arrive in a streaming fashion, it is useful to fit regression models online, -updating the parameters of the model as new data arrives. MLlib currently supports +updating the parameters of the model as new data arrives. `spark.mllib` currently supports streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. @@ -852,7 +855,7 @@ will get better! # Implementation (developer) -Behind the scene, MLlib implements a simple distributed version of stochastic gradient descent +Behind the scene, `spark.mllib` implements a simple distributed version of stochastic gradient descent (SGD), building on the underlying gradient descent primitive (as described in the optimization section). All provided algorithms take as input a regularization parameter (`regParam`) along with various parameters associated with stochastic diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 774b85d1f773a..73e4fddf67fc0 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -1,7 +1,7 @@ --- layout: global -title: Old Migration Guides - MLlib -displayTitle: MLlib - Old Migration Guides +title: Old Migration Guides - spark.mllib +displayTitle: Old Migration Guides - spark.mllib description: MLlib migration guides from before Spark SPARK_VERSION_SHORT --- diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 60ac6c7e5bb1a..d0d594af6a4ad 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -1,7 +1,7 @@ --- layout: global -title: Naive Bayes - MLlib -displayTitle: MLlib - Naive Bayes +title: Naive Bayes - spark.mllib +displayTitle: Naive Bayes - spark.mllib --- [Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) is a simple @@ -12,7 +12,7 @@ distribution of each feature given label, and then it applies Bayes' theorem to compute the conditional probability distribution of label given an observation and use it for prediction. -MLlib supports [multinomial naive +`spark.mllib` supports [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index ad7bcd9bfd407..f90b66f8e2c44 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -1,7 +1,7 @@ --- layout: global -title: Optimization - MLlib -displayTitle: MLlib - Optimization +title: Optimization - spark.mllib +displayTitle: Optimization - spark.mllib --- * Table of contents @@ -87,7 +87,7 @@ in the `$t$`-th iteration, with the input parameter `$s=$ stepSize`. Note that s step-size for SGD methods can often be delicate in practice and is a topic of active research. **Gradients.** -A table of (sub)gradients of the machine learning methods implemented in MLlib, is available in +A table of (sub)gradients of the machine learning methods implemented in `spark.mllib`, is available in the classification and regression section. @@ -140,7 +140,7 @@ other first-order optimization. ### Choosing an Optimization Method -[Linear methods](mllib-linear-methods.html) use optimization internally, and some linear methods in MLlib support both SGD and L-BFGS. +[Linear methods](mllib-linear-methods.html) use optimization internally, and some linear methods in `spark.mllib` support both SGD and L-BFGS. Different optimization methods can have different convergence guarantees depending on the properties of the objective function, and we cannot cover the literature here. In general, when L-BFGS is available, we recommend using it instead of SGD since L-BFGS tends to converge faster (in fewer iterations). diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md index 615287125c032..b532ad907dfc5 100644 --- a/docs/mllib-pmml-model-export.md +++ b/docs/mllib-pmml-model-export.md @@ -1,21 +1,21 @@ --- layout: global -title: PMML model export - MLlib -displayTitle: MLlib - PMML model export +title: PMML model export - spark.mllib +displayTitle: PMML model export - spark.mllib --- * Table of contents {:toc} -## MLlib supported models +## `spark.mllib` supported models -MLlib supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). +`spark.mllib` supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). -The table below outlines the MLlib models that can be exported to PMML and their equivalent PMML model. +The table below outlines the `spark.mllib` models that can be exported to PMML and their equivalent PMML model.
    - + diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index de209f68e19ca..652d215fa8653 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -1,7 +1,7 @@ --- layout: global -title: Basic Statistics - MLlib -displayTitle: MLlib - Basic Statistics +title: Basic Statistics - spark.mllib +displayTitle: Basic Statistics - spark.mllib --- * Table of contents @@ -112,7 +112,7 @@ print(summary.numNonzeros()) ## Correlations -Calculating the correlation between two series of data is a common operation in Statistics. In MLlib +Calculating the correlation between two series of data is a common operation in Statistics. In `spark.mllib` we provide the flexibility to calculate pairwise correlations among many series. The supported correlation methods are currently Pearson's and Spearman's correlation. @@ -209,7 +209,7 @@ print(Statistics.corr(data, method="pearson")) ## Stratified sampling -Unlike the other statistics functions, which reside in MLlib, stratified sampling methods, +Unlike the other statistics functions, which reside in `spark.mllib`, stratified sampling methods, `sampleByKey` and `sampleByKeyExact`, can be performed on RDD's of key-value pairs. For stratified sampling, the keys can be thought of as a label and the value as a specific attribute. For example the key can be man or woman, or document ids, and the respective values can be the list of ages @@ -294,12 +294,12 @@ approxSample = data.sampleByKey(False, fractions); ## Hypothesis testing Hypothesis testing is a powerful tool in statistics to determine whether a result is statistically -significant, whether this result occurred by chance or not. MLlib currently supports Pearson's +significant, whether this result occurred by chance or not. `spark.mllib` currently supports Pearson's chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine whether the goodness of fit or the independence test is conducted. The goodness of fit test requires an input type of `Vector`, whereas the independence test requires a `Matrix` as input. -MLlib also supports the input type `RDD[LabeledPoint]` to enable feature selection via chi-squared +`spark.mllib` also supports the input type `RDD[LabeledPoint]` to enable feature selection via chi-squared independence tests.
    @@ -438,7 +438,7 @@ for i, result in enumerate(featureTestResults):
    -Additionally, MLlib provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test +Additionally, `spark.mllib` provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test for equality of probability distributions. By providing the name of a theoretical distribution (currently solely supported for the normal distribution) and its parameters, or a function to calculate the cumulative distribution according to a given theoretical distribution, the user can @@ -522,7 +522,7 @@ print(testResult) # summary of the test including the p-value, test statistic, ### Streaming Significance Testing -MLlib provides online implementations of some tests to support use cases +`spark.mllib` provides online implementations of some tests to support use cases like A/B testing. These tests may be performed on a Spark Streaming `DStream[(Boolean,Double)]` where the first element of each tuple indicates control group (`false`) or treatment group (`true`) and the @@ -550,7 +550,7 @@ provides streaming hypothesis testing. ## Random data generation Random data generation is useful for randomized algorithms, prototyping, and performance testing. -MLlib supports generating random RDDs with i.i.d. values drawn from a given distribution: +`spark.mllib` supports generating random RDDs with i.i.d. values drawn from a given distribution: uniform, standard normal, or Poisson.
    From 34a9bc1ef9a4642917d0190a61794411df576fe0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 10 Dec 2015 13:26:30 -0800 Subject: [PATCH 1111/1168] [SPARK-11563][CORE][REPL] Use RpcEnv to transfer REPL-generated classes. This avoids bringing up yet another HTTP server on the driver, and instead reuses the file server already managed by the driver's RpcEnv. As a bonus, the repl now inherits the security features of the network library. There's also a small change to create the directory for storing classes under the root temp dir for the application (instead of directly under java.io.tmpdir). Author: Marcelo Vanzin Closes #9923 from vanzin/SPARK-11563. --- .../org/apache/spark/HttpFileServer.scala | 5 ++ .../scala/org/apache/spark/HttpServer.scala | 26 ++++++---- .../scala/org/apache/spark/SparkContext.scala | 6 +++ .../org/apache/spark/executor/Executor.scala | 6 +-- .../scala/org/apache/spark/rpc/RpcEnv.scala | 18 +++++++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 5 ++ .../spark/rpc/netty/NettyStreamManager.scala | 22 ++++++++- .../org/apache/spark/rpc/RpcEnvSuite.scala | 25 +++++++++- docs/configuration.md | 8 ---- docs/security.md | 8 ---- .../org/apache/spark/repl/SparkILoop.scala | 17 ++++--- .../org/apache/spark/repl/SparkIMain.scala | 28 ++--------- .../scala/org/apache/spark/repl/Main.scala | 23 ++++----- .../spark/repl/ExecutorClassLoader.scala | 36 +++++++------- .../spark/repl/ExecutorClassLoaderSuite.scala | 48 +++++++++++++++---- 15 files changed, 183 insertions(+), 98 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 7cf7bc0dc6810..77d8ec9bb1607 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -71,6 +71,11 @@ private[spark] class HttpFileServer( serverUri + "/jars/" + file.getName } + def addDirectory(path: String, resourceBase: String): String = { + httpServer.addDirectory(path, resourceBase) + serverUri + path + } + def addFileToDir(file: File, dir: File) : String = { // Check whether the file is a directory. If it is, throw a more meaningful exception. // If we don't catch this, Guava throws a very confusing error message: diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 8de3a6c04df34..faa3ef3d7561d 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -23,10 +23,9 @@ import org.eclipse.jetty.server.ssl.SslSocketConnector import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.security.authentication.DigestAuthenticator import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService} - import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector -import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler} +import org.eclipse.jetty.servlet.{DefaultServlet, ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.apache.spark.util.Utils @@ -52,6 +51,11 @@ private[spark] class HttpServer( private var server: Server = null private var port: Int = requestedPort + private val servlets = { + val handler = new ServletContextHandler() + handler.setContextPath("/") + handler + } def start() { if (server != null) { @@ -65,6 +69,14 @@ private[spark] class HttpServer( } } + def addDirectory(contextPath: String, resourceBase: String): Unit = { + val holder = new ServletHolder() + holder.setInitParameter("resourceBase", resourceBase) + holder.setInitParameter("pathInfoOnly", "true") + holder.setServlet(new DefaultServlet()) + servlets.addServlet(holder, contextPath.stripSuffix("/") + "/*") + } + /** * Actually start the HTTP server on the given port. * @@ -85,21 +97,17 @@ private[spark] class HttpServer( val threadPool = new QueuedThreadPool threadPool.setDaemon(true) server.setThreadPool(threadPool) - val resHandler = new ResourceHandler - resHandler.setResourceBase(resourceBase.getAbsolutePath) - - val handlerList = new HandlerList - handlerList.setHandlers(Array(resHandler, new DefaultHandler)) + addDirectory("/", resourceBase.getAbsolutePath) if (securityManager.isAuthenticationEnabled()) { logDebug("HttpServer is using security") val sh = setupSecurityHandler(securityManager) // make sure we go through security handler to get resources - sh.setHandler(handlerList) + sh.setHandler(servlets) server.setHandler(sh) } else { logDebug("HttpServer is not using security") - server.setHandler(handlerList) + server.setHandler(servlets) } server.start() diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8a62b71c3fa68..194ecc0a0434e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -457,6 +457,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _env = createSparkEnv(_conf, isLocal, listenerBus) SparkEnv.set(_env) + // If running the REPL, register the repl's output dir with the file server. + _conf.getOption("spark.repl.class.outputDir").foreach { path => + val replUri = _env.rpcEnv.fileServer.addDirectory("/classes", new File(path)) + _conf.set("spark.repl.class.uri", replUri) + } + _metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf) _statusTracker = new SparkStatusTracker(this) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 7b68dfe5ad06e..552b644d13aaf 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -364,9 +364,9 @@ private[spark] class Executor( val _userClassPathFirst: java.lang.Boolean = userClassPathFirst val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] - val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], - classOf[ClassLoader], classOf[Boolean]) - constructor.newInstance(conf, classUri, parent, _userClassPathFirst) + val constructor = klass.getConstructor(classOf[SparkConf], classOf[SparkEnv], + classOf[String], classOf[ClassLoader], classOf[Boolean]) + constructor.newInstance(conf, env, classUri, parent, _userClassPathFirst) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 3d7d281b0dd66..64a4a8bf7c5eb 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -179,6 +179,24 @@ private[spark] trait RpcEnvFileServer { */ def addJar(file: File): String + /** + * Adds a local directory to be served via this file server. + * + * @param baseUri Leading URI path (files can be retrieved by appending their relative + * path to this base URI). This cannot be "files" nor "jars". + * @param path Path to the local directory. + * @return URI for the root of the directory in the file server. + */ + def addDirectory(baseUri: String, path: File): String + + /** Validates and normalizes the base URI for directories. */ + protected def validateDirectoryUri(baseUri: String): String = { + val fixedBaseUri = "/" + baseUri.stripPrefix("/").stripSuffix("/") + require(fixedBaseUri != "/files" && fixedBaseUri != "/jars", + "Directory URI cannot be /files nor /jars.") + fixedBaseUri + } + } private[spark] case class RpcEnvConfig( diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 94dbec593c315..9d098154f7190 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -273,6 +273,11 @@ private[akka] class AkkaFileServer( getFileServer().addJar(file) } + override def addDirectory(baseUri: String, path: File): String = { + val fixedBaseUri = validateDirectoryUri(baseUri) + getFileServer().addDirectory(fixedBaseUri, path.getAbsolutePath()) + } + def shutdown(): Unit = { if (httpFileServer != null) { httpFileServer.stop() diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index a2768b4252dcb..ecd96972455d0 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -25,12 +25,22 @@ import org.apache.spark.rpc.RpcEnvFileServer /** * StreamManager implementation for serving files from a NettyRpcEnv. + * + * Three kinds of resources can be registered in this manager, all backed by actual files: + * + * - "/files": a flat list of files; used as the backend for [[SparkContext.addFile]]. + * - "/jars": a flat list of files; used as the backend for [[SparkContext.addJar]]. + * - arbitrary directories; all files under the directory become available through the manager, + * respecting the directory's hierarchy. + * + * Only streaming (openStream) is supported. */ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) extends StreamManager with RpcEnvFileServer { private val files = new ConcurrentHashMap[String, File]() private val jars = new ConcurrentHashMap[String, File]() + private val dirs = new ConcurrentHashMap[String, File]() override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { throw new UnsupportedOperationException() @@ -41,7 +51,10 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) val file = ftype match { case "files" => files.get(fname) case "jars" => jars.get(fname) - case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") + case other => + val dir = dirs.get(ftype) + require(dir != null, s"Invalid stream URI: $ftype not found.") + new File(dir, fname) } require(file != null && file.isFile(), s"File not found: $streamId") @@ -60,4 +73,11 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}" } + override def addDirectory(baseUri: String, path: File): String = { + val fixedBaseUri = validateDirectoryUri(baseUri) + require(dirs.putIfAbsent(fixedBaseUri.stripPrefix("/"), path) == null, + s"URI '$fixedBaseUri' already registered.") + s"${rpcEnv.address.toSparkURL}$fixedBaseUri" + } + } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 6cc958a5f6bc8..a61d0479aacdb 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -734,9 +734,28 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val jar = new File(tempDir, "jar") Files.write(UUID.randomUUID().toString(), jar, UTF_8) + val dir1 = new File(tempDir, "dir1") + assert(dir1.mkdir()) + val subFile1 = new File(dir1, "file1") + Files.write(UUID.randomUUID().toString(), subFile1, UTF_8) + + val dir2 = new File(tempDir, "dir2") + assert(dir2.mkdir()) + val subFile2 = new File(dir2, "file2") + Files.write(UUID.randomUUID().toString(), subFile2, UTF_8) + val fileUri = env.fileServer.addFile(file) val emptyUri = env.fileServer.addFile(empty) val jarUri = env.fileServer.addJar(jar) + val dir1Uri = env.fileServer.addDirectory("/dir1", dir1) + val dir2Uri = env.fileServer.addDirectory("/dir2", dir2) + + // Try registering directories with invalid names. + Seq("/files", "/jars").foreach { uri => + intercept[IllegalArgumentException] { + env.fileServer.addDirectory(uri, dir1) + } + } val destDir = Utils.createTempDir() val sm = new SecurityManager(conf) @@ -745,7 +764,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val files = Seq( (file, fileUri), (empty, emptyUri), - (jar, jarUri)) + (jar, jarUri), + (subFile1, dir1Uri + "/file1"), + (subFile2, dir2Uri + "/file2")) files.foreach { case (f, uri) => val destFile = new File(destDir, f.getName()) Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) @@ -753,7 +774,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } // Try to download files that do not exist. - Seq("files", "jars").foreach { root => + Seq("files", "jars", "dir1").foreach { root => intercept[Exception] { val uri = env.address.toSparkURL + s"/$root/doesNotExist" Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) diff --git a/docs/configuration.md b/docs/configuration.md index fd61ddc244f44..873a2d0b303cd 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1053,14 +1053,6 @@ Apart from these, the following properties are also available, and may be useful to port + maxRetries.
    - - - - - diff --git a/docs/security.md b/docs/security.md index e1af221d446b0..0bfc791c5744e 100644 --- a/docs/security.md +++ b/docs/security.md @@ -169,14 +169,6 @@ configure those ports. - - - - - - - - diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 304b1e8cdbed5..22749c4609345 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -253,7 +253,7 @@ class SparkILoop( case xs => xs find (_.name == cmd) } } - private var fallbackMode = false + private var fallbackMode = false private def toggleFallbackMode() { val old = fallbackMode @@ -261,9 +261,9 @@ class SparkILoop( System.setProperty("spark.repl.fallback", fallbackMode.toString) echo(s""" |Switched ${if (old) "off" else "on"} fallback mode without restarting. - | If you have defined classes in the repl, it would + | If you have defined classes in the repl, it would |be good to redefine them incase you plan to use them. If you still run - |into issues it would be good to restart the repl and turn on `:fallback` + |into issues it would be good to restart the repl and turn on `:fallback` |mode as first command. """.stripMargin) } @@ -350,7 +350,7 @@ class SparkILoop( shCommand, nullary("silent", "disable/enable automatic printing of results", verbosity), nullary("fallback", """ - |disable/enable advanced repl changes, these fix some issues but may introduce others. + |disable/enable advanced repl changes, these fix some issues but may introduce others. |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode), cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) @@ -1009,8 +1009,13 @@ class SparkILoop( val conf = new SparkConf() .setMaster(getMaster()) .setJars(jars) - .set("spark.repl.class.uri", intp.classServerUri) .setIfMissing("spark.app.name", "Spark shell") + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + .set("spark.repl.class.outputDir", intp.outputDir.getAbsolutePath()) if (execUri != null) { conf.set("spark.executor.uri", execUri) } @@ -1025,7 +1030,7 @@ class SparkILoop( val loader = Utils.getContextOrSparkClassLoader try { sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) - .newInstance(sparkContext).asInstanceOf[SQLContext] + .newInstance(sparkContext).asInstanceOf[SQLContext] logInfo("Created sql context (with Hive support)..") } catch { diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 829b12269fd2b..7fcb423575d39 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -37,7 +37,7 @@ import scala.reflect.{ ClassTag, classTag } import scala.tools.reflect.StdRuntimeTags._ import scala.util.control.ControlThrowable -import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.util.Utils import org.apache.spark.annotation.DeveloperApi @@ -96,10 +96,9 @@ import org.apache.spark.annotation.DeveloperApi private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") /** Local directory to save .class files too */ - private lazy val outputDir = { - val tmp = System.getProperty("java.io.tmpdir") - val rootDir = conf.get("spark.repl.classdir", tmp) - Utils.createTempDir(rootDir) + private[repl] val outputDir = { + val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) + Utils.createTempDir(root = rootDir, namePrefix = "repl") } if (SPARK_DEBUG_REPL) { echo("Output directory: " + outputDir) @@ -114,8 +113,6 @@ import org.apache.spark.annotation.DeveloperApi private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ - private val classServerPort = conf.getInt("spark.replClassServer.port", 0) - private val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings private var printResults = true // whether to print result lines private var totalSilence = false // whether to print anything @@ -124,22 +121,6 @@ import org.apache.spark.annotation.DeveloperApi private var bindExceptions = true // whether to bind the lastException variable private var _executionWrapper = "" // code to be wrapped around all lines - - // Start the classServer and store its URI in a spark system property - // (which will be passed to executors so that they can connect to it) - classServer.start() - if (SPARK_DEBUG_REPL) { - echo("Class server started, URI = " + classServer.uri) - } - - /** - * URI of the class server used to feed REPL compiled classes. - * - * @return The string representing the class server uri - */ - @DeveloperApi - def classServerUri = classServer.uri - /** We're going to go to some trouble to initialize the compiler asynchronously. * It's critical that nothing call into it until it's been initialized or we will * run into unrecoverable issues, but the perceived repl startup time goes @@ -994,7 +975,6 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi def close() { reporter.flush() - classServer.stop() } /** diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 455a6b9a93aad..44650f25f7a18 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -28,11 +28,13 @@ import org.apache.spark.sql.SQLContext object Main extends Logging { val conf = new SparkConf() - val tmp = System.getProperty("java.io.tmpdir") - val rootDir = conf.get("spark.repl.classdir", tmp) - val outputDir = Utils.createTempDir(rootDir) + val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) + val outputDir = Utils.createTempDir(root = rootDir, namePrefix = "repl") + val s = new Settings() + s.processArguments(List("-Yrepl-class-based", + "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", + "-classpath", getAddedJars.mkString(File.pathSeparator)), true) // the creation of SecurityManager has to be lazy so SPARK_YARN_MODE is set if needed - lazy val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ var interp = new SparkILoop // this is a public var because tests reset it. @@ -45,7 +47,6 @@ object Main extends Logging { } def main(args: Array[String]) { - val interpArguments = List( "-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", @@ -57,11 +58,7 @@ object Main extends Logging { if (!hasErrors) { if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") - // Start the classServer and store its URI in a spark system property - // (which will be passed to executors so that they can connect to it) - classServer.start() interp.process(settings) // Repl starts and goes in loop of R.E.P.L - classServer.stop() Option(sparkContext).map(_.stop) } } @@ -82,9 +79,13 @@ object Main extends Logging { val conf = new SparkConf() .setMaster(getMaster) .setJars(jars) - .set("spark.repl.class.uri", classServer.uri) .setIfMissing("spark.app.name", "Spark shell") - logInfo("Spark class server started at " + classServer.uri) + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + .set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) if (execUri != null) { conf.set("spark.executor.uri", execUri) } diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index a8859fcd4584b..da8f0aa1e3360 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -19,6 +19,7 @@ package org.apache.spark.repl import java.io.{IOException, ByteArrayOutputStream, InputStream} import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import java.nio.channels.Channels import scala.util.control.NonFatal @@ -38,7 +39,11 @@ import org.apache.spark.util.ParentClassLoader * This class loader delegates getting/finding resources to parent loader, * which makes sense until REPL never provide resource dynamically. */ -class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader, +class ExecutorClassLoader( + conf: SparkConf, + env: SparkEnv, + classUri: String, + parent: ClassLoader, userClassPathFirst: Boolean) extends ClassLoader with Logging { val uri = new URI(classUri) val directory = uri.getPath @@ -48,13 +53,12 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 - // Hadoop FileSystem object for our URI, if it isn't using HTTP - var fileSystem: FileSystem = { - if (Set("http", "https", "ftp").contains(uri.getScheme)) { - null - } else { - FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) - } + private val fetchFn: (String) => InputStream = uri.getScheme() match { + case "spark" => getClassFileInputStreamFromSparkRPC + case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer + case _ => + val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) + getClassFileInputStreamFromFileSystem(fileSystem) } override def getResource(name: String): URL = { @@ -90,6 +94,11 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } + private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = { + val channel = env.rpcEnv.openChannel(s"$classUri/$path") + Channels.newInputStream(channel) + } + private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) @@ -126,7 +135,8 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } - private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = { + private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)( + pathInDirectory: String): InputStream = { val path = new Path(directory, pathInDirectory) if (fileSystem.exists(path)) { fileSystem.open(path) @@ -139,13 +149,7 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader val pathInDirectory = name.replace('.', '/') + ".class" var inputStream: InputStream = null try { - inputStream = { - if (fileSystem != null) { - getClassFileInputStreamFromFileSystem(pathInDirectory) - } else { - getClassFileInputStreamFromHttpServer(pathInDirectory) - } - } + inputStream = fetchFn(pathInDirectory) val bytes = readAndTransformClass(name, inputStream) Some(defineClass(name, bytes, 0, bytes.length)) } catch { diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index c1211f7596b9c..1360f09e7fa1f 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -18,24 +18,29 @@ package org.apache.spark.repl import java.io.File -import java.net.{URL, URLClassLoader} +import java.net.{URI, URL, URLClassLoader} +import java.nio.channels.{FileChannel, ReadableByteChannel} import java.nio.charset.StandardCharsets +import java.nio.file.{Paths, StandardOpenOption} import java.util -import com.google.common.io.Files - import scala.concurrent.duration._ import scala.io.Source import scala.language.implicitConversions import scala.language.postfixOps +import com.google.common.io.Files import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Interruptor import org.scalatest.concurrent.Timeouts._ import org.scalatest.mock.MockitoSugar +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.mockito.Matchers.anyString import org.mockito.Mockito._ import org.apache.spark._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils class ExecutorClassLoaderSuite @@ -78,7 +83,7 @@ class ExecutorClassLoaderSuite test("child first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") @@ -86,7 +91,7 @@ class ExecutorClassLoaderSuite test("parent first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, false) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, false) val fakeClass = classLoader.loadClass("ReplFakeClass1").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -94,7 +99,7 @@ class ExecutorClassLoaderSuite test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val fakeClass = classLoader.loadClass("ReplFakeClass3").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -102,7 +107,7 @@ class ExecutorClassLoaderSuite test("child first can fail") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) intercept[java.lang.ClassNotFoundException] { classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() } @@ -110,7 +115,7 @@ class ExecutorClassLoaderSuite test("resource from parent") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val resourceName: String = parentResourceNames.head val is = classLoader.getResourceAsStream(resourceName) assert(is != null, s"Resource $resourceName not found") @@ -120,7 +125,7 @@ class ExecutorClassLoaderSuite test("resources from parent") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val resourceName: String = parentResourceNames.head val resources: util.Enumeration[URL] = classLoader.getResources(resourceName) assert(resources.hasMoreElements, s"Resource $resourceName not found") @@ -142,7 +147,7 @@ class ExecutorClassLoaderSuite SparkEnv.set(mockEnv) // Create an ExecutorClassLoader that's configured to load classes from the HTTP server val parentLoader = new URLClassLoader(Array.empty, null) - val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false) + val classLoader = new ExecutorClassLoader(conf, null, classServer.uri, parentLoader, false) classLoader.httpUrlConnectionTimeoutMillis = 500 // Check that this class loader can actually load classes that exist val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() @@ -177,4 +182,27 @@ class ExecutorClassLoaderSuite failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor) } + test("fetch classes using Spark's RpcEnv") { + val env = mock[SparkEnv] + val rpcEnv = mock[RpcEnv] + when(env.rpcEnv).thenReturn(rpcEnv) + when(rpcEnv.openChannel(anyString())).thenAnswer(new Answer[ReadableByteChannel]() { + override def answer(invocation: InvocationOnMock): ReadableByteChannel = { + val uri = new URI(invocation.getArguments()(0).asInstanceOf[String]) + val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/")) + FileChannel.open(path, StandardOpenOption.READ) + } + }) + + val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234", + getClass().getClassLoader(), false) + + val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "1") + intercept[java.lang.ClassNotFoundException] { + classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() + } + } + } From 4983e4d4944f1274a2d57ecc6fb815d9d263a0c3 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 10 Dec 2015 14:21:15 -0800 Subject: [PATCH 1112/1168] [SPARK-11713] [PYSPARK] [STREAMING] Initial RDD updateStateByKey for PySpark Adding ability to define an initial state RDD for use with updateStateByKey PySpark. Added unit test and changed stateful_network_wordcount example to use initial RDD. Author: Bryan Cutler Closes #10082 from BryanCutler/initial-rdd-updateStateByKey-SPARK-11713. --- .../streaming/stateful_network_wordcount.py | 5 ++++- python/pyspark/streaming/dstream.py | 13 ++++++++++-- python/pyspark/streaming/tests.py | 20 +++++++++++++++++++ .../streaming/api/python/PythonDStream.scala | 14 +++++++++++-- 4 files changed, 47 insertions(+), 5 deletions(-) diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index 16ef646b7c42e..f8bbc659c2ea7 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -44,13 +44,16 @@ ssc = StreamingContext(sc, 1) ssc.checkpoint("checkpoint") + # RDD with initial state (key, value) pairs + initialStateRDD = sc.parallelize([(u'hello', 1), (u'world', 1)]) + def updateFunc(new_values, last_sum): return sum(new_values) + (last_sum or 0) lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) running_counts = lines.flatMap(lambda line: line.split(" "))\ .map(lambda word: (word, 1))\ - .updateStateByKey(updateFunc) + .updateStateByKey(updateFunc, initialRDD=initialStateRDD) running_counts.pprint() diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index acec850f02c2d..f61137cb88c47 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -568,7 +568,7 @@ def invReduceFunc(t, a, b): self._ssc._jduration(slideDuration)) return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) - def updateStateByKey(self, updateFunc, numPartitions=None): + def updateStateByKey(self, updateFunc, numPartitions=None, initialRDD=None): """ Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values of the key. @@ -579,6 +579,9 @@ def updateStateByKey(self, updateFunc, numPartitions=None): if numPartitions is None: numPartitions = self._sc.defaultParallelism + if initialRDD and not isinstance(initialRDD, RDD): + initialRDD = self._sc.parallelize(initialRDD) + def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) @@ -590,7 +593,13 @@ def reduceFunc(t, a, b): jreduceFunc = TransformFunction(self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer) - dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + if initialRDD: + initialRDD = initialRDD._reserialize(self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc, + initialRDD._jrdd) + else: + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a2bfd79e1abcd..4949cd68e3212 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -403,6 +403,26 @@ def func(dstream): expected = [[('k', v)] for v in expected] self._test_func(input, func, expected) + def test_update_state_by_key_initial_rdd(self): + + def updater(vs, s): + if not s: + s = [] + s.extend(vs) + return s + + initial = [('k', [0, 1])] + initial = self.sc.parallelize(initial, 1) + + input = [[('k', i)] for i in range(2, 5)] + + def func(dstream): + return dstream.updateStateByKey(updater, initialRDD=initial) + + expected = [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + def test_failed_func(self): # Test failure in # TransformFunction.apply(rdd: Option[RDD[_]], time: Time) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 994309ddd0a3e..056248ccc7bcd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -264,9 +264,19 @@ private[python] class PythonTransformed2DStream( */ private[python] class PythonStateDStream( parent: DStream[Array[Byte]], - reduceFunc: PythonTransformFunction) + reduceFunc: PythonTransformFunction, + initialRDD: Option[RDD[Array[Byte]]]) extends PythonDStream(parent, reduceFunc) { + def this( + parent: DStream[Array[Byte]], + reduceFunc: PythonTransformFunction) = this(parent, reduceFunc, None) + + def this( + parent: DStream[Array[Byte]], + reduceFunc: PythonTransformFunction, + initialRDD: JavaRDD[Array[Byte]]) = this(parent, reduceFunc, Some(initialRDD.rdd)) + super.persist(StorageLevel.MEMORY_ONLY) override val mustCheckpoint = true @@ -274,7 +284,7 @@ private[python] class PythonStateDStream( val lastState = getOrCompute(validTime - slideDuration) val rdd = parent.getOrCompute(validTime) if (rdd.isDefined) { - func(lastState, rdd, validTime) + func(lastState.orElse(initialRDD), rdd, validTime) } else { lastState } From 8ccbe0ed62d09ec2583bb2c962a322b85e168ea4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 10 Dec 2015 15:29:04 -0800 Subject: [PATCH 1113/1168] [SPARK-12251] Document and improve off-heap memory configurations This patch adds documentation for Spark configurations that affect off-heap memory and makes some naming and validation improvements for those configs. - Change `spark.memory.offHeapSize` to `spark.memory.offHeap.size`. This is fine because this configuration has not shipped in any Spark release yet (it's new in Spark 1.6). - Deprecated `spark.unsafe.offHeap` in favor of a new `spark.memory.offHeap.enabled` configuration. The motivation behind this change is to gather all memory-related configurations under the same prefix. - Add a check which prevents users from setting `spark.memory.offHeap.enabled=true` when `spark.memory.offHeap.size == 0`. After SPARK-11389 (#9344), which was committed in Spark 1.6, Spark enforces a hard limit on the amount of off-heap memory that it will allocate to tasks. As a result, enabling off-heap execution memory without setting `spark.memory.offHeap.size` will lead to immediate OOMs. The new configuration validation makes this scenario easier to diagnose, helping to avoid user confusion. - Document these configurations on the configuration page. Author: Josh Rosen Closes #10237 from JoshRosen/SPARK-12251. --- .../scala/org/apache/spark/SparkConf.scala | 4 +++- .../apache/spark/memory/MemoryManager.scala | 10 +++++++-- .../spark/memory/TaskMemoryManagerSuite.java | 21 +++++++++++++++---- .../sort/PackedRecordPointerSuite.java | 6 ++++-- .../sort/ShuffleInMemorySorterSuite.java | 4 ++-- .../sort/UnsafeShuffleWriterSuite.java | 2 +- .../map/AbstractBytesToBytesMapSuite.java | 4 ++-- .../sort/UnsafeExternalSorterSuite.java | 2 +- .../sort/UnsafeInMemorySorterSuite.java | 4 ++-- .../memory/StaticMemoryManagerSuite.scala | 2 +- .../memory/UnifiedMemoryManagerSuite.scala | 2 +- docs/configuration.md | 16 ++++++++++++++ .../sql/execution/joins/HashedRelation.scala | 6 +++++- .../UnsafeFixedWidthAggregationMapSuite.scala | 2 +- .../UnsafeKVExternalSorterSuite.scala | 2 +- 15 files changed, 65 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 19633a3ce6a02..d3384fb297732 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -597,7 +597,9 @@ private[spark] object SparkConf extends Logging { "spark.streaming.fileStream.minRememberDuration" -> Seq( AlternateConfig("spark.streaming.minRememberDuration", "1.5")), "spark.yarn.max.executor.failures" -> Seq( - AlternateConfig("spark.yarn.max.worker.failures", "1.5")) + AlternateConfig("spark.yarn.max.worker.failures", "1.5")), + "spark.memory.offHeap.enabled" -> Seq( + AlternateConfig("spark.unsafe.offHeap", "1.6")) ) /** diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index ae9e1ac0e246b..e707e27d96b50 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -50,7 +50,7 @@ private[spark] abstract class MemoryManager( storageMemoryPool.incrementPoolSize(storageMemory) onHeapExecutionMemoryPool.incrementPoolSize(onHeapExecutionMemory) - offHeapExecutionMemoryPool.incrementPoolSize(conf.getSizeAsBytes("spark.memory.offHeapSize", 0)) + offHeapExecutionMemoryPool.incrementPoolSize(conf.getSizeAsBytes("spark.memory.offHeap.size", 0)) /** * Total available memory for storage, in bytes. This amount can vary over time, depending on @@ -182,7 +182,13 @@ private[spark] abstract class MemoryManager( * sun.misc.Unsafe. */ final val tungstenMemoryMode: MemoryMode = { - if (conf.getBoolean("spark.unsafe.offHeap", false)) MemoryMode.OFF_HEAP else MemoryMode.ON_HEAP + if (conf.getBoolean("spark.memory.offHeap.enabled", false)) { + require(conf.getSizeAsBytes("spark.memory.offHeap.size", 0) > 0, + "spark.memory.offHeap.size must be > 0 when spark.memory.offHeap.enabled == true") + MemoryMode.OFF_HEAP + } else { + MemoryMode.ON_HEAP + } } /** diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index 711eed0193bc0..776a2997cf91f 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -29,7 +29,7 @@ public class TaskMemoryManagerSuite { public void leakedPageMemoryIsDetected() { final TaskMemoryManager manager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.unsafe.offHeap", "false"), + new SparkConf().set("spark.memory.offHeap.enabled", "false"), Long.MAX_VALUE, Long.MAX_VALUE, 1), @@ -41,8 +41,10 @@ public void leakedPageMemoryIsDetected() { @Test public void encodePageNumberAndOffsetOffHeap() { - final TaskMemoryManager manager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0); + final SparkConf conf = new SparkConf() + .set("spark.memory.offHeap.enabled", "true") + .set("spark.memory.offHeap.size", "1000"); + final TaskMemoryManager manager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock dataPage = manager.allocatePage(256, null); // In off-heap mode, an offset is an absolute address that may require more than 51 bits to // encode. This test exercises that corner-case: @@ -55,7 +57,7 @@ public void encodePageNumberAndOffsetOffHeap() { @Test public void encodePageNumberAndOffsetOnHeap() { final TaskMemoryManager manager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); final MemoryBlock dataPage = manager.allocatePage(256, null); final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); @@ -104,4 +106,15 @@ public void cooperativeSpilling() { assert(manager.cleanUpAllAllocatedMemory() == 0); } + @Test + public void offHeapConfigurationBackwardsCompatibility() { + // Tests backwards-compatibility with the old `spark.unsafe.offHeap` configuration, which + // was deprecated in Spark 1.6 and replaced by `spark.memory.offHeap.enabled` (see SPARK-12251). + final SparkConf conf = new SparkConf() + .set("spark.unsafe.offHeap", "true") + .set("spark.memory.offHeap.size", "1000"); + final TaskMemoryManager manager = new TaskMemoryManager(new TestMemoryManager(conf), 0); + assert(manager.tungstenMemoryMode == MemoryMode.OFF_HEAP); + } + } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java index 9a43f1f3a9235..fe5abc5c23049 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java @@ -35,7 +35,7 @@ public class PackedRecordPointerSuite { @Test public void heap() throws IOException { - final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false"); + final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock page0 = memoryManager.allocatePage(128, null); @@ -54,7 +54,9 @@ public void heap() throws IOException { @Test public void offHeap() throws IOException { - final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true"); + final SparkConf conf = new SparkConf() + .set("spark.memory.offHeap.enabled", "true") + .set("spark.memory.offHeap.size", "10000"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock page0 = memoryManager.allocatePage(128, null); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index faa5a863ee630..0328e63e45439 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -34,7 +34,7 @@ public class ShuffleInMemorySorterSuite { final TestMemoryManager memoryManager = - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager); @@ -64,7 +64,7 @@ public void testBasicSorting() throws Exception { "Lychee", "Mango" }; - final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false"); + final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index bc85918c59aab..5fe64bde3604a 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -108,7 +108,7 @@ public void setUp() throws IOException { spillFilesCreated.clear(); conf = new SparkConf() .set("spark.buffer.pageSize", "1m") - .set("spark.unsafe.offHeap", "false"); + .set("spark.memory.offHeap.enabled", "false"); taskMetrics = new TaskMetrics(); memoryManager = new TestMemoryManager(conf); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 8724a34988421..702ba5469b8b4 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -85,8 +85,8 @@ public void setup() { memoryManager = new TestMemoryManager( new SparkConf() - .set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator()) - .set("spark.memory.offHeapSize", "256mb")); + .set("spark.memory.offHeap.enabled", "" + useOffHeapMemoryAllocator()) + .set("spark.memory.offHeap.size", "256mb")); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index a1c9f6fab8e65..e0ee281e98b71 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -58,7 +58,7 @@ public class UnsafeExternalSorterSuite { final LinkedList spillFilesCreated = new LinkedList(); final TestMemoryManager memoryManager = - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = new PrefixComparator() { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index a203a09648ac0..93efd033eb940 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -46,7 +46,7 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset, @Test public void testSortingEmptyInput() { final TaskMemoryManager memoryManager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, @@ -71,7 +71,7 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { "Mango" }; final TaskMemoryManager memoryManager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 272253bc94e91..68cf26fc3ed5d 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -47,7 +47,7 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { conf.clone .set("spark.memory.fraction", "1") .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) - .set("spark.memory.offHeapSize", maxOffHeapExecutionMemory.toString), + .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString), maxOnHeapExecutionMemory = maxOnHeapExecutionMemory, maxStorageMemory = 0, numCores = 1) diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 71221deeb4c28..e21a028b7faec 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -42,7 +42,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val conf = new SparkConf() .set("spark.memory.fraction", "1") .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) - .set("spark.memory.offHeapSize", maxOffHeapExecutionMemory.toString) + .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString) .set("spark.memory.storageFraction", storageFraction.toString) UnifiedMemoryManager(conf, numCores = 1) } diff --git a/docs/configuration.md b/docs/configuration.md index 873a2d0b303cd..55cf4b2dac5f5 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -738,6 +738,22 @@ Apart from these, the following properties are also available, and may be useful this description. + + + + + + + + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index aebfea5832402..8c7099ab5a34d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -334,7 +334,11 @@ private[joins] final class UnsafeHashedRelation( // so that tests compile: val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.unsafe.offHeap", "false"), Long.MaxValue, Long.MaxValue, 1), 0) + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 7ceaee38d131b..5a8406789ab81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -61,7 +61,7 @@ class UnsafeFixedWidthAggregationMapSuite } test(name) { - val conf = new SparkConf().set("spark.unsafe.offHeap", "false") + val conf = new SparkConf().set("spark.memory.offHeap.enabled", "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 7b80963ec8708..29027a664b4b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -109,7 +109,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { pageSize: Long, spill: Boolean): Unit = { val memoryManager = - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")) + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")) val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, From fb5b0f84fbb2a6bf0b480720bc61029c1556f053 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 10 Dec 2015 15:30:08 -0800 Subject: [PATCH 1114/1168] [SPARK-12155][SPARK-12253] Fix executor OOM in unified memory management **Problem.** In unified memory management, acquiring execution memory may lead to eviction of storage memory. However, the space freed from evicting cached blocks is distributed among all active tasks. Thus, an incorrect upper bound on the execution memory per task can cause the acquisition to fail, leading to OOM's and premature spills. **Example.** Suppose total memory is 1000B, cached blocks occupy 900B, `spark.memory.storageFraction` is 0.4, and there are two active tasks. In this case, the cap on task execution memory is 100B / 2 = 50B. If task A tries to acquire 200B, it will evict 100B of storage but can only acquire 50B because of the incorrect cap. For another example, see this [regression test](https://github.com/andrewor14/spark/blob/fix-oom/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala#L233) that I stole from JoshRosen. **Solution.** Fix the cap on task execution memory. It should take into account the space that could have been freed by storage in addition to the current amount of memory available to execution. In the example above, the correct cap should have been 600B / 2 = 300B. This patch also guards against the race condition (SPARK-12253): (1) Existing tasks collectively occupy all execution memory (2) New task comes in and blocks while existing tasks spill (3) After tasks finish spilling, another task jumps in and puts in a large block, stealing the freed memory (4) New task still cannot acquire memory and goes back to sleep Author: Andrew Or Closes #10240 from andrewor14/fix-oom. --- .../spark/memory/ExecutionMemoryPool.scala | 57 +++++++++++++------ .../spark/memory/UnifiedMemoryManager.scala | 57 ++++++++++++++----- .../org/apache/spark/scheduler/Task.scala | 6 ++ .../memory/UnifiedMemoryManagerSuite.scala | 25 ++++++++ 4 files changed, 114 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala index 9023e1ac012b7..dbb0ad8d5c673 100644 --- a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala @@ -70,11 +70,28 @@ private[memory] class ExecutionMemoryPool( * active tasks) before it is forced to spill. This can happen if the number of tasks increase * but an older task had a lot of memory already. * + * @param numBytes number of bytes to acquire + * @param taskAttemptId the task attempt acquiring memory + * @param maybeGrowPool a callback that potentially grows the size of this pool. It takes in + * one parameter (Long) that represents the desired amount of memory by + * which this pool should be expanded. + * @param computeMaxPoolSize a callback that returns the maximum allowable size of this pool + * at this given moment. This is not a field because the max pool + * size is variable in certain cases. For instance, in unified + * memory management, the execution pool can be expanded by evicting + * cached blocks, thereby shrinking the storage pool. + * * @return the number of bytes granted to the task. */ - def acquireMemory(numBytes: Long, taskAttemptId: Long): Long = lock.synchronized { + private[memory] def acquireMemory( + numBytes: Long, + taskAttemptId: Long, + maybeGrowPool: Long => Unit = (additionalSpaceNeeded: Long) => Unit, + computeMaxPoolSize: () => Long = () => poolSize): Long = lock.synchronized { assert(numBytes > 0, s"invalid number of bytes requested: $numBytes") + // TODO: clean up this clunky method signature + // Add this task to the taskMemory map just so we can keep an accurate count of the number // of active tasks, to let other tasks ramp down their memory in calls to `acquireMemory` if (!memoryForTask.contains(taskAttemptId)) { @@ -91,25 +108,31 @@ private[memory] class ExecutionMemoryPool( val numActiveTasks = memoryForTask.keys.size val curMem = memoryForTask(taskAttemptId) - // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; - // don't let it be negative - val maxToGrant = - math.min(numBytes, math.max(0, (poolSize / numActiveTasks) - curMem)) + // In every iteration of this loop, we should first try to reclaim any borrowed execution + // space from storage. This is necessary because of the potential race condition where new + // storage blocks may steal the free execution memory that this task was waiting for. + maybeGrowPool(numBytes - memoryFree) + + // Maximum size the pool would have after potentially growing the pool. + // This is used to compute the upper bound of how much memory each task can occupy. This + // must take into account potential free memory as well as the amount this pool currently + // occupies. Otherwise, we may run into SPARK-12155 where, in unified memory management, + // we did not take into account space that could have been freed by evicting cached blocks. + val maxPoolSize = computeMaxPoolSize() + val maxMemoryPerTask = maxPoolSize / numActiveTasks + val minMemoryPerTask = poolSize / (2 * numActiveTasks) + + // How much we can grant this task; keep its share within 0 <= X <= 1 / numActiveTasks + val maxToGrant = math.min(numBytes, math.max(0, maxMemoryPerTask - curMem)) // Only give it as much memory as is free, which might be none if it reached 1 / numTasks val toGrant = math.min(maxToGrant, memoryFree) - if (curMem < poolSize / (2 * numActiveTasks)) { - // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; - // if we can't give it this much now, wait for other tasks to free up memory - // (this happens if older tasks allocated lots of memory before N grew) - if (memoryFree >= math.min(maxToGrant, poolSize / (2 * numActiveTasks) - curMem)) { - memoryForTask(taskAttemptId) += toGrant - return toGrant - } else { - logInfo( - s"TID $taskAttemptId waiting for at least 1/2N of $poolName pool to be free") - lock.wait() - } + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (toGrant < numBytes && curMem + toGrant < minMemoryPerTask) { + logInfo(s"TID $taskAttemptId waiting for at least 1/2N of $poolName pool to be free") + lock.wait() } else { memoryForTask(taskAttemptId) += toGrant return toGrant diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 0b9f6a9dc0525..829f054dba0e9 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -81,22 +81,51 @@ private[spark] class UnifiedMemoryManager private[memory] ( assert(numBytes >= 0) memoryMode match { case MemoryMode.ON_HEAP => - if (numBytes > onHeapExecutionMemoryPool.memoryFree) { - val extraMemoryNeeded = numBytes - onHeapExecutionMemoryPool.memoryFree - // There is not enough free memory in the execution pool, so try to reclaim memory from - // storage. We can reclaim any free memory from the storage pool. If the storage pool - // has grown to become larger than `storageRegionSize`, we can evict blocks and reclaim - // the memory that storage has borrowed from execution. - val memoryReclaimableFromStorage = - math.max(storageMemoryPool.memoryFree, storageMemoryPool.poolSize - storageRegionSize) - if (memoryReclaimableFromStorage > 0) { - // Only reclaim as much space as is necessary and available: - val spaceReclaimed = storageMemoryPool.shrinkPoolToFreeSpace( - math.min(extraMemoryNeeded, memoryReclaimableFromStorage)) - onHeapExecutionMemoryPool.incrementPoolSize(spaceReclaimed) + + /** + * Grow the execution pool by evicting cached blocks, thereby shrinking the storage pool. + * + * When acquiring memory for a task, the execution pool may need to make multiple + * attempts. Each attempt must be able to evict storage in case another task jumps in + * and caches a large block between the attempts. This is called once per attempt. + */ + def maybeGrowExecutionPool(extraMemoryNeeded: Long): Unit = { + if (extraMemoryNeeded > 0) { + // There is not enough free memory in the execution pool, so try to reclaim memory from + // storage. We can reclaim any free memory from the storage pool. If the storage pool + // has grown to become larger than `storageRegionSize`, we can evict blocks and reclaim + // the memory that storage has borrowed from execution. + val memoryReclaimableFromStorage = + math.max(storageMemoryPool.memoryFree, storageMemoryPool.poolSize - storageRegionSize) + if (memoryReclaimableFromStorage > 0) { + // Only reclaim as much space as is necessary and available: + val spaceReclaimed = storageMemoryPool.shrinkPoolToFreeSpace( + math.min(extraMemoryNeeded, memoryReclaimableFromStorage)) + onHeapExecutionMemoryPool.incrementPoolSize(spaceReclaimed) + } } } - onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + + /** + * The size the execution pool would have after evicting storage memory. + * + * The execution memory pool divides this quantity among the active tasks evenly to cap + * the execution memory allocation for each task. It is important to keep this greater + * than the execution pool size, which doesn't take into account potential memory that + * could be freed by evicting storage. Otherwise we may hit SPARK-12155. + * + * Additionally, this quantity should be kept below `maxMemory` to arbitrate fairness + * in execution memory allocation across tasks, Otherwise, a task may occupy more than + * its fair share of execution memory, mistakenly thinking that other tasks can acquire + * the portion of storage memory that cannot be evicted. + */ + def computeMaxExecutionPoolSize(): Long = { + maxMemory - math.min(storageMemoryUsed, storageRegionSize) + } + + onHeapExecutionMemoryPool.acquireMemory( + numBytes, taskAttemptId, maybeGrowExecutionPool, computeMaxExecutionPoolSize) + case MemoryMode.OFF_HEAP => // For now, we only support on-heap caching of data, so we do not need to interact with // the storage pool when allocating off-heap memory. This will change in the future, though. diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index d4bc3a5c900f7..9f27eed626be3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -92,6 +92,12 @@ private[spark] abstract class Task[T]( Utils.tryLogNonFatalError { // Release memory used by this thread for unrolling blocks SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + // Notify any tasks waiting for execution memory to be freed to wake up and try to + // acquire memory again. This makes impossible the scenario where a task sleeps forever + // because there are no other tasks left to notify it. Since this is safe to do but may + // not be strictly necessary, we should revisit whether we can remove this in the future. + val memoryManager = SparkEnv.get.memoryManager + memoryManager.synchronized { memoryManager.notifyAll() } } } finally { TaskContext.unset() diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index e21a028b7faec..6cc48597d38f9 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -230,4 +230,29 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(exception.getMessage.contains("larger heap size")) } + test("execution can evict cached blocks when there are multiple active tasks (SPARK-12155)") { + val conf = new SparkConf() + .set("spark.memory.fraction", "1") + .set("spark.memory.storageFraction", "0") + .set("spark.testing.memory", "1000") + val mm = UnifiedMemoryManager(conf, numCores = 2) + val ms = makeMemoryStore(mm) + assert(mm.maxMemory === 1000) + // Have two tasks each acquire some execution memory so that the memory pool registers that + // there are two active tasks: + assert(mm.acquireExecutionMemory(100L, 0, MemoryMode.ON_HEAP) === 100L) + assert(mm.acquireExecutionMemory(100L, 1, MemoryMode.ON_HEAP) === 100L) + // Fill up all of the remaining memory with storage. + assert(mm.acquireStorageMemory(dummyBlock, 800L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 800) + assert(mm.executionMemoryUsed === 200) + // A task should still be able to allocate 100 bytes execution memory by evicting blocks + assert(mm.acquireExecutionMemory(100L, 0, MemoryMode.ON_HEAP) === 100L) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) + assert(mm.executionMemoryUsed === 300) + assert(mm.storageMemoryUsed === 700) + assert(evictedBlocks.nonEmpty) + } + } From 7413ec87f71ed4cc12661a74d3855e22d07cb767 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 10 Dec 2015 15:31:46 -0800 Subject: [PATCH 1115/1168] [STREAMING][DOC][MINOR] Update the description of direct Kafka stream doc With the merge of [SPARK-8337](https://issues.apache.org/jira/browse/SPARK-8337), now the Python API has the same functionalities compared to Scala/Java, so here changing the description to make it more precise. zsxwing tdas , please review, thanks a lot. Author: jerryshao Closes #10246 from jerryshao/direct-kafka-doc-update. --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index b00351b2fbcc0..5be73c42560f5 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -74,7 +74,7 @@ Next, we discuss how to use this approach in your streaming application. [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_2.10%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. ## Approach 2: Direct Approach (No Receivers) -This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API. Spark 1.4 added a Python API, but it is not yet at full feature parity. +This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API, in Spark 1.4 for the Python API. This approach has the following advantages over the receiver-based approach (i.e. Approach 1). From c4156279bcdec165bad9fcec3360a71102f2f29e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 10 Dec 2015 17:22:18 -0800 Subject: [PATCH 1116/1168] [SPARK-12258][SQL] passing null into ScalaUDF Check nullability and passing them into ScalaUDF. Closes #10249 Author: Davies Liu Closes #10259 from davies/udf_null. --- .../apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 7 +++++-- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 03b89221ef2d3..5deb2f81d1738 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1029,8 +1029,11 @@ case class ScalaUDF( // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. val evalCode = evals.map(_.code).mkString - val funcArguments = converterTerms.zip(evals).map { - case (converter, eval) => s"$converter.apply(${eval.value})" + val funcArguments = converterTerms.zipWithIndex.map { + case (converter, i) => + val eval = evals(i) + val dt = children(i).dataType + s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)}) ${eval.value})" }.mkString(",") val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " + s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 605a6549dd686..8887dc68a50e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1138,14 +1138,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-11725: correctly handle null inputs for ScalaUDF") { - val df = Seq( + val df = sparkContext.parallelize(Seq( new java.lang.Integer(22) -> "John", - null.asInstanceOf[java.lang.Integer] -> "Lucy").toDF("age", "name") + null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name") + // passing null into the UDF that could handle it val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { - (i: java.lang.Integer) => if (i == null) null else i * 2 + (i: java.lang.Integer) => if (i == null) -10 else i * 2 } - checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(null) :: Nil) + checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil) val primitiveUDF = udf((i: Int) => i * 2) checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) From eb564dbb33b4a42f3996f36a59964410bdb35800 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 11 Dec 2015 02:35:53 -0500 Subject: [PATCH 1117/1168] [SPARK-10991][ML] logistic regression training summary handle empty prediction col LogisticRegression training summary should still function if the predictionCol is set to an empty string or otherwise unset (related too https://issues.apache.org/jira/browse/SPARK-9718 ) Author: Holden Karau Author: Holden Karau Closes #9037 from holdenk/SPARK-10991-LogisticRegressionTrainingSummary-handle-empty-prediction-col. --- .../classification/LogisticRegression.scala | 20 +++++++++++++++++-- .../LogisticRegressionSuite.scala | 11 ++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 19cc323d5073f..486043e8d9741 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -389,9 +389,10 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.unpersist() val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept)) + val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol() val logRegSummary = new BinaryLogisticRegressionTrainingSummary( - model.transform(dataset), - $(probabilityCol), + summaryModel.transform(dataset), + probabilityColName, $(labelCol), $(featuresCol), objectiveHistory) @@ -469,6 +470,21 @@ class LogisticRegressionModel private[ml] ( new NullPointerException()) } + /** + * If the probability column is set returns the current model and probability column, + * otherwise generates a new column and sets it as the probability column on a new copy + * of the current model. + */ + private[classification] def findSummaryModelAndProbabilityCol(): + (LogisticRegressionModel, String) = { + $(probabilityCol) match { + case "" => + val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString() + (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName) + case p => (this, p) + } + } + private[classification] def setSummary( summary: LogisticRegressionTrainingSummary): this.type = { this.trainingSummary = Some(summary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a9a6ff8a783d5..1087afb0cdf79 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -99,6 +99,17 @@ class LogisticRegressionSuite assert(model.hasParent) } + test("empty probabilityCol") { + val lr = new LogisticRegression().setProbabilityCol("") + val model = lr.fit(dataset) + assert(model.hasSummary) + // Validate that we re-insert a probability column for evaluation + val fieldNames = model.summary.predictions.schema.fieldNames + assert((dataset.schema.fieldNames.toSet).subsetOf( + fieldNames.toSet)) + assert(fieldNames.exists(s => s.startsWith("probability_"))) + } + test("setThreshold, getThreshold") { val lr = new LogisticRegression // default From c50c7d0457a42df0e2c13c522b92f68fe56ba875 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 11 Dec 2015 11:15:53 -0800 Subject: [PATCH 1118/1168] [SPARK-12258] [SQL] passing null into ScalaUDF (follow-up) This is a follow-up PR for #10259 Author: Davies Liu Closes #10266 from davies/null_udf2. --- .../sql/catalyst/expressions/ScalaUDF.scala | 31 ++++++++++--------- .../org/apache/spark/sql/DataFrameSuite.scala | 8 +++-- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 5deb2f81d1738..85faa19bbf5ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1029,24 +1029,27 @@ case class ScalaUDF( // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. val evalCode = evals.map(_.code).mkString - val funcArguments = converterTerms.zipWithIndex.map { - case (converter, i) => - val eval = evals(i) - val dt = children(i).dataType - s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)}) ${eval.value})" - }.mkString(",") - val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " + - s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" + - s".apply($funcTerm.apply($funcArguments));" + val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) => + val eval = evals(i) + val argTerm = ctx.freshName("arg") + val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});" + (convert, argTerm) + }.unzip - evalCode + s""" - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - Boolean ${ev.isNull}; + val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " + + s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" + + s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));" + s""" + $evalCode + ${converters.mkString("\n")} $callFunc - ${ev.value} = $resultTerm; - ${ev.isNull} = $resultTerm == null; + boolean ${ev.isNull} = $resultTerm == null; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $resultTerm; + } """ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8887dc68a50e7..5353fefaf4b84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1144,9 +1144,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // passing null into the UDF that could handle it val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { - (i: java.lang.Integer) => if (i == null) -10 else i * 2 + (i: java.lang.Integer) => if (i == null) -10 else null } - checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil) + checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil) + + sqlContext.udf.register("boxedUDF", + (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer) + checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil) val primitiveUDF = udf((i: Int) => i * 2) checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) From a45d1febfb1e480ba0227bdeca528e9567714d33 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Dec 2015 11:47:35 -0800 Subject: [PATCH 1119/1168] [SPARK-12146][SPARKR] SparkR jsonFile should support multiple input files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ```jsonFile``` should support multiple input files, such as: ```R jsonFile(sqlContext, c(“path1”, “path2”)) # character vector as arguments jsonFile(sqlContext, “path1,path2”) ``` * Meanwhile, ```jsonFile``` has been deprecated by Spark SQL and will be removed at Spark 2.0. So we mark ```jsonFile``` deprecated and use ```read.json``` at SparkR side. * Replace all ```jsonFile``` with ```read.json``` at test_sparkSQL.R, but still keep jsonFile test case. * If this PR is accepted, we should also make almost the same change for ```parquetFile```. cc felixcheung sun-rui shivaram Author: Yanbo Liang Closes #10145 from yanboliang/spark-12146. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 102 +++++++++--------- R/pkg/R/SQLContext.R | 29 ++++-- R/pkg/inst/tests/testthat/test_sparkSQL.R | 120 ++++++++++++---------- examples/src/main/r/dataframe.R | 2 +- 5 files changed, 138 insertions(+), 116 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ba64bc59edee5..cab39d68c3f52 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -267,6 +267,7 @@ export("as.DataFrame", "createExternalTable", "dropTempTable", "jsonFile", + "read.json", "loadDF", "parquetFile", "read.df", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index f4c4a2585e291..975b058c0aaf1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -24,14 +24,14 @@ setOldClass("jobj") #' @title S4 class that represents a DataFrame #' @description DataFrames can be created using functions like \link{createDataFrame}, -#' \link{jsonFile}, \link{table} etc. +#' \link{read.json}, \link{table} etc. #' @family DataFrame functions #' @rdname DataFrame #' @docType class #' #' @slot env An R environment that stores bookkeeping states of the DataFrame #' @slot sdf A Java object reference to the backing Scala DataFrame -#' @seealso \link{createDataFrame}, \link{jsonFile}, \link{table} +#' @seealso \link{createDataFrame}, \link{read.json}, \link{table} #' @seealso \url{https://spark.apache.org/docs/latest/sparkr.html#sparkr-dataframes} #' @export #' @examples @@ -77,7 +77,7 @@ dataFrame <- function(sdf, isCached = FALSE) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' printSchema(df) #'} setMethod("printSchema", @@ -102,7 +102,7 @@ setMethod("printSchema", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dfSchema <- schema(df) #'} setMethod("schema", @@ -126,7 +126,7 @@ setMethod("schema", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' explain(df, TRUE) #'} setMethod("explain", @@ -157,7 +157,7 @@ setMethod("explain", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' isLocal(df) #'} setMethod("isLocal", @@ -182,7 +182,7 @@ setMethod("isLocal", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' showDF(df) #'} setMethod("showDF", @@ -207,7 +207,7 @@ setMethod("showDF", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' df #'} setMethod("show", "DataFrame", @@ -234,7 +234,7 @@ setMethod("show", "DataFrame", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dtypes(df) #'} setMethod("dtypes", @@ -261,7 +261,7 @@ setMethod("dtypes", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' columns(df) #' colnames(df) #'} @@ -376,7 +376,7 @@ setMethod("coltypes", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' coltypes(df) <- c("character", "integer") #' coltypes(df) <- c(NA, "numeric") #'} @@ -423,7 +423,7 @@ setMethod("coltypes<-", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "json_df") #' new_df <- sql(sqlContext, "SELECT * FROM json_df") #'} @@ -476,7 +476,7 @@ setMethod("insertInto", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' cache(df) #'} setMethod("cache", @@ -504,7 +504,7 @@ setMethod("cache", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #'} setMethod("persist", @@ -532,7 +532,7 @@ setMethod("persist", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #' unpersist(df) #'} @@ -560,7 +560,7 @@ setMethod("unpersist", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- repartition(df, 2L) #'} setMethod("repartition", @@ -585,7 +585,7 @@ setMethod("repartition", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newRDD <- toJSON(df) #'} setMethod("toJSON", @@ -613,7 +613,7 @@ setMethod("toJSON", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' saveAsParquetFile(df, "/tmp/sparkr-tmp/") #'} setMethod("saveAsParquetFile", @@ -637,7 +637,7 @@ setMethod("saveAsParquetFile", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' distinctDF <- distinct(df) #'} setMethod("distinct", @@ -672,7 +672,7 @@ setMethod("unique", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' collect(sample(df, FALSE, 0.5)) #' collect(sample(df, TRUE, 0.5)) #'} @@ -711,7 +711,7 @@ setMethod("sample_frac", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' count(df) #' } setMethod("count", @@ -741,7 +741,7 @@ setMethod("nrow", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' ncol(df) #' } setMethod("ncol", @@ -762,7 +762,7 @@ setMethod("ncol", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dim(df) #' } setMethod("dim", @@ -786,7 +786,7 @@ setMethod("dim", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' collected <- collect(df) #' firstName <- collected[[1]]$name #' } @@ -858,7 +858,7 @@ setMethod("collect", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' limitedDF <- limit(df, 10) #' } setMethod("limit", @@ -879,7 +879,7 @@ setMethod("limit", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' take(df, 2) #' } setMethod("take", @@ -908,7 +908,7 @@ setMethod("take", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' head(df) #' } setMethod("head", @@ -931,7 +931,7 @@ setMethod("head", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' first(df) #' } setMethod("first", @@ -952,7 +952,7 @@ setMethod("first", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' rdd <- toRDD(df) #'} setMethod("toRDD", @@ -1298,7 +1298,7 @@ setMethod("select", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' selectExpr(df, "col1", "(col2 * 5) as newCol") #' } setMethod("selectExpr", @@ -1327,7 +1327,7 @@ setMethod("selectExpr", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- withColumn(df, "newCol", df$col1 * 5) #' } setMethod("withColumn", @@ -1352,7 +1352,7 @@ setMethod("withColumn", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 #' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2) @@ -1402,7 +1402,7 @@ setMethod("transform", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- withColumnRenamed(df, "col1", "newCol1") #' } setMethod("withColumnRenamed", @@ -1427,7 +1427,7 @@ setMethod("withColumnRenamed", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- rename(df, col1 = df$newCol1) #' } setMethod("rename", @@ -1471,7 +1471,7 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' arrange(df, df$col1) #' arrange(df, asc(df$col1), desc(abs(df$col2))) #' arrange(df, "col1", decreasing = TRUE) @@ -1547,7 +1547,7 @@ setMethod("orderBy", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' filter(df, "col1 > 0") #' filter(df, df$col2 != "abcdefg") #' } @@ -1591,8 +1591,8 @@ setMethod("where", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' join(df1, df2) # Performs a Cartesian #' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression #' join(df1, df2, df1$col1 == df2$col2, "right_outer") @@ -1648,8 +1648,8 @@ setMethod("join", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' merge(df1, df2) # Performs a Cartesian #' merge(df1, df2, by = "col1") # Performs an inner join based on expression #' merge(df1, df2, by.x = "col1", by.y = "col2", all.y = TRUE) @@ -1781,8 +1781,8 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' unioned <- unionAll(df, df2) #' } setMethod("unionAll", @@ -1824,8 +1824,8 @@ setMethod("rbind", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' intersectDF <- intersect(df, df2) #' } setMethod("intersect", @@ -1851,8 +1851,8 @@ setMethod("intersect", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' exceptDF <- except(df, df2) #' } #' @rdname except @@ -1892,7 +1892,7 @@ setMethod("except", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' write.df(df, "myfile", "parquet", "overwrite") #' saveDF(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) #' } @@ -1957,7 +1957,7 @@ setMethod("saveDF", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", @@ -1998,7 +1998,7 @@ setMethod("saveAsTable", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' describe(df) #' describe(df, "col1") #' describe(df, "col1", "col2") @@ -2054,7 +2054,7 @@ setMethod("summary", #' sc <- sparkR.init() #' sqlCtx <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- read.json(sqlCtx, path) #' dropna(df) #' } setMethod("dropna", @@ -2108,7 +2108,7 @@ setMethod("na.omit", #' sc <- sparkR.init() #' sqlCtx <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- read.json(sqlCtx, path) #' fillna(df, 1) #' fillna(df, list("age" = 20, "name" = "unknown")) #' } diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index f678c70a7a77c..9243d70e66f75 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -208,24 +208,33 @@ setMethod("toDF", signature(x = "RDD"), #' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return DataFrame +#' @rdname read.json +#' @name read.json #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) #' df <- jsonFile(sqlContext, path) #' } - -jsonFile <- function(sqlContext, path) { +read.json <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path - path <- suppressWarnings(normalizePath(path)) - # Convert a string vector of paths to a string containing comma separated paths - path <- paste(path, collapse = ",") - sdf <- callJMethod(sqlContext, "jsonFile", path) + paths <- as.list(suppressWarnings(normalizePath(path))) + read <- callJMethod(sqlContext, "read") + sdf <- callJMethod(read, "json", paths) dataFrame(sdf) } +#' @rdname read.json +#' @name jsonFile +#' @export +jsonFile <- function(sqlContext, path) { + .Deprecated("read.json") + read.json(sqlContext, path) +} + #' JSON RDD #' @@ -299,7 +308,7 @@ parquetFile <- function(sqlContext, ...) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' new_df <- sql(sqlContext, "SELECT * FROM table") #' } @@ -323,7 +332,7 @@ sql <- function(sqlContext, sqlQuery) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' new_df <- table(sqlContext, "table") #' } @@ -396,7 +405,7 @@ tableNames <- function(sqlContext, databaseName = NULL) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' cacheTable(sqlContext, "table") #' } @@ -418,7 +427,7 @@ cacheTable <- function(sqlContext, tableName) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' uncacheTable(sqlContext, "table") #' } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 2051784427be3..ed9b2c9d4d16c 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -330,7 +330,7 @@ writeLines(mockLinesMapType, mapTypeJsonPath) test_that("Collect DataFrame with complex types", { # ArrayType - df <- jsonFile(sqlContext, complexTypeJsonPath) + df <- read.json(sqlContext, complexTypeJsonPath) ldf <- collect(df) expect_equal(nrow(ldf), 3) @@ -357,7 +357,7 @@ test_that("Collect DataFrame with complex types", { expect_equal(bob$height, 176.5) # StructType - df <- jsonFile(sqlContext, mapTypeJsonPath) + df <- read.json(sqlContext, mapTypeJsonPath) expect_equal(dtypes(df), list(c("info", "struct"), c("name", "string"))) ldf <- collect(df) @@ -371,10 +371,22 @@ test_that("Collect DataFrame with complex types", { expect_equal(bob$height, 176.5) }) -test_that("jsonFile() on a local file returns a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) +test_that("read.json()/jsonFile() on a local file returns a DataFrame", { + df <- read.json(sqlContext, jsonPath) expect_is(df, "DataFrame") expect_equal(count(df), 3) + # read.json()/jsonFile() works with multiple input paths + jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".json") + write.df(df, jsonPath2, "json", mode="overwrite") + jsonDF1 <- read.json(sqlContext, c(jsonPath, jsonPath2)) + expect_is(jsonDF1, "DataFrame") + expect_equal(count(jsonDF1), 6) + # Suppress warnings because jsonFile is deprecated + jsonDF2 <- suppressWarnings(jsonFile(sqlContext, c(jsonPath, jsonPath2))) + expect_is(jsonDF2, "DataFrame") + expect_equal(count(jsonDF2), 6) + + unlink(jsonPath2) }) test_that("jsonRDD() on a RDD with json string", { @@ -391,7 +403,7 @@ test_that("jsonRDD() on a RDD with json string", { }) test_that("test cache, uncache and clearCache", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) registerTempTable(df, "table1") cacheTable(sqlContext, "table1") uncacheTable(sqlContext, "table1") @@ -400,7 +412,7 @@ test_that("test cache, uncache and clearCache", { }) test_that("test tableNames and tables", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) registerTempTable(df, "table1") expect_equal(length(tableNames(sqlContext)), 1) df <- tables(sqlContext) @@ -409,7 +421,7 @@ test_that("test tableNames and tables", { }) test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) registerTempTable(df, "table1") newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") expect_is(newdf, "DataFrame") @@ -445,7 +457,7 @@ test_that("insertInto() on a registered table", { }) test_that("table() returns a new DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) registerTempTable(df, "table1") tabledf <- table(sqlContext, "table1") expect_is(tabledf, "DataFrame") @@ -458,14 +470,14 @@ test_that("table() returns a new DataFrame", { }) test_that("toRDD() returns an RRDD", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") expect_equal(count(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) @@ -487,7 +499,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { writeLines(textLines, textPath) textRDD <- textFile(sc, textPath) - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) @@ -505,7 +517,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { test_that("objectFile() works with row serialization", { objectPath <- tempfile(pattern="spark-test", fileext=".tmp") - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) dfRDD <- toRDD(df) saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) @@ -516,7 +528,7 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 row @@ -528,7 +540,7 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", { }) test_that("collect() returns a data.frame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) expect_equal(names(rdf)[1], "age") @@ -550,14 +562,14 @@ test_that("collect() returns a data.frame", { }) test_that("limit() returns DataFrame with the correct number of rows", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) dfLimited <- limit(df, 2) expect_is(dfLimited, "DataFrame") expect_equal(count(dfLimited), 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_equal(nrow(collect(df)), nrow(take(df, 10))) expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) @@ -584,7 +596,7 @@ test_that("collect() support Unicode characters", { }) test_that("multiple pipeline transformations result in an RDD with the correct values", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 row @@ -601,7 +613,7 @@ test_that("multiple pipeline transformations result in an RDD with the correct v }) test_that("cache(), persist(), and unpersist() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_false(df@env$isCached) cache(df) expect_true(df@env$isCached) @@ -620,7 +632,7 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testSchema <- schema(df) expect_equal(length(testSchema$fields()), 2) expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") @@ -641,7 +653,7 @@ test_that("schema(), dtypes(), columns(), names() return the correct values/form }) test_that("names() colnames() set the column names", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) names(df) <- c("col1", "col2") expect_equal(colnames(df)[2], "col2") @@ -661,7 +673,7 @@ test_that("names() colnames() set the column names", { }) test_that("head() and first() return the correct data", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testHead <- head(df) expect_equal(nrow(testHead), 3) expect_equal(ncol(testHead), 2) @@ -694,7 +706,7 @@ test_that("distinct() and unique on DataFrames", { jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlContext, jsonPathWithDup) + df <- read.json(sqlContext, jsonPathWithDup) uniques <- distinct(df) expect_is(uniques, "DataFrame") expect_equal(count(uniques), 3) @@ -705,7 +717,7 @@ test_that("distinct() and unique on DataFrames", { }) test_that("sample on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) expect_is(sampled, "DataFrame") @@ -721,7 +733,7 @@ test_that("sample on a DataFrame", { }) test_that("select operators", { - df <- select(jsonFile(sqlContext, jsonPath), "name", "age") + df <- select(read.json(sqlContext, jsonPath), "name", "age") expect_is(df$name, "Column") expect_is(df[[2]], "Column") expect_is(df[["age"]], "Column") @@ -747,7 +759,7 @@ test_that("select operators", { }) test_that("select with column", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) df1 <- select(df, "name") expect_equal(columns(df1), c("name")) expect_equal(count(df1), 3) @@ -770,8 +782,8 @@ test_that("select with column", { }) test_that("subsetting", { - # jsonFile returns columns in random order - df <- select(jsonFile(sqlContext, jsonPath), "name", "age") + # read.json returns columns in random order + df <- select(read.json(sqlContext, jsonPath), "name", "age") filtered <- df[df$age > 20,] expect_equal(count(filtered), 1) expect_equal(columns(filtered), c("name", "age")) @@ -808,7 +820,7 @@ test_that("subsetting", { }) test_that("selectExpr() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") expect_equal(names(selected), "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) @@ -819,12 +831,12 @@ test_that("selectExpr() on a DataFrame", { }) test_that("expr() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_equal(collect(select(df, expr("abs(-123)")))[1, 1], 123) }) test_that("column calculation", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) expect_equal(names(d), c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) @@ -915,7 +927,7 @@ test_that("column functions", { expect_equal(class(rank())[[1]], "Column") expect_equal(rank(1:3), as.numeric(c(1:3))) - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) expect_equal(collect(df2)[[2, 1]], TRUE) expect_equal(collect(df2)[[2, 2]], FALSE) @@ -983,7 +995,7 @@ test_that("column binary mathfunctions", { "{\"a\":4, \"b\":8}") jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlContext, jsonPathWithDup) + df <- read.json(sqlContext, jsonPathWithDup) expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) @@ -1004,7 +1016,7 @@ test_that("column binary mathfunctions", { }) test_that("string operators", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_equal(count(where(df, like(df$name, "A%"))), 1) expect_equal(count(where(df, startsWith(df$name, "A"))), 1) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") @@ -1100,7 +1112,7 @@ test_that("when(), otherwise() and ifelse() on a DataFrame", { }) test_that("group by, agg functions", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_equal(1, count(df1)) df1 <- agg(df, age2 = max(df$age)) @@ -1145,7 +1157,7 @@ test_that("group by, agg functions", { "{\"name\":\"ID2\", \"value\": \"-3\"}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines2, jsonPath2) - gd2 <- groupBy(jsonFile(sqlContext, jsonPath2), "name") + gd2 <- groupBy(read.json(sqlContext, jsonPath2), "name") df6 <- agg(gd2, value = "sum") df6_local <- collect(df6) expect_equal(42, df6_local[df6_local$name == "ID1",][1, 2]) @@ -1162,7 +1174,7 @@ test_that("group by, agg functions", { "{\"name\":\"Justin\", \"age\":1}") jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines3, jsonPath3) - df8 <- jsonFile(sqlContext, jsonPath3) + df8 <- read.json(sqlContext, jsonPath3) gd3 <- groupBy(df8, "name") gd3_local <- collect(sum(gd3)) expect_equal(60, gd3_local[gd3_local$name == "Andy",][1, 2]) @@ -1181,7 +1193,7 @@ test_that("group by, agg functions", { }) test_that("arrange() and orderBy() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) sorted <- arrange(df, df$age) expect_equal(collect(sorted)[1,2], "Michael") @@ -1207,7 +1219,7 @@ test_that("arrange() and orderBy() on a DataFrame", { }) test_that("filter() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) filtered <- filter(df, "age > 20") expect_equal(count(filtered), 1) expect_equal(collect(filtered)$name, "Andy") @@ -1230,7 +1242,7 @@ test_that("filter() on a DataFrame", { }) test_that("join() and merge() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", "{\"name\":\"Andy\", \"test\": \"no\"}", @@ -1238,7 +1250,7 @@ test_that("join() and merge() on a DataFrame", { "{\"name\":\"Bob\", \"test\": \"yes\"}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines2, jsonPath2) - df2 <- jsonFile(sqlContext, jsonPath2) + df2 <- read.json(sqlContext, jsonPath2) joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) @@ -1313,14 +1325,14 @@ test_that("join() and merge() on a DataFrame", { "{\"name\":\"Bob\", \"name_y\":\"Bob\", \"test\": \"yes\"}") jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines3, jsonPath3) - df3 <- jsonFile(sqlContext, jsonPath3) + df3 <- read.json(sqlContext, jsonPath3) expect_error(merge(df, df3), paste("The following column name: name_y occurs more than once in the 'DataFrame'.", "Please use different suffixes for the intersected columns.", sep = "")) }) test_that("toJSON() returns an RDD of the correct values", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testRDD <- toJSON(df) expect_is(testRDD, "RDD") expect_equal(getSerializedMode(testRDD), "string") @@ -1328,7 +1340,7 @@ test_that("toJSON() returns an RDD of the correct values", { }) test_that("showDF()", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) s <- capture.output(showDF(df)) expected <- paste("+----+-------+\n", "| age| name|\n", @@ -1341,12 +1353,12 @@ test_that("showDF()", { }) test_that("isLocal()", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_false(isLocal(df)) }) test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"Andy\", \"age\":30}", @@ -1383,7 +1395,7 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { }) test_that("withColumn() and withColumnRenamed()", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) expect_equal(length(columns(newDF)), 3) expect_equal(columns(newDF)[3], "newAge") @@ -1395,7 +1407,7 @@ test_that("withColumn() and withColumnRenamed()", { }) test_that("mutate(), transform(), rename() and names()", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) expect_equal(length(columns(newDF)), 3) expect_equal(columns(newDF)[3], "newAge") @@ -1425,7 +1437,7 @@ test_that("mutate(), transform(), rename() and names()", { }) test_that("write.df() on DataFrame and works with read.parquet", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetDF <- read.parquet(sqlContext, parquetPath) expect_is(parquetDF, "DataFrame") @@ -1433,7 +1445,7 @@ test_that("write.df() on DataFrame and works with read.parquet", { }) test_that("read.parquet()/parquetFile() works with multiple input paths", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") write.df(df, parquetPath2, "parquet", mode="overwrite") @@ -1452,7 +1464,7 @@ test_that("read.parquet()/parquetFile() works with multiple input paths", { }) test_that("describe() and summarize() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") expect_equal(collect(stats)[2, "age"], "24.5") @@ -1470,7 +1482,7 @@ test_that("describe() and summarize() on a DataFrame", { }) test_that("dropna() and na.omit() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPathNa) + df <- read.json(sqlContext, jsonPathNa) rows <- collect(df) # drop with columns @@ -1556,7 +1568,7 @@ test_that("dropna() and na.omit() on a DataFrame", { }) test_that("fillna() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPathNa) + df <- read.json(sqlContext, jsonPathNa) rows <- collect(df) # fill with value @@ -1665,7 +1677,7 @@ test_that("Method as.data.frame as a synonym for collect()", { }) test_that("attach() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_error(age) attach(df) expect_is(age, "DataFrame") @@ -1713,7 +1725,7 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", { list("a"="b", "c"="d", "e"="f"))))) expect_equal(coltypes(x), "map") - df <- selectExpr(jsonFile(sqlContext, jsonPath), "name", "(age * 1.21) as age") + df <- selectExpr(read.json(sqlContext, jsonPath), "name", "(age * 1.21) as age") expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) df1 <- select(df, cast(df$age, "integer")) diff --git a/examples/src/main/r/dataframe.R b/examples/src/main/r/dataframe.R index 53b817144f6ac..62f60e57eebe6 100644 --- a/examples/src/main/r/dataframe.R +++ b/examples/src/main/r/dataframe.R @@ -35,7 +35,7 @@ printSchema(df) # Create a DataFrame from a JSON file path <- file.path(Sys.getenv("SPARK_HOME"), "examples/src/main/resources/people.json") -peopleDF <- jsonFile(sqlContext, path) +peopleDF <- read.json(sqlContext, path) printSchema(peopleDF) # Register this DataFrame as a table. From 3bff1d24cc23a0614c63d1fd12659247c9f37343 Mon Sep 17 00:00:00 2001 From: anabranch Date: Fri, 11 Dec 2015 12:55:56 -0800 Subject: [PATCH 1120/1168] [SPARK-11964][DOCS][ML] Add in Pipeline Import/Export Documentation Adding in Pipeline Import and Export Documentation. Author: anabranch Author: Bill Chambers Closes #10179 from anabranch/master. --- docs/ml-guide.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 5c96c2b7d5cc9..44a316a07dfef 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -192,6 +192,10 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. +## Saving and Loading Pipelines + +Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported. + # Code examples This section gives code examples illustrating the functionality discussed above. @@ -455,6 +459,15 @@ val pipeline = new Pipeline() // Fit the pipeline to training documents. val model = pipeline.fit(training) +// now we can optionally save the fitted pipeline to disk +model.save("/tmp/spark-logistic-regression-model") + +// we can also save this unfit pipeline to disk +pipeline.save("/tmp/unfit-lr-model") + +// and load it back in during production +val sameModel = Pipeline.load("/tmp/spark-logistic-regression-model") + // Prepare test documents, which are unlabeled (id, text) tuples. val test = sqlContext.createDataFrame(Seq( (4L, "spark i j k"), From a2376c367f5122fdd4b8b32b2f4a65acc5b87dc5 Mon Sep 17 00:00:00 2001 From: proflin Date: Fri, 11 Dec 2015 13:50:36 -0800 Subject: [PATCH 1121/1168] [SPARK-12273][STREAMING] Make Spark Streaming web UI list Receivers in order Currently the Streaming web UI does NOT list Receivers in order; however, it seems more convenient for the users if Receivers are listed in order. ![spark-12273](https://cloud.githubusercontent.com/assets/15843379/11736602/0bb7f7a8-a00b-11e5-8e86-96ba9297fb12.png) Author: proflin Closes #10264 from proflin/Spark-12273. --- .../scala/org/apache/spark/streaming/ui/StreamingPage.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 4588b2163cd44..88a4483e8068f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -392,8 +392,9 @@ private[ui] class StreamingPage(parent: StreamingTab) maxX: Long, minY: Double, maxY: Double): Seq[Node] = { - val content = listener.receivedEventRateWithBatchTime.map { case (streamId, eventRates) => - generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxY) + val content = listener.receivedEventRateWithBatchTime.toList.sortBy(_._1).map { + case (streamId, eventRates) => + generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxY) }.foldLeft[Seq[Node]](Nil)(_ ++ _) // scalastyle:off From 24352b78ed7927afe1b06a7a034da922ff4256fa Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Fri, 11 Dec 2015 14:21:33 -0800 Subject: [PATCH 1122/1168] [SPARK-11497][MLLIB][PYTHON] PySpark RowMatrix Constructor Has Type Erasure Issue As noted in PR #9441, implementing `tallSkinnyQR` uncovered a bug with our PySpark `RowMatrix` constructor. As discussed on the dev list [here](http://apache-spark-developers-list.1001551.n3.nabble.com/K-Means-And-Class-Tags-td10038.html), there appears to be an issue with type erasure with RDDs coming from Java, and by extension from PySpark. Although we are attempting to construct a `RowMatrix` from an `RDD[Vector]` in [PythonMLlibAPI](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala#L1115), the `Vector` type is erased, resulting in an `RDD[Object]`. Thus, when calling Scala's `tallSkinnyQR` from PySpark, we get a Java `ClassCastException` in which an `Object` cannot be cast to a Spark `Vector`. As noted in the aforementioned dev list thread, this issue was also encountered with `DecisionTrees`, and the fix involved an explicit `retag` of the RDD with a `Vector` type. `IndexedRowMatrix` and `CoordinateMatrix` do not appear to have this issue likely due to their related helper functions in `PythonMLlibAPI` creating the RDDs explicitly from DataFrames with pattern matching, thus preserving the types. This PR currently contains that retagging fix applied to the `createRowMatrix` helper function in `PythonMLlibAPI`. This PR blocks #9441, so once this is merged, the other can be rebased. cc holdenk Author: Mike Dusenberry Closes #9458 from dusenberrymw/SPARK-11497_PySpark_RowMatrix_Constructor_Has_Type_Erasure_Issue. --- .../org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 2aa6aec0b4347..8d546e3d6099b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1143,7 +1143,7 @@ private[python] class PythonMLLibAPI extends Serializable { * Wrapper around RowMatrix constructor. */ def createRowMatrix(rows: JavaRDD[Vector], numRows: Long, numCols: Int): RowMatrix = { - new RowMatrix(rows.rdd, numRows, numCols) + new RowMatrix(rows.rdd.retag(classOf[Vector]), numRows, numCols) } /** From 74bc85862c1e274bb1751e7ea0067e260c1a8744 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Fri, 11 Dec 2015 15:43:00 -0800 Subject: [PATCH 1123/1168] [SPARK-12217][ML] Document invalid handling for StringIndexer Added a paragraph regarding StringIndexer#setHandleInvalid to the ml-features documentation. I wonder if I should also add a snippet to the code example, input welcome. Author: BenFradet Closes #10257 from BenFradet/SPARK-12217. --- docs/ml-features.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/docs/ml-features.md b/docs/ml-features.md index 6494fed0a01e5..8b00cc652dc7a 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -459,6 +459,42 @@ column, we should get the following: "a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with index `2`. +Additionaly, there are two strategies regarding how `StringIndexer` will handle +unseen labels when you have fit a `StringIndexer` on one dataset and then use it +to transform another: + +- throw an exception (which is the default) +- skip the row containing the unseen label entirely + +**Examples** + +Let's go back to our previous example but this time reuse our previously defined +`StringIndexer` on the following dataset: + +~~~~ + id | category +----|---------- + 0 | a + 1 | b + 2 | c + 3 | d +~~~~ + +If you've not set how `StringIndexer` handles unseen labels or set it to +"error", an exception will be thrown. +However, if you had called `setHandleInvalid("skip")`, the following dataset +will be generated: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 +~~~~ + +Notice that the row containing "d" does not appear. +
    From 9f25f85a37374a54c65ff4c1a73e4c429b447137 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Dec 2015 18:02:24 -0800 Subject: [PATCH 1124/1168] [SPARK-11978][ML] Move dataset_example.py to examples/ml and rename to dataframe_example.py Since ```Dataset``` has a new meaning in Spark 1.6, we should rename it to avoid confusion. #9873 finished the work of Scala example, here we focus on the Python one. Move dataset_example.py to ```examples/ml``` and rename to ```dataframe_example.py```. BTW, fix minor missing issues of #9873. cc mengxr Author: Yanbo Liang Closes #9957 from yanboliang/SPARK-11978. --- .../dataframe_example.py} | 56 +++++++++++-------- .../spark/examples/ml/DataFrameExample.scala | 8 +-- 2 files changed, 38 insertions(+), 26 deletions(-) rename examples/src/main/python/{mllib/dataset_example.py => ml/dataframe_example.py} (53%) diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/ml/dataframe_example.py similarity index 53% rename from examples/src/main/python/mllib/dataset_example.py rename to examples/src/main/python/ml/dataframe_example.py index e23ecc0c5d302..d2644ca335654 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -16,8 +16,8 @@ # """ -An example of how to use DataFrame as a dataset for ML. Run with:: - bin/spark-submit examples/src/main/python/mllib/dataset_example.py +An example of how to use DataFrame for ML. Run with:: + bin/spark-submit examples/src/main/python/ml/dataframe_example.py """ from __future__ import print_function @@ -28,36 +28,48 @@ from pyspark import SparkContext from pyspark.sql import SQLContext -from pyspark.mllib.util import MLUtils from pyspark.mllib.stat import Statistics - -def summarize(dataset): - print("schema: %s" % dataset.schema().json()) - labels = dataset.map(lambda r: r.label) - print("label average: %f" % labels.mean()) - features = dataset.map(lambda r: r.features) - summary = Statistics.colStats(features) - print("features average: %r" % summary.mean()) - if __name__ == "__main__": if len(sys.argv) > 2: - print("Usage: dataset_example.py ", file=sys.stderr) + print("Usage: dataframe_example.py ", file=sys.stderr) exit(-1) - sc = SparkContext(appName="DatasetExample") + sc = SparkContext(appName="DataFrameExample") sqlContext = SQLContext(sc) if len(sys.argv) == 2: input = sys.argv[1] else: input = "data/mllib/sample_libsvm_data.txt" - points = MLUtils.loadLibSVMFile(sc, input) - dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() - summarize(dataset0) + + # Load input data + print("Loading LIBSVM file with UDT from " + input + ".") + df = sqlContext.read.format("libsvm").load(input).cache() + print("Schema from LIBSVM:") + df.printSchema() + print("Loaded training data as a DataFrame with " + + str(df.count()) + " records.") + + # Show statistical summary of labels. + labelSummary = df.describe("label") + labelSummary.show() + + # Convert features column to an RDD of vectors. + features = df.select("features").map(lambda r: r.features) + summary = Statistics.colStats(features) + print("Selected features column with average values:\n" + + str(summary.mean())) + + # Save the records in a parquet file. tempdir = tempfile.NamedTemporaryFile(delete=False).name os.unlink(tempdir) - print("Save dataset as a Parquet file to %s." % tempdir) - dataset0.saveAsParquetFile(tempdir) - print("Load it back and summarize it again.") - dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() - summarize(dataset1) + print("Saving to " + tempdir + " as Parquet file.") + df.write.parquet(tempdir) + + # Load the records back. + print("Loading Parquet file with UDT from " + tempdir) + newDF = sqlContext.read.parquet(tempdir) + print("Schema from Parquet:") + newDF.printSchema() shutil.rmtree(tempdir) + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 424f00158c2f2..0a477abae5679 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -44,10 +44,10 @@ object DataFrameExample { def main(args: Array[String]) { val defaultParams = Params() - val parser = new OptionParser[Params]("DatasetExample") { - head("Dataset: an example app using DataFrame as a Dataset for ML.") + val parser = new OptionParser[Params]("DataFrameExample") { + head("DataFrameExample: an example app using DataFrame for ML.") opt[String]("input") - .text(s"input path to dataset") + .text(s"input path to dataframe") .action((x, c) => c.copy(input = x)) checkConfig { params => success @@ -88,7 +88,7 @@ object DataFrameExample { // Save the records in a parquet file. val tmpDir = Files.createTempDir() tmpDir.deleteOnExit() - val outputDir = new File(tmpDir, "dataset").toString + val outputDir = new File(tmpDir, "dataframe").toString println(s"Saving to $outputDir as Parquet file.") df.write.parquet(outputDir) From 1c5cd9ee0a1813ed087df85718f11359b821c795 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Fri, 11 Dec 2015 19:07:48 -0800 Subject: [PATCH 1125/1168] [SPARK-12298][SQL] Fix infinite loop in DataFrame.sortWithinPartitions Modifies the String overload to call the Column overload and ensures this is called in a test. Author: Ankur Dave Closes #10271 from ankurdave/SPARK-12298. --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index da180a2ba09d9..497bd48266770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -609,7 +609,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sortWithinPartitions(sortCol: String, sortCols: String*): DataFrame = { - sortWithinPartitions(sortCol, sortCols : _*) + sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5353fefaf4b84..c0bbf73ab1188 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1090,8 +1090,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } // Distribute into one partition and order by. This partition should contain all the values. - val df6 = data.repartition(1, $"a").sortWithinPartitions($"b".asc) - // Walk each partition and verify that it is sorted descending and not globally sorted. + val df6 = data.repartition(1, $"a").sortWithinPartitions("b") + // Walk each partition and verify that it is sorted ascending and not globally sorted. df6.rdd.foreachPartition { p => var previousValue: Int = -1 var allSequential: Boolean = true From 5b39f6c7c2dad6b4ded9d3e6b659a965e0a39fbe Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 11 Dec 2015 20:55:16 -0800 Subject: [PATCH 1126/1168] [SPARK-12158][SPARKR][SQL] Fix 'sample' functions that break R unit test cases The existing sample functions miss the parameter `seed`, however, the corresponding function interface in `generics` has such a parameter. Thus, although the function caller can call the function with the 'seed', we are not using the value. This could cause SparkR unit tests failed. For example, I hit it in another PR: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/47213/consoleFull Author: gatorsmile Closes #10160 from gatorsmile/sampleR. --- R/pkg/R/DataFrame.R | 17 +++++++++++------ R/pkg/inst/tests/testthat/test_sparkSQL.R | 4 ++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 975b058c0aaf1..764597d1e32b4 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -662,6 +662,7 @@ setMethod("unique", #' @param x A SparkSQL DataFrame #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value #' #' @family DataFrame functions #' @rdname sample @@ -677,13 +678,17 @@ setMethod("unique", #' collect(sample(df, TRUE, 0.5)) #'} setMethod("sample", - # TODO : Figure out how to send integer as java.lang.Long to JVM so - # we can send seed as an argument through callJMethod signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), - function(x, withReplacement, fraction) { + function(x, withReplacement, fraction, seed) { if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) - sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + if (!missing(seed)) { + # TODO : Figure out how to send integer as java.lang.Long to JVM so + # we can send seed as an argument through callJMethod + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed)) + } else { + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + } dataFrame(sdf) }) @@ -692,8 +697,8 @@ setMethod("sample", setMethod("sample_frac", signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), - function(x, withReplacement, fraction) { - sample(x, withReplacement, fraction) + function(x, withReplacement, fraction, seed) { + sample(x, withReplacement, fraction, seed) }) #' nrow diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ed9b2c9d4d16c..071fd310fd58a 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -724,6 +724,10 @@ test_that("sample on a DataFrame", { sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled2) < 3) + count1 <- count(sample(df, FALSE, 0.1, 0)) + count2 <- count(sample(df, FALSE, 0.1, 0)) + expect_equal(count1, count2) + # Also test sample_frac sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) From a82431f14a7bf9e0d3b88491016498862a15bb68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Sat, 12 Dec 2015 08:51:52 +0000 Subject: [PATCH 1127/1168] [SPARK-11193] Use Java ConcurrentHashMap instead of SynchronizedMap trait in order to avoid ClassCastException due to KryoSerializer in KinesisReceiver MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: Jean-Baptiste Onofré Closes #10203 from jbonofre/SPARK-11193. --- .../streaming/kinesis/KinesisReceiver.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 05080835fc4ad..80edda59e1719 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.kinesis import java.util.UUID +import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -124,8 +125,7 @@ private[kinesis] class KinesisReceiver[T]( private val seqNumRangesInCurrentBlock = new mutable.ArrayBuffer[SequenceNumberRange] /** Sequence number ranges of data added to each generated block */ - private val blockIdToSeqNumRanges = new mutable.HashMap[StreamBlockId, SequenceNumberRanges] - with mutable.SynchronizedMap[StreamBlockId, SequenceNumberRanges] + private val blockIdToSeqNumRanges = new ConcurrentHashMap[StreamBlockId, SequenceNumberRanges] /** * The centralized kinesisCheckpointer that checkpoints based on the given checkpointInterval. @@ -135,8 +135,8 @@ private[kinesis] class KinesisReceiver[T]( /** * Latest sequence number ranges that have been stored successfully. * This is used for checkpointing through KCL */ - private val shardIdToLatestStoredSeqNum = new mutable.HashMap[String, String] - with mutable.SynchronizedMap[String, String] + private val shardIdToLatestStoredSeqNum = new ConcurrentHashMap[String, String] + /** * This is called when the KinesisReceiver starts and must be non-blocking. * The KCL creates and manages the receiving/processing thread pool through Worker.run(). @@ -222,7 +222,7 @@ private[kinesis] class KinesisReceiver[T]( /** Get the latest sequence number for the given shard that can be checkpointed through KCL */ private[kinesis] def getLatestSeqNumToCheckpoint(shardId: String): Option[String] = { - shardIdToLatestStoredSeqNum.get(shardId) + Option(shardIdToLatestStoredSeqNum.get(shardId)) } /** @@ -257,7 +257,7 @@ private[kinesis] class KinesisReceiver[T]( * for next block. Internally, this is synchronized with `rememberAddedRange()`. */ private def finalizeRangesForCurrentBlock(blockId: StreamBlockId): Unit = { - blockIdToSeqNumRanges(blockId) = SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray) + blockIdToSeqNumRanges.put(blockId, SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray)) seqNumRangesInCurrentBlock.clear() logDebug(s"Generated block $blockId has $blockIdToSeqNumRanges") } @@ -265,7 +265,7 @@ private[kinesis] class KinesisReceiver[T]( /** Store the block along with its associated ranges */ private def storeBlockWithRanges( blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[T]): Unit = { - val rangesToReportOption = blockIdToSeqNumRanges.remove(blockId) + val rangesToReportOption = Option(blockIdToSeqNumRanges.remove(blockId)) if (rangesToReportOption.isEmpty) { stop("Error while storing block into Spark, could not find sequence number ranges " + s"for block $blockId") @@ -294,7 +294,7 @@ private[kinesis] class KinesisReceiver[T]( // Note that we are doing this sequentially because the array of sequence number ranges // is assumed to be rangesToReport.ranges.foreach { range => - shardIdToLatestStoredSeqNum(range.shardId) = range.toSeqNumber + shardIdToLatestStoredSeqNum.put(range.shardId, range.toSeqNumber) } } From 53389bd84b850b54a28ee5b316afd5c6b72ca6e5 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Sat, 12 Dec 2015 17:47:01 -0800 Subject: [PATCH 1128/1168] [SPARK-12199][DOC] Follow-up: Refine example code in ml-features.md https://issues.apache.org/jira/browse/SPARK-12199 Follow-up PR of SPARK-11551. Fix some errors in ml-features.md mengxr Author: Xusen Yin Closes #10193 from yinxusen/SPARK-12199. --- docs/ml-features.md | 22 +++++++++---------- .../examples/ml/JavaBinarizerExample.java | 2 +- .../python/ml/polynomial_expansion_example.py | 6 ++--- ....scala => ElementwiseProductExample.scala} | 0 4 files changed, 15 insertions(+), 15 deletions(-) rename examples/src/main/scala/org/apache/spark/examples/ml/{ElementWiseProductExample.scala => ElementwiseProductExample.scala} (100%) diff --git a/docs/ml-features.md b/docs/ml-features.md index 8b00cc652dc7a..158f3f201899c 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -63,7 +63,7 @@ the [IDF Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IDF) for mor `Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` transforms each document into a vector using the average of all words in the document; this vector can then be used for as features for prediction, document similarity calculations, etc. -Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more +Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2Vec) for more details. In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. @@ -411,7 +411,7 @@ for more details on the API. Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}} +{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}
    @@ -669,7 +669,7 @@ for more details on the API. The following example demonstrates how to load a dataset in libsvm format and then normalize each row to have unit $L^2$ norm and unit $L^\infty$ norm.
    -
    +
    Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) for more details on the API. @@ -677,7 +677,7 @@ for more details on the API. {% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %}
    -
    +
    Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) for more details on the API. @@ -685,7 +685,7 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %}
    -
    +
    Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) for more details on the API. @@ -709,7 +709,7 @@ Note that if the standard deviation of a feature is zero, it will return default The following example demonstrates how to load a dataset in libsvm format and then normalize each feature to have unit standard deviation.
    -
    +
    Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) for more details on the API. @@ -717,7 +717,7 @@ for more details on the API. {% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %}
    -
    +
    Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) for more details on the API. @@ -725,7 +725,7 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %}
    -
    +
    Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) for more details on the API. @@ -788,7 +788,7 @@ More details can be found in the API docs for [Bucketizer](api/scala/index.html# The following example demonstrates how to bucketize a column of `Double`s into another index-wised column.
    -
    +
    Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) for more details on the API. @@ -796,7 +796,7 @@ for more details on the API. {% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %}
    -
    +
    Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) for more details on the API. @@ -804,7 +804,7 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %}
    -
    +
    Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) for more details on the API. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java index 9698cac504371..1eda1f694fc27 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -59,7 +59,7 @@ public static void main(String[] args) { DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); + Double binarized_value = r.getDouble(0); System.out.println(binarized_value); } // $example off$ diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py index 3d4fafd1a42e9..89f5cbe8f2f41 100644 --- a/examples/src/main/python/ml/polynomial_expansion_example.py +++ b/examples/src/main/python/ml/polynomial_expansion_example.py @@ -30,9 +30,9 @@ # $example on$ df = sqlContext\ - .createDataFrame([(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], + .createDataFrame([(Vectors.dense([-2.0, 2.3]),), + (Vectors.dense([0.0, 0.0]),), + (Vectors.dense([0.6, -1.1]),)], ["features"]) px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") polyDF = px.transform(df) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala similarity index 100% rename from examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala From 6659b4eb32b484db3d2d1dcd13be14d8305b2727 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sat, 12 Dec 2015 21:58:55 -0800 Subject: [PATCH 1129/1168] [SPARK-12267][CORE] Store the remote RpcEnv address to send the correct disconnetion message Author: Shixiong Zhu Closes #10261 from zsxwing/SPARK-12267. --- .../spark/deploy/master/ApplicationInfo.scala | 1 + .../apache/spark/deploy/worker/Worker.scala | 2 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 21 ++++++++++ .../org/apache/spark/rpc/RpcEnvSuite.scala | 42 +++++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index ac553b71115df..7e2cf956c7253 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -66,6 +66,7 @@ private[spark] class ApplicationInfo( nextExecutorId = 0 removedExecutors = new ArrayBuffer[ExecutorDesc] executorLimit = Integer.MAX_VALUE + appUIUrlAtHistoryServer = None } private def newExecutorId(useID: Option[Int] = None): Int = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 1afc1ff59f2f9..f41efb097b4be 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -690,7 +690,7 @@ private[deploy] object Worker extends Logging { val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, - args.memory, args.masters, args.workDir) + args.memory, args.masters, args.workDir, conf = conf) rpcEnv.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 68c5f44145b0d..f82fd4eb5756d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -553,6 +553,9 @@ private[netty] class NettyRpcHandler( // A variable to track whether we should dispatch the RemoteProcessConnected message. private val clients = new ConcurrentHashMap[TransportClient, JBoolean]() + // A variable to track the remote RpcEnv addresses of all clients + private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]() + override def receive( client: TransportClient, message: ByteBuffer, @@ -580,6 +583,12 @@ private[netty] class NettyRpcHandler( // Create a new message with the socket address of the client as the sender. RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) } else { + // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for + // the listening address + val remoteEnvAddress = requestMessage.senderAddress + if (remoteAddresses.putIfAbsent(clientAddr, remoteEnvAddress) == null) { + dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress)) + } requestMessage } } @@ -591,6 +600,12 @@ private[netty] class NettyRpcHandler( if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr)) + // If the remove RpcEnv listens to some address, we should also fire a + // RemoteProcessConnectionError for the remote RpcEnv listening address + val remoteEnvAddress = remoteAddresses.get(clientAddr) + if (remoteEnvAddress != null) { + dispatcher.postToAll(RemoteProcessConnectionError(cause, remoteEnvAddress)) + } } else { // If the channel is closed before connecting, its remoteAddress will be null. // See java.net.Socket.getRemoteSocketAddress @@ -606,6 +621,12 @@ private[netty] class NettyRpcHandler( val clientAddr = RpcAddress(addr.getHostName, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) + val remoteEnvAddress = remoteAddresses.remove(clientAddr) + // If the remove RpcEnv listens to some address, we should also fire a + // RemoteProcessDisconnected for the remote RpcEnv listening address + if (remoteEnvAddress != null) { + dispatcher.postToAll(RemoteProcessDisconnected(remoteEnvAddress)) + } } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index a61d0479aacdb..6d153eb04e04f 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -545,6 +545,48 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("network events between non-client-mode RpcEnvs") { + val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] + env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case "hello" => + case m => events += "receive" -> m + } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + events += "onConnected" -> remoteAddress + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + events += "onDisconnected" -> remoteAddress + } + + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + events += "onNetworkError" -> remoteAddress + } + + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = false) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "network-events-non-client") + val remoteAddress = anotherEnv.address + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", remoteAddress))) + } + + anotherEnv.shutdown() + anotherEnv.awaitTermination() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", remoteAddress))) + assert(events.contains(("onDisconnected", remoteAddress))) + } + } + test("sendWithReply: unserializable error") { env.setupEndpoint("sendWithReply-unserializable-error", new RpcEndpoint { override val rpcEnv = env From af3e0717e3756127912e3f2cbb6ec63de663cf46 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sun, 13 Dec 2015 22:06:39 -0800 Subject: [PATCH 1130/1168] [SPARK-12281][CORE] Fix a race condition when reporting ExecutorState in the shutdown hook 1. Make sure workers and masters exit so that no worker or master will still be running when triggering the shutdown hook. 2. Set ExecutorState to FAILED if it's still RUNNING when executing the shutdown hook. This should fix the potential exceptions when exiting a local cluster ``` java.lang.AssertionError: assertion failed: executor 4 state transfer from RUNNING to RUNNING is illegal at scala.Predef$.assert(Predef.scala:179) at org.apache.spark.deploy.master.Master$$anonfun$receive$1.applyOrElse(Master.scala:260) at org.apache.spark.rpc.netty.Inbox$$anonfun$process$1.apply$mcV$sp(Inbox.scala:116) at org.apache.spark.rpc.netty.Inbox.safelyCall(Inbox.scala:204) at org.apache.spark.rpc.netty.Inbox.process(Inbox.scala:100) at org.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) java.lang.IllegalStateException: Shutdown hooks cannot be modified during shutdown. at org.apache.spark.util.SparkShutdownHookManager.add(ShutdownHookManager.scala:246) at org.apache.spark.util.ShutdownHookManager$.addShutdownHook(ShutdownHookManager.scala:191) at org.apache.spark.util.ShutdownHookManager$.addShutdownHook(ShutdownHookManager.scala:180) at org.apache.spark.deploy.worker.ExecutorRunner.start(ExecutorRunner.scala:73) at org.apache.spark.deploy.worker.Worker$$anonfun$receive$1.applyOrElse(Worker.scala:474) at org.apache.spark.rpc.netty.Inbox$$anonfun$process$1.apply$mcV$sp(Inbox.scala:116) at org.apache.spark.rpc.netty.Inbox.safelyCall(Inbox.scala:204) at org.apache.spark.rpc.netty.Inbox.process(Inbox.scala:100) at org.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` Author: Shixiong Zhu Closes #10269 from zsxwing/executor-state. --- .../scala/org/apache/spark/deploy/LocalSparkCluster.scala | 2 ++ .../main/scala/org/apache/spark/deploy/master/Master.scala | 5 ++--- .../org/apache/spark/deploy/worker/ExecutorRunner.scala | 5 +++++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 83ccaadfe7447..5bb62d37d6374 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -75,6 +75,8 @@ class LocalSparkCluster( // Stop the workers before the master so they don't get upset that it disconnected workerRpcEnvs.foreach(_.shutdown()) masterRpcEnvs.foreach(_.shutdown()) + workerRpcEnvs.foreach(_.awaitTermination()) + masterRpcEnvs.foreach(_.awaitTermination()) masterRpcEnvs.clear() workerRpcEnvs.clear() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 04b20e0d6ab9c..1355e1ad1b523 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -257,9 +257,8 @@ private[deploy] class Master( exec.state = state if (state == ExecutorState.RUNNING) { - if (oldState != ExecutorState.LAUNCHING) { - logWarning(s"Executor $execId state transfer from $oldState to RUNNING is unexpected") - } + assert(oldState == ExecutorState.LAUNCHING, + s"executor $execId state transfer from $oldState to RUNNING is illegal") appInfo.resetRetryCount() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 25a17473e4b53..9a42487bb37aa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -71,6 +71,11 @@ private[deploy] class ExecutorRunner( workerThread.start() // Shutdown hook that kills actors on shutdown. shutdownHook = ShutdownHookManager.addShutdownHook { () => + // It's possible that we arrive here before calling `fetchAndRunExecutor`, then `state` will + // be `ExecutorState.RUNNING`. In this case, we should set `state` to `FAILED`. + if (state == ExecutorState.RUNNING) { + state = ExecutorState.FAILED + } killProcess(Some("Worker shutting down")) } } From cb0b87242dad295875487b97aaea16813092e200 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 13 Dec 2015 22:57:01 -0800 Subject: [PATCH 1131/1168] [SPARK-12213][SQL] use multiple partitions for single distinct query Currently, we could generate different plans for query with single distinct (depends on spark.sql.specializeSingleDistinctAggPlanning), one works better on low cardinality columns, the other works better for high cardinality column (default one). This PR change to generate a single plan (three aggregations and two exchanges), which work better in both cases, then we could safely remove the flag `spark.sql.specializeSingleDistinctAggPlanning` (introduced in 1.6). For a query like `SELECT COUNT(DISTINCT a) FROM table` will be ``` AGG-4 (count distinct) Shuffle to a single reducer Partial-AGG-3 (count distinct, no grouping) Partial-AGG-2 (grouping on a) Shuffle by a Partial-AGG-1 (grouping on a) ``` This PR also includes large refactor for aggregation (reduce 500+ lines of code) cc yhuai nongli marmbrus Author: Davies Liu Closes #10228 from davies/single_distinct. --- .../spark/sql/catalyst/CatalystConf.scala | 7 - .../DistinctAggregationRewriter.scala | 11 +- .../scala/org/apache/spark/sql/SQLConf.scala | 15 - .../aggregate/AggregationIterator.scala | 417 ++++++----------- .../aggregate/SortBasedAggregate.scala | 29 +- .../SortBasedAggregationIterator.scala | 47 +- .../aggregate/TungstenAggregate.scala | 25 +- .../TungstenAggregationIterator.scala | 439 +++--------------- .../spark/sql/execution/aggregate/utils.scala | 280 +++++------ .../execution/AggregationQuerySuite.scala | 142 +++--- 10 files changed, 422 insertions(+), 990 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 7c2b8a9407884..2c7c58e66b855 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst private[spark] trait CatalystConf { def caseSensitiveAnalysis: Boolean - - protected[spark] def specializeSingleDistinctAggPlanning: Boolean } /** @@ -31,13 +29,8 @@ object EmptyConf extends CatalystConf { override def caseSensitiveAnalysis: Boolean = { throw new UnsupportedOperationException } - - protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = { - throw new UnsupportedOperationException - } } /** A CatalystConf that can be used for local testing. */ case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf { - protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 9c78f6d4cc71b..4e7d1341028ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -123,15 +123,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP .filter(_.isDistinct) .groupBy(_.aggregateFunction.children.toSet) - val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) { - // When the flag is set to specialize single distinct agg planning, - // we will rely on our Aggregation strategy to handle queries with a single - // distinct column. - distinctAggGroups.size > 1 - } else { - distinctAggGroups.size >= 1 - } - if (shouldRewrite) { + // Aggregation strategy can handle the query with single distinct + if (distinctAggGroups.size > 1) { // Create the attributes for the grouping id and the group by clause. val gid = new AttributeReference("gid", IntegerType, false)() val groupByMap = a.groupingExpressions.collect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 58adf64e49869..3d819262859f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -449,18 +449,6 @@ private[spark] object SQLConf { doc = "When true, we could use `datasource`.`path` as table in SQL query" ) - val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING = - booleanConf("spark.sql.specializeSingleDistinctAggPlanning", - defaultValue = Some(false), - isPublic = false, - doc = "When true, if a query only has a single distinct column and it has " + - "grouping expressions, we will use our planner rule to handle this distinct " + - "column (other cases are handled by DistinctAggregationRewriter). " + - "When false, we will always use DistinctAggregationRewriter to plan " + - "aggregation queries with DISTINCT keyword. This is an internal flag that is " + - "used to benchmark the performance impact of using DistinctAggregationRewriter to " + - "plan aggregation queries with a single distinct column.") - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" @@ -579,9 +567,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) - protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = - getConf(SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING) - /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 008478a6a0e17..0c74df0aa5fdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import scala.collection.mutable.ArrayBuffer - /** - * The base class of [[SortBasedAggregationIterator]]. + * The base class of [[SortBasedAggregationIterator]] and [[TungstenAggregationIterator]]. * It mainly contains two parts: * 1. It initializes aggregate functions. * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of @@ -33,64 +33,58 @@ import scala.collection.mutable.ArrayBuffer * is used to generate result. */ abstract class AggregationIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], + inputAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) - extends Iterator[InternalRow] with Logging { + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection)) + extends Iterator[UnsafeRow] with Logging { /////////////////////////////////////////////////////////////////////////// // Initializing functions. /////////////////////////////////////////////////////////////////////////// - // An Seq of all AggregateExpressions. - // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final - // are at the beginning of the allAggregateExpressions. - protected val allAggregateExpressions = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - require( - allAggregateExpressions.map(_.mode).distinct.length <= 2, - s"$allAggregateExpressions are not supported becuase they have more than 2 distinct modes.") - /** - * The distinct modes of AggregateExpressions. Right now, we can handle the following mode: - * - Partial-only: all AggregateExpressions have the mode of Partial; - * - PartialMerge-only: all AggregateExpressions have the mode of PartialMerge); - * - Final-only: all AggregateExpressions have the mode of Final; - * - Final-Complete: some AggregateExpressions have the mode of Final and - * others have the mode of Complete; - * - Complete-only: nonCompleteAggregateExpressions is empty and we have AggregateExpressions - * with mode Complete in completeAggregateExpressions; and - * - Grouping-only: there is no AggregateExpression. - */ - protected val aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = - nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> - completeAggregateExpressions.map(_.mode).distinct.headOption + * The following combinations of AggregationMode are supported: + * - Partial + * - PartialMerge (for single distinct) + * - Partial and PartialMerge (for single distinct) + * - Final + * - Complete (for SortBasedAggregate with functions that does not support Partial) + * - Final and Complete (currently not used) + * + * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression + * could have a flag to tell it's final or not. + */ + { + val modes = aggregateExpressions.map(_.mode).distinct.toSet + require(modes.size <= 2, + s"$aggregateExpressions are not supported because they have more than 2 distinct modes.") + require(modes.subsetOf(Set(Partial, PartialMerge)) || modes.subsetOf(Set(Final, Complete)), + s"$aggregateExpressions can't have Partial/PartialMerge and Final/Complete in the same time.") + } // Initialize all AggregateFunctions by binding references if necessary, // and set inputBufferOffset and mutableBufferOffset. - protected val allAggregateFunctions: Array[AggregateFunction] = { + protected def initializeAggregateFunctions( + expressions: Seq[AggregateExpression], + startingInputBufferOffset: Int): Array[AggregateFunction] = { var mutableBufferOffset = 0 - var inputBufferOffset: Int = initialInputBufferOffset - val functions = new Array[AggregateFunction](allAggregateExpressions.length) + var inputBufferOffset: Int = startingInputBufferOffset + val functions = new Array[AggregateFunction](expressions.length) var i = 0 - while (i < allAggregateExpressions.length) { - val func = allAggregateExpressions(i).aggregateFunction - val funcWithBoundReferences: AggregateFunction = allAggregateExpressions(i).mode match { + while (i < expressions.length) { + val func = expressions(i).aggregateFunction + val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => // We need to create BoundReferences if the function is not an // expression-based aggregate function (it does not support code-gen) and the mode of // this function is Partial or Complete because we will call eval of this // function's children in the update method of this aggregate function. // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, valueAttributes) + BindReferences.bindReference(func, inputAttributes) case _ => // We only need to set inputBufferOffset for aggregate functions with mode // PartialMerge and Final. @@ -117,15 +111,18 @@ abstract class AggregationIterator( functions } + protected val aggregateFunctions: Array[AggregateFunction] = + initializeAggregateFunctions(aggregateExpressions, initialInputBufferOffset) + // Positions of those imperative aggregate functions in allAggregateFunctions. // For example, we have func1, func2, func3, func4 in aggregateFunctions, and // func2 and func3 are imperative aggregate functions. // ImperativeAggregateFunctionPositions will be [1, 2]. - private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { + protected[this] val allImperativeAggregateFunctionPositions: Array[Int] = { val positions = new ArrayBuffer[Int]() var i = 0 - while (i < allAggregateFunctions.length) { - allAggregateFunctions(i) match { + while (i < aggregateFunctions.length) { + aggregateFunctions(i) match { case agg: DeclarativeAggregate => case _ => positions += i } @@ -134,17 +131,9 @@ abstract class AggregationIterator( positions.toArray } - // All AggregateFunctions functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - - // All imperative aggregate functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = - nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } - // The projection used to initialize buffer values for all expression-based aggregates. - private[this] val expressionAggInitialProjection = { - val initExpressions = allAggregateFunctions.flatMap { + protected[this] val expressionAggInitialProjection = { + val initExpressions = aggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.initialValues // For the positions corresponding to imperative aggregate functions, we'll use special // no-op expressions which are ignored during projection code-generation. @@ -154,248 +143,112 @@ abstract class AggregationIterator( } // All imperative AggregateFunctions. - private[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] = + protected[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] = allImperativeAggregateFunctionPositions - .map(allAggregateFunctions) + .map(aggregateFunctions) .map(_.asInstanceOf[ImperativeAggregate]) - /////////////////////////////////////////////////////////////////////////// - // Methods and fields used by sub-classes. - /////////////////////////////////////////////////////////////////////////// - // Initializing functions used to process a row. - protected val processRow: (MutableRow, InternalRow) => Unit = { - val rowToBeProcessed = new JoinedRow - val aggregationBufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) - aggregationMode match { - // Partial-only - case (Some(Partial), None) => - val updateExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val expressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - expressionAggUpdateProjection.target(currentBuffer) - // Process all expression-based aggregate functions. - expressionAggUpdateProjection(rowToBeProcessed(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // PartialMerge-only or Final-only - case (Some(PartialMerge), None) | (Some(Final), None) => - val inputAggregationBufferSchema = if (initialInputBufferOffset == 0) { - // If initialInputBufferOffset, the input value does not contain - // grouping keys. - // This part is pretty hacky. - allAggregateFunctions.flatMap(_.inputAggBufferAttributes).toSeq - } else { - groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.inputAggBufferAttributes) - } - // val inputAggregationBufferSchema = - // groupingKeyAttributes ++ - // allAggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - // This projection is used to merge buffer values for all expression-based aggregates. - val expressionAggMergeProjection = - newMutableProjection( - mergeExpressions, - aggregationBufferSchema ++ inputAggregationBufferSchema)() - - (currentBuffer: MutableRow, row: InternalRow) => { - // Process all expression-based aggregate functions. - expressionAggMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Final-Complete - case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - // The first initialInputBufferOffset values of the input aggregation buffer is - // for grouping expressions and distinct columns. - val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset) - - val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - // We do not touch buffer values of aggregate functions with the Final mode. - val finalOffsetExpressions = - Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - - val mergeInputSchema = - aggregationBufferSchema ++ - groupingAttributesAndDistinctColumns ++ - nonCompleteAggregateFunctions.flatMap(_.inputAggBufferAttributes) - val mergeExpressions = - nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } ++ completeOffsetExpressions - val finalExpressionAggMergeProjection = - newMutableProjection(mergeExpressions, mergeInputSchema)() - - val updateExpressions = - finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - val input = rowToBeProcessed(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - - // For all aggregate functions with mode Final, merge buffers. - finalExpressionAggMergeProjection.target(currentBuffer)(input) - i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 + protected def generateProcessRow( + expressions: Seq[AggregateExpression], + functions: Seq[AggregateFunction], + inputAttributes: Seq[Attribute]): (MutableRow, InternalRow) => Unit = { + val joinedRow = new JoinedRow + if (expressions.nonEmpty) { + val mergeExpressions = functions.zipWithIndex.flatMap { + case (ae: DeclarativeAggregate, i) => + expressions(i).mode match { + case Partial | Complete => ae.updateExpressions + case PartialMerge | Final => ae.mergeExpressions } - } - - // Complete-only - case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val updateExpressions = - completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - val input = rowToBeProcessed(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 + case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val updateFunctions = functions.zipWithIndex.collect { + case (ae: ImperativeAggregate, i) => + expressions(i).mode match { + case Partial | Complete => + (buffer: MutableRow, row: InternalRow) => ae.update(buffer, row) + case PartialMerge | Final => + (buffer: MutableRow, row: InternalRow) => ae.merge(buffer, row) } + } + // This projection is used to merge buffer values for all expression-based aggregates. + val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes) + val updateProjection = + newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes)() + + (currentBuffer: MutableRow, row: InternalRow) => { + // Process all expression-based aggregate functions. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) + // Process all imperative aggregate functions. + var i = 0 + while (i < updateFunctions.length) { + updateFunctions(i)(currentBuffer, row) + i += 1 } - + } + } else { // Grouping only. - case (None, None) => (currentBuffer: MutableRow, row: InternalRow) => {} - - case other => - sys.error( - s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + - s"support evaluate modes $other in this iterator.") + (currentBuffer: MutableRow, row: InternalRow) => {} } } - // Initializing the function used to generate the output row. - protected val generateOutput: (InternalRow, MutableRow) => InternalRow = { - val rowToBeEvaluated = new JoinedRow - val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType)) - val mutableOutput = if (outputsUnsafeRows) { - UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow) - } else { - safeOutputRow - } - - aggregationMode match { - // Partial-only or PartialMerge-only: every output row is basically the values of - // the grouping expressions and the corresponding aggregation buffer. - case (Some(Partial), None) | (Some(PartialMerge), None) => - // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not - // support generic getter), we create a mutable projection to output the - // JoinedRow(currentGroupingKey, currentBuffer) - val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.aggBufferAttributes) - val resultProjection = - newMutableProjection( - groupingKeyAttributes ++ bufferSchema, - groupingKeyAttributes ++ bufferSchema)() - resultProjection.target(mutableOutput) - - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - resultProjection(rowToBeEvaluated(currentGroupingKey, currentBuffer)) - // rowToBeEvaluated(currentGroupingKey, currentBuffer) - } + protected val processRow: (MutableRow, InternalRow) => Unit = + generateProcessRow(aggregateExpressions, aggregateFunctions, inputAttributes) - // Final-only, Complete-only and Final-Complete: every output row contains values representing - // resultExpressions. - case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => - val bufferSchemata = - allAggregateFunctions.flatMap(_.aggBufferAttributes) - val evalExpressions = allAggregateFunctions.map { - case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction => NoOp - } - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() - val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes - // TODO: Use unsafe row. - val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) - expressionAggEvalProjection.target(aggregateResult) - val resultProjection = - newMutableProjection( - resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)() - resultProjection.target(mutableOutput) + protected val groupingProjection: UnsafeProjection = + UnsafeProjection.create(groupingExpressions, inputAttributes) + protected val groupingAttributes = groupingExpressions.map(_.toAttribute) - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - // Generate results for all expression-based aggregate functions. - expressionAggEvalProjection(currentBuffer) - // Generate results for all imperative aggregate functions. - var i = 0 - while (i < allImperativeAggregateFunctions.length) { - aggregateResult.update( - allImperativeAggregateFunctionPositions(i), - allImperativeAggregateFunctions(i).eval(currentBuffer)) - i += 1 - } - resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult)) + // Initializing the function used to generate the output row. + protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + val joinedRow = new JoinedRow + val modes = aggregateExpressions.map(_.mode).distinct + val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + if (modes.contains(Final) || modes.contains(Complete)) { + val evalExpressions = aggregateFunctions.map { + case ae: DeclarativeAggregate => ae.evaluateExpression + case agg: AggregateFunction => NoOp + } + val aggregateResult = new SpecificMutableRow(aggregateAttributes.map(_.dataType)) + val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() + expressionAggEvalProjection.target(aggregateResult) + + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + // Generate results for all expression-based aggregate functions. + expressionAggEvalProjection(currentBuffer) + // Generate results for all imperative aggregate functions. + var i = 0 + while (i < allImperativeAggregateFunctions.length) { + aggregateResult.update( + allImperativeAggregateFunctionPositions(i), + allImperativeAggregateFunctions(i).eval(currentBuffer)) + i += 1 } - + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + val resultProjection = UnsafeProjection.create( + groupingAttributes ++ bufferAttributes, + groupingAttributes ++ bufferAttributes) + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + resultProjection(joinedRow(currentGroupingKey, currentBuffer)) + } + } else { // Grouping-only: we only output values of grouping expressions. - case (None, None) => - val resultProjection = - newMutableProjection(resultExpressions, groupingKeyAttributes)() - resultProjection.target(mutableOutput) - - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - resultProjection(currentGroupingKey) - } - - case other => - sys.error( - s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + - s"support evaluate modes $other in this iterator.") + val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + resultProjection(currentGroupingKey) + } } } + protected val generateOutput: (UnsafeRow, MutableRow) => UnsafeRow = + generateResultProjection() + /** Initializes buffer values for all aggregate functions. */ protected def initializeBuffer(buffer: MutableRow): Unit = { expressionAggInitialProjection.target(buffer)(EmptyRow) @@ -405,10 +258,4 @@ abstract class AggregationIterator( i += 1 } } - - /** - * Creates a new aggregation buffer and initializes buffer values - * for all aggregate functions. - */ - protected def newBuffer: MutableRow } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index ee982453c3287..c5470a6989de7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -29,10 +29,8 @@ import org.apache.spark.sql.execution.metric.SQLMetrics case class SortBasedAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) @@ -42,10 +40,8 @@ case class SortBasedAggregate( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def outputsUnsafeRows: Boolean = false - + override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = false - override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -76,31 +72,24 @@ case class SortBasedAggregate( if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. - Iterator[InternalRow]() + Iterator[UnsafeRow]() } else { - val groupingKeyProjection = - UnsafeProjection.create(groupingExpressions, child.output) - val outputIter = new SortBasedAggregationIterator( - groupingKeyProjection, - groupingExpressions.map(_.toAttribute), + groupingExpressions, child.output, iter, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, newMutableProjection, - outputsUnsafeRows, numInputRows, numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { // There is no input and there is no grouping expressions. // We need to output a single row as the output. numOutputRows += 1 - Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) } else { outputIter } @@ -109,7 +98,7 @@ case class SortBasedAggregate( } override def simpleString: String = { - val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + val allAggregateExpressions = aggregateExpressions val keyString = groupingExpressions.mkString("[", ",", "]") val functionString = allAggregateExpressions.mkString("[", ",", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index fe5c3195f867b..ac920aa8bc7f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -24,37 +24,34 @@ import org.apache.spark.sql.execution.metric.LongSQLMetric /** * An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been - * sorted by values of [[groupingKeyAttributes]]. + * sorted by values of [[groupingExpressions]]. */ class SortBasedAggregationIterator( - groupingKeyProjection: InternalRow => InternalRow, - groupingKeyAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], valueAttributes: Seq[Attribute], inputIterator: Iterator[InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean, numInputRows: LongSQLMetric, numOutputRows: LongSQLMetric) extends AggregationIterator( - groupingKeyAttributes, + groupingExpressions, valueAttributes, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, - outputsUnsafeRows) { - - override protected def newBuffer: MutableRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) + newMutableProjection) { + + /** + * Creates a new aggregation buffer and initializes buffer values + * for all aggregate functions. + */ + private def newBuffer: MutableRow = { + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) @@ -76,10 +73,10 @@ class SortBasedAggregationIterator( /////////////////////////////////////////////////////////////////////////// // The partition key of the current partition. - private[this] var currentGroupingKey: InternalRow = _ + private[this] var currentGroupingKey: UnsafeRow = _ // The partition key of next partition. - private[this] var nextGroupingKey: InternalRow = _ + private[this] var nextGroupingKey: UnsafeRow = _ // The first row of next partition. private[this] var firstRowInNextGroup: InternalRow = _ @@ -94,7 +91,7 @@ class SortBasedAggregationIterator( if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) val inputRow = inputIterator.next() - nextGroupingKey = groupingKeyProjection(inputRow).copy() + nextGroupingKey = groupingProjection(inputRow).copy() firstRowInNextGroup = inputRow.copy() numInputRows += 1 sortedInputHasNewGroup = true @@ -120,7 +117,7 @@ class SortBasedAggregationIterator( while (!findNextPartition && inputIterator.hasNext) { // Get the grouping key. val currentRow = inputIterator.next() - val groupingKey = groupingKeyProjection(currentRow) + val groupingKey = groupingProjection(currentRow) numInputRows += 1 // Check if the current row belongs the current input row. @@ -146,7 +143,7 @@ class SortBasedAggregationIterator( override final def hasNext: Boolean = sortedInputHasNewGroup - override final def next(): InternalRow = { + override final def next(): UnsafeRow = { if (hasNext) { // Process the current group. processCurrentSortedGroup() @@ -162,8 +159,8 @@ class SortBasedAggregationIterator( } } - def outputForEmptyGroupingKeyWithoutInput(): InternalRow = { + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { initializeBuffer(sortBasedAggregationBuffer) - generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer) + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 920de615e1d86..b8849c827048a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -30,21 +30,18 @@ import org.apache.spark.sql.types.StructType case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { private[this] val aggregateBufferAttributes = { - (nonCompleteAggregateExpressions ++ completeAggregateExpressions) - .flatMap(_.aggregateFunction.aggBufferAttributes) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) } - require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes)) + require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes)) override private[sql] lazy val metrics = Map( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), @@ -53,9 +50,7 @@ case class TungstenAggregate( "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -94,10 +89,8 @@ case class TungstenAggregate( val aggregationIterator = new TungstenAggregationIterator( groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, newMutableProjection, @@ -119,7 +112,7 @@ case class TungstenAggregate( } override def simpleString: String = { - val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + val allAggregateExpressions = aggregateExpressions testFallbackStartsAt match { case None => @@ -135,9 +128,7 @@ case class TungstenAggregate( } object TungstenAggregate { - def supportsAggregate( - groupingExpressions: Seq[Expression], - aggregateBufferAttributes: Seq[Attribute]): Boolean = { + def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = { val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 04391443920ac..582fdbe547061 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -17,17 +17,15 @@ package org.apache.spark.sql.execution.aggregate -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{InternalAccumulator, Logging, TaskContext} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.metric.LongSQLMetric +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter} import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.KVIterator +import org.apache.spark.{InternalAccumulator, Logging, TaskContext} /** * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. @@ -63,15 +61,11 @@ import org.apache.spark.sql.types.StructType * * @param groupingExpressions * expressions for grouping keys - * @param nonCompleteAggregateExpressions + * @param aggregateExpressions * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Partial]], * [[PartialMerge]], or [[Final]]. - * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions' + * @param aggregateAttributes the attributes of the aggregateExpressions' * outputs when they are stored in the final aggregation buffer. - * @param completeAggregateExpressions - * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Complete]]. - * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs - * when they are stored in the final aggregation buffer. * @param resultExpressions * expressions for generating output rows. * @param newMutableProjection @@ -83,10 +77,8 @@ import org.apache.spark.sql.types.StructType */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), @@ -97,378 +89,62 @@ class TungstenAggregationIterator( numOutputRows: LongSQLMetric, dataSize: LongSQLMetric, spillSize: LongSQLMetric) - extends Iterator[UnsafeRow] with Logging { + extends AggregationIterator( + groupingExpressions, + originalInputAttributes, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection) with Logging { /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. /////////////////////////////////////////////////////////////////////////// - // A Seq containing all AggregateExpressions. - // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final - // are at the beginning of the allAggregateExpressions. - private[this] val allAggregateExpressions: Seq[AggregateExpression] = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - // Check to make sure we do not have more than three modes in our AggregateExpressions. - // If we have, users are hitting a bug and we throw an IllegalStateException. - if (allAggregateExpressions.map(_.mode).distinct.length > 2) { - throw new IllegalStateException( - s"$allAggregateExpressions should have no more than 2 kinds of modes.") - } - // Remember spill data size of this task before execute this operator so that we can // figure out how many bytes we spilled for this operator. private val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled - // - // The modes of AggregateExpressions. Right now, we can handle the following mode: - // - Partial-only: - // All AggregateExpressions have the mode of Partial. - // For this case, aggregationMode is (Some(Partial), None). - // - PartialMerge-only: - // All AggregateExpressions have the mode of PartialMerge). - // For this case, aggregationMode is (Some(PartialMerge), None). - // - Final-only: - // All AggregateExpressions have the mode of Final. - // For this case, aggregationMode is (Some(Final), None). - // - Final-Complete: - // Some AggregateExpressions have the mode of Final and - // others have the mode of Complete. For this case, - // aggregationMode is (Some(Final), Some(Complete)). - // - Complete-only: - // nonCompleteAggregateExpressions is empty and we have AggregateExpressions - // with mode Complete in completeAggregateExpressions. For this case, - // aggregationMode is (None, Some(Complete)). - // - Grouping-only: - // There is no AggregateExpression. For this case, AggregationMode is (None,None). - // - private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = { - nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> - completeAggregateExpressions.map(_.mode).distinct.headOption - } - - // Initialize all AggregateFunctions by binding references, if necessary, - // and setting inputBufferOffset and mutableBufferOffset. - private def initializeAllAggregateFunctions( - startingInputBufferOffset: Int): Array[AggregateFunction] = { - var mutableBufferOffset = 0 - var inputBufferOffset: Int = startingInputBufferOffset - val functions = new Array[AggregateFunction](allAggregateExpressions.length) - var i = 0 - while (i < allAggregateExpressions.length) { - val func = allAggregateExpressions(i).aggregateFunction - val aggregateExpressionIsNonComplete = i < nonCompleteAggregateExpressions.length - // We need to use this mode instead of func.mode in order to handle aggregation mode switching - // when switching to sort-based aggregation: - val mode = if (aggregateExpressionIsNonComplete) aggregationMode._1 else aggregationMode._2 - val funcWithBoundReferences = mode match { - case Some(Partial) | Some(Complete) if func.isInstanceOf[ImperativeAggregate] => - // We need to create BoundReferences if the function is not an - // expression-based aggregate function (it does not support code-gen) and the mode of - // this function is Partial or Complete because we will call eval of this - // function's children in the update method of this aggregate function. - // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, originalInputAttributes) - case _ => - // We only need to set inputBufferOffset for aggregate functions with mode - // PartialMerge and Final. - val updatedFunc = func match { - case function: ImperativeAggregate => - function.withNewInputAggBufferOffset(inputBufferOffset) - case function => function - } - inputBufferOffset += func.aggBufferSchema.length - updatedFunc - } - val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match { - case function: ImperativeAggregate => - // Set mutableBufferOffset for this function. It is important that setting - // mutableBufferOffset happens after all potential bindReference operations - // because bindReference will create a new instance of the function. - function.withNewMutableAggBufferOffset(mutableBufferOffset) - case function => function - } - mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length - functions(i) = funcWithUpdatedAggBufferOffset - i += 1 - } - functions - } - - private[this] var allAggregateFunctions: Array[AggregateFunction] = - initializeAllAggregateFunctions(initialInputBufferOffset) - - // Positions of those imperative aggregate functions in allAggregateFunctions. - // For example, say that we have func1, func2, func3, func4 in aggregateFunctions, and - // func2 and func3 are imperative aggregate functions. Then - // allImperativeAggregateFunctionPositions will be [1, 2]. Note that this does not need to be - // updated when falling back to sort-based aggregation because the positions of the aggregate - // functions do not change in that case. - private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { - val positions = new ArrayBuffer[Int]() - var i = 0 - while (i < allAggregateFunctions.length) { - allAggregateFunctions(i) match { - case agg: DeclarativeAggregate => - case _ => positions += i - } - i += 1 - } - positions.toArray - } - /////////////////////////////////////////////////////////////////////////// // Part 2: Methods and fields used by setting aggregation buffer values, // processing input rows from inputIter, and generating output // rows. /////////////////////////////////////////////////////////////////////////// - // The projection used to initialize buffer values for all expression-based aggregates. - // Note that this projection does not need to be updated when switching to sort-based aggregation - // because the schema of empty aggregation buffers does not change in that case. - private[this] val expressionAggInitialProjection: MutableProjection = { - val initExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.initialValues - // For the positions corresponding to imperative aggregate functions, we'll use special - // no-op expressions which are ignored during projection code-generation. - case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) - } - newMutableProjection(initExpressions, Nil)() - } - // Creates a new aggregation buffer and initializes buffer values. - // This function should be only called at most three times (when we create the hash map, - // when we switch to sort-based aggregation, and when we create the re-used buffer for - // sort-based aggregation). + // This function should be only called at most two times (when we create the hash map, + // and when we create the re-used buffer for sort-based aggregation). private def createNewAggregationBuffer(): UnsafeRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType)) .apply(new GenericMutableRow(bufferSchema.length)) // Initialize declarative aggregates' buffer values expressionAggInitialProjection.target(buffer)(EmptyRow) // Initialize imperative aggregates' buffer values - allAggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) + aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) buffer } - // Creates a function used to process a row based on the given inputAttributes. - private def generateProcessRow( - inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = { - - val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes) - val joinedRow = new JoinedRow() - - aggregationMode match { - // Partial-only - case (Some(Partial), None) => - val updateExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val imperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - val expressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - expressionAggUpdateProjection.target(currentBuffer) - // Process all expression-based aggregate functions. - expressionAggUpdateProjection(joinedRow(currentBuffer, row)) - // Process all imperative aggregate functions - var i = 0 - while (i < imperativeAggregateFunctions.length) { - imperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // PartialMerge-only or Final-only - case (Some(PartialMerge), None) | (Some(Final), None) => - val mergeExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val imperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - // This projection is used to merge buffer values for all expression-based aggregates. - val expressionAggMergeProjection = - newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - // Process all expression-based aggregate functions. - expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < imperativeAggregateFunctions.length) { - imperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Final-Complete - case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - val nonCompleteAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = - nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - val mergeExpressions = - nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } ++ completeOffsetExpressions - val finalMergeProjection = - newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() - - // We do not touch buffer values of aggregate functions with the Final mode. - val finalOffsetExpressions = - Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - val input = joinedRow(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - - // For all aggregate functions with mode Final, merge buffer values in row to - // currentBuffer. - finalMergeProjection.target(currentBuffer)(input) - i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Complete-only - case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val updateExpressions = completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // Grouping only. - case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {} - - case other => - throw new IllegalStateException( - s"${aggregationMode} should not be passed into TungstenAggregationIterator.") - } - } - // Creates a function used to generate output rows. - private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { - - val groupingAttributes = groupingExpressions.map(_.toAttribute) - val bufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes) - - aggregationMode match { - // Partial-only or PartialMerge-only: every output row is basically the values of - // the grouping expressions and the corresponding aggregation buffer. - case (Some(Partial), None) | (Some(PartialMerge), None) => - val groupingKeySchema = StructType.fromAttributes(groupingAttributes) - val bufferSchema = StructType.fromAttributes(bufferAttributes) - val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - unsafeRowJoiner.join(currentGroupingKey, currentBuffer) - } - - // Final-only, Complete-only and Final-Complete: a output row is generated based on - // resultExpressions. - case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => - val joinedRow = new JoinedRow() - val evalExpressions = allAggregateFunctions.map { - case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction => NoOp - } - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() - // These are the attributes of the row produced by `expressionAggEvalProjection` - val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes - val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) - expressionAggEvalProjection.target(aggregateResult) - val resultProjection = - UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema) - - val allImperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - // Generate results for all expression-based aggregate functions. - expressionAggEvalProjection(currentBuffer) - // Generate results for all imperative aggregate functions. - var i = 0 - while (i < allImperativeAggregateFunctions.length) { - aggregateResult.update( - allImperativeAggregateFunctionPositions(i), - allImperativeAggregateFunctions(i).eval(currentBuffer)) - i += 1 - } - resultProjection(joinedRow(currentGroupingKey, aggregateResult)) - } - - // Grouping-only: a output row is generated from values of grouping expressions. - case (None, None) => - val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - resultProjection(currentGroupingKey) - } - - case other => - throw new IllegalStateException( - s"${aggregationMode} should not be passed into TungstenAggregationIterator.") + override protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.nonEmpty && !modes.contains(Final) && !modes.contains(Complete)) { + // Fast path for partial aggregation, UnsafeRowJoiner is usually faster than projection + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + unsafeRowJoiner.join(currentGroupingKey, currentBuffer.asInstanceOf[UnsafeRow]) + } + } else { + super.generateResultProjection() } } - // An UnsafeProjection used to extract grouping keys from the input rows. - private[this] val groupProjection = - UnsafeProjection.create(groupingExpressions, originalInputAttributes) - - // A function used to process a input row. Its first argument is the aggregation buffer - // and the second argument is the input row. - private[this] var processRow: (UnsafeRow, InternalRow) => Unit = - generateProcessRow(originalInputAttributes) - - // A function used to generate output rows based on the grouping keys (first argument) - // and the corresponding aggregation buffer (second argument). - private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow = - generateResultProjection() - // An aggregation buffer containing initial buffer values. It is used to // initialize other aggregation buffers. private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer() @@ -482,7 +158,7 @@ class TungstenAggregationIterator( // all groups and their corresponding aggregation buffers for hash-based aggregation. private[this] val hashMap = new UnsafeFixedWidthAggregationMap( initialAggregationBuffer, - StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)), + StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), TaskContext.get().taskMemoryManager(), 1024 * 16, // initial capacity @@ -499,7 +175,7 @@ class TungstenAggregationIterator( if (groupingExpressions.isEmpty) { // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. - val groupingKey = groupProjection.apply(null) + val groupingKey = groupingProjection.apply(null) val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) while (inputIter.hasNext) { val newInput = inputIter.next() @@ -511,7 +187,7 @@ class TungstenAggregationIterator( while (inputIter.hasNext) { val newInput = inputIter.next() numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) + val groupingKey = groupingProjection.apply(newInput) var buffer: UnsafeRow = null if (i < fallbackStartsAt) { buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) @@ -565,25 +241,18 @@ class TungstenAggregationIterator( private def switchToSortBasedAggregation(): Unit = { logInfo("falling back to sort based aggregation.") - // Set aggregationMode, processRow, and generateOutput for sort-based aggregation. - val newAggregationMode = aggregationMode match { - case (Some(Partial), None) => (Some(PartialMerge), None) - case (None, Some(Complete)) => (Some(Final), None) - case (Some(Final), Some(Complete)) => (Some(Final), None) + // Basically the value of the KVIterator returned by externalSorter + // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it. + val newExpressions = aggregateExpressions.map { + case agg @ AggregateExpression(_, Partial, _) => + agg.copy(mode = PartialMerge) + case agg @ AggregateExpression(_, Complete, _) => + agg.copy(mode = Final) case other => other } - aggregationMode = newAggregationMode - - allAggregateFunctions = initializeAllAggregateFunctions(startingInputBufferOffset = 0) - - // Basically the value of the KVIterator returned by externalSorter - // will just aggregation buffer. At here, we use inputAggBufferAttributes. - val newInputAttributes: Seq[Attribute] = - allAggregateFunctions.flatMap(_.inputAggBufferAttributes) - - // Set up new processRow and generateOutput. - processRow = generateProcessRow(newInputAttributes) - generateOutput = generateResultProjection() + val newFunctions = initializeAggregateFunctions(newExpressions, 0) + val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) + sortBasedProcessRow = generateProcessRow(newExpressions, newFunctions, newInputAttributes) // Step 5: Get the sorted iterator from the externalSorter. sortedKVIterator = externalSorter.sortedIterator() @@ -632,6 +301,9 @@ class TungstenAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + // The function used to process rows in a group + private[this] var sortBasedProcessRow: (MutableRow, InternalRow) => Unit = null + // Processes rows in the current group. It will stop when it find a new group. private def processCurrentSortedGroup(): Unit = { // First, we need to copy nextGroupingKey to currentGroupingKey. @@ -640,7 +312,7 @@ class TungstenAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + sortBasedProcessRow(sortBasedAggregationBuffer, firstRowInNextGroup) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -655,16 +327,15 @@ class TungstenAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey.equals(groupingKey)) { - processRow(sortBasedAggregationBuffer, inputAggregationBuffer) + sortBasedProcessRow(sortBasedAggregationBuffer, inputAggregationBuffer) hasNext = sortedKVIterator.next() } else { // We find a new group. findNextPartition = true // copyFrom will fail when - nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy() - firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy() - + nextGroupingKey.copyFrom(groupingKey) + firstRowInNextGroup.copyFrom(inputAggregationBuffer) } } // We have not seen a new group. It means that there is no new row in the input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 76b938cdb694e..83379ae90f703 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -42,16 +42,45 @@ object Utils { SortBasedAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = Nil, - nonCompleteAggregateAttributes = Nil, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateAttributes, initialInputBufferOffset = 0, resultExpressions = resultExpressions, child = child ) :: Nil } + private def createAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]] = None, + groupingExpressions: Seq[NamedExpression] = Nil, + aggregateExpressions: Seq[AggregateExpression] = Nil, + aggregateAttributes: Seq[Attribute] = Nil, + initialInputBufferOffset: Int = 0, + resultExpressions: Seq[NamedExpression] = Nil, + child: SparkPlan): SparkPlan = { + val usesTungstenAggregate = TungstenAggregate.supportsAggregate( + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } else { + SortBasedAggregate( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } + } + def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], @@ -59,9 +88,6 @@ object Utils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. - val usesTungstenAggregate = TungstenAggregate.supportsAggregate( - groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // 1. Create an Aggregate Operator for partial aggregations. @@ -73,29 +99,14 @@ object Utils { groupingAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - val partialAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], + val partialAggregate = createAggregate( + requiredChildDistributionExpressions = None, groupingExpressions = groupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, + aggregateExpressions = partialAggregateExpressions, + aggregateAttributes = partialAggregateAttributes, initialInputBufferOffset = 0, resultExpressions = partialResultExpressions, child = child) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], - groupingExpressions = groupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialResultExpressions, - child = child) - } // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) @@ -105,29 +116,14 @@ object Utils { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } - val finalAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = groupingExpressions.length, - resultExpressions = resultExpressions, - child = partialAggregate) - } else { - SortBasedAggregate( + val finalAggregate = createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, child = partialAggregate) - } finalAggregate :: Nil } @@ -140,99 +136,99 @@ object Utils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct - val usesTungstenAggregate = TungstenAggregate.supportsAggregate( - groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one // DISTINCT aggregate function, all of those functions will have the same column expressions. // For example, it would be valid for functionsWithDistinct to be // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is // disallowed because those two distinct aggregates have different column expressions. - val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children - val namedDistinctColumnExpressions = distinctColumnExpressions.map { + val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children + val namedDistinctExpressions = distinctExpressions.map { case ne: NamedExpression => ne case other => Alias(other, other.toString)() } - val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute) + val distinctAttributes = namedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) // 1. Create an Aggregate Operator for partial aggregations. val partialAggregate: SparkPlan = { - val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val partialAggregateAttributes = - partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } // We will group by the original grouping expression, plus an additional expression for the // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. - val partialAggregateGroupingExpressions = - groupingExpressions ++ namedDistinctColumnExpressions - val partialAggregateResult = - groupingAttributes ++ - distinctColumnAttributes ++ - partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialAggregateResult, - child = child) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialAggregateResult, - child = child) - } + createAggregate( + groupingExpressions = groupingExpressions ++ namedDistinctExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) } // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregate: SparkPlan = { - val partialMergeAggregateExpressions = - functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val partialMergeAggregateResult = - groupingAttributes ++ - distinctColumnAttributes ++ - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes ++ distinctColumnAttributes, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes ++ distinctColumnAttributes, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes ++ distinctAttributes), + groupingExpressions = groupingAttributes ++ distinctAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate) + } + + // 3. Create an Aggregate operator for partial aggregation (for distinct) + val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap + val rewrittenDistinctFunctions = functionsWithDistinct.map { + // Children of an AggregateFunction with DISTINCT keyword has already + // been evaluated. At here, we need to replace original children + // to AttributeReferences. + case agg @ AggregateExpression(aggregateFunction, mode, true) => + aggregateFunction.transformDown(distinctColumnAttributeLookup) + .asInstanceOf[AggregateFunction] } - // 3. Create an Aggregate Operator for the final aggregation. + val partialDistinctAggregate: SparkPlan = { + val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val mergeAggregateAttributes = mergeAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. + val expr = AggregateExpression(func, Partial, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + (expr, attr) + }.unzip + + val partialAggregateResult = groupingAttributes ++ + mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ + distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + createAggregate( + groupingExpressions = groupingAttributes, + aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = partialAggregateResult, + child = partialMergeAggregate) + } + + // 4. Create an Aggregate Operator for the final aggregation. val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result @@ -241,49 +237,27 @@ object Utils { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } - val distinctColumnAttributeLookup = - distinctColumnExpressions.zip(distinctColumnAttributes).toMap - val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { - // Children of an AggregateFunction with DISTINCT keyword has already - // been evaluated. At here, we need to replace original children - // to AttributeReferences. - case agg @ AggregateExpression(aggregateFunction, mode, true) => - val rewrittenAggregateFunction = aggregateFunction - .transformDown(distinctColumnAttributeLookup) - .asInstanceOf[AggregateFunction] + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. - val rewrittenAggregateExpression = - AggregateExpression(rewrittenAggregateFunction, Complete, isDistinct = true) - - val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) - (rewrittenAggregateExpression, aggregateFunctionAttribute) + val expr = AggregateExpression(func, Final, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + (expr, attr) }.unzip - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = resultExpressions, - child = partialMergeAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = resultExpressions, - child = partialMergeAggregate) - } + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = partialDistinctAggregate) } finalAndCompleteAggregate :: Nil diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 064c0004b801e..5550198c02fbf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ -import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} @@ -552,80 +551,73 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } test("single distinct column set") { - Seq(true, false).foreach { specializeSingleDistinctAgg => - val conf = - (SQLConf.SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING.key, - specializeSingleDistinctAgg.toString) - withSQLConf(conf) { - // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. - checkAnswer( - sqlContext.sql( - """ - |SELECT - | min(distinct value1), - | sum(distinct value1), - | avg(value1), - | avg(value2), - | max(distinct value1) - |FROM agg2 - """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100)) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | mydoubleavg(distinct value1), - | avg(value1), - | avg(value2), - | key, - | mydoubleavg(value1 - 1), - | mydoubleavg(distinct value1) * 0.1, - | avg(value1 + value2) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: - Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: - Row(null, null, 3.0, 3, null, null, null) :: - Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | key, - | mydoubleavg(distinct value1), - | mydoublesum(value2), - | mydoublesum(distinct value1), - | mydoubleavg(distinct value1), - | mydoubleavg(value1) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: - Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: - Row(3, null, 3.0, null, null, null) :: - Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | count(value1), - | count(*), - | count(1), - | count(DISTINCT value1), - | key - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(3, 3, 3, 2, 1) :: - Row(3, 4, 4, 2, 2) :: - Row(0, 2, 2, 0, 3) :: - Row(3, 4, 4, 3, null) :: Nil) - } - } + // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | min(distinct value1), + | sum(distinct value1), + | avg(value1), + | avg(value2), + | max(distinct value1) + |FROM agg2 + """.stripMargin), + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoubleavg(distinct value1), + | avg(value1), + | avg(value2), + | key, + | mydoubleavg(value1 - 1), + | mydoubleavg(distinct value1) * 0.1, + | avg(value1 + value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: + Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: + Row(null, null, 3.0, 3, null, null, null) :: + Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoubleavg(distinct value1), + | mydoublesum(value2), + | mydoublesum(distinct value1), + | mydoubleavg(distinct value1), + | mydoubleavg(value1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: + Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: + Row(3, null, 3.0, null, null, null) :: + Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) } test("single distinct multiple columns set") { From 0acee0c9b3dc4f45232bea674573c831cb4def39 Mon Sep 17 00:00:00 2001 From: yucai Date: Sun, 13 Dec 2015 23:08:21 -0800 Subject: [PATCH 1132/1168] [SPARK-12275][SQL] No plan for BroadcastHint in some condition When SparkStrategies.BasicOperators's "case BroadcastHint(child) => apply(child)" is hit, it only recursively invokes BasicOperators.apply with this "child". It makes many strategies have no change to process this plan, which probably leads to "No plan" issue, so we use planLater to go through all strategies. https://issues.apache.org/jira/browse/SPARK-12275 Author: yucai Closes #10265 from yucai/broadcast_hint. --- .../org/apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameJoinSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 25e98c0bdd431..688555cf136e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -364,7 +364,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil - case BroadcastHint(child) => apply(child) + case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 56ad71ea4f487..c70397f9853ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -120,5 +120,12 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // planner should not crash without a join broadcast(df1).queryExecution.executedPlan + + // SPARK-12275: no physical plan for BroadcastHint in some condition + withTempPath { path => + df1.write.parquet(path.getCanonicalPath) + val pf1 = sqlContext.read.parquet(path.getCanonicalPath) + assert(df1.join(broadcast(pf1)).count() === 4) + } } } From bbdb3e7777e70c8839e689e6e01b29c8cb058ebc Mon Sep 17 00:00:00 2001 From: BenFradet Date: Mon, 14 Dec 2015 13:50:30 +0000 Subject: [PATCH 1133/1168] [MINOR][DOC] Fix broken word2vec link Follow-up of [SPARK-12199](https://issues.apache.org/jira/browse/SPARK-12199) and #10193 where a broken link has been left as is. Author: BenFradet Closes #10282 from BenFradet/SPARK-12199. --- docs/ml-features.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 158f3f201899c..677e4bfb916e8 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -63,7 +63,7 @@ the [IDF Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IDF) for mor `Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` transforms each document into a vector using the average of all words in the document; this vector can then be used for as features for prediction, document similarity calculations, etc. -Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2Vec) for more +Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2vec) for more details. In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. From 50a33d8c8bf64db304755c3cd893095fcb3b841b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 14 Dec 2015 09:59:42 -0800 Subject: [PATCH 1134/1168] [SPARK-12016] [MLLIB] [PYSPARK] Wrap Word2VecModel when loading it in pyspark JIRA: https://issues.apache.org/jira/browse/SPARK-12016 We should not directly use Word2VecModel in pyspark. We need to wrap it in a Word2VecModelWrapper when loading it in pyspark. Author: Liang-Chi Hsieh Closes #10100 from viirya/fix-load-py-wordvecmodel. --- .../mllib/api/python/PythonMLLibAPI.scala | 33 ---------- .../api/python/Word2VecModelWrapper.scala | 62 +++++++++++++++++++ python/pyspark/mllib/feature.py | 6 +- 3 files changed, 67 insertions(+), 34 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 8d546e3d6099b..29160a10e16b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -680,39 +680,6 @@ private[python] class PythonMLLibAPI extends Serializable { } } - private[python] class Word2VecModelWrapper(model: Word2VecModel) { - def transform(word: String): Vector = { - model.transform(word) - } - - /** - * Transforms an RDD of words to its vector representation - * @param rdd an RDD of words - * @return an RDD of vector representations of words - */ - def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = { - rdd.rdd.map(model.transform) - } - - def findSynonyms(word: String, num: Int): JList[Object] = { - val vec = transform(word) - findSynonyms(vec, num) - } - - def findSynonyms(vector: Vector, num: Int): JList[Object] = { - val result = model.findSynonyms(vector, num) - val similarity = Vectors.dense(result.map(_._2)) - val words = result.map(_._1) - List(words, similarity).map(_.asInstanceOf[Object]).asJava - } - - def getVectors: JMap[String, JList[Float]] = { - model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava - } - - def save(sc: SparkContext, path: String): Unit = model.save(sc, path) - } - /** * Java stub for Python mllib DecisionTree.train(). * This stub returns a handle to the Java object instead of the content of the Java object. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala new file mode 100644 index 0000000000000..0f55980481dcb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.feature.Word2VecModel +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * Wrapper around Word2VecModel to provide helper methods in Python + */ +private[python] class Word2VecModelWrapper(model: Word2VecModel) { + def transform(word: String): Vector = { + model.transform(word) + } + + /** + * Transforms an RDD of words to its vector representation + * @param rdd an RDD of words + * @return an RDD of vector representations of words + */ + def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = { + rdd.rdd.map(model.transform) + } + + def findSynonyms(word: String, num: Int): JList[Object] = { + val vec = transform(word) + findSynonyms(vec, num) + } + + def findSynonyms(vector: Vector, num: Int): JList[Object] = { + val result = model.findSynonyms(vector, num) + val similarity = Vectors.dense(result.map(_._2)) + val words = result.map(_._1) + List(words, similarity).map(_.asInstanceOf[Object]).asJava + } + + def getVectors: JMap[String, JList[Float]] = { + model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 7b077b058c3fd..7254679ebb533 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -504,7 +504,8 @@ def load(cls, sc, path): """ jmodel = sc._jvm.org.apache.spark.mllib.feature \ .Word2VecModel.load(sc._jsc.sc(), path) - return Word2VecModel(jmodel) + model = sc._jvm.Word2VecModelWrapper(jmodel) + return Word2VecModel(model) @ignore_unicode_prefix @@ -546,6 +547,9 @@ class Word2Vec(object): >>> sameModel = Word2VecModel.load(sc, path) >>> model.transform("a") == sameModel.transform("a") True + >>> syms = sameModel.findSynonyms("a", 2) + >>> [s[0] for s in syms] + [u'b', u'c'] >>> from shutil import rmtree >>> try: ... rmtree(path) From 490e0bc12886e0bb3af9e004ba164cfe44f93d3b Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 14 Dec 2015 16:13:55 -0800 Subject: [PATCH 1135/1168] [SPARK-12327] Disable commented code lintr temporarily cc yhuai felixcheung shaneknapp Author: Shivaram Venkataraman Closes #10300 from shivaram/comment-lintr-disable. --- R/pkg/.lintr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/.lintr b/R/pkg/.lintr index 038236fc149e6..39c872663ad44 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE), commented_code_linter = NULL) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") From 4f28c601a847aa5be5b4c6982f9e9172e16cc0a2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 14 Dec 2015 16:48:11 -0800 Subject: [PATCH 1136/1168] [SPARK-12274][SQL] WrapOption should not have type constraint for child I think it was a mistake, and we have not catched it so far until https://github.com/apache/spark/pull/10260 which begin to check if the `fromRowExpression` is resolved. Author: Wenchen Fan Closes #10263 from cloud-fan/encoder. --- .../org/apache/spark/sql/catalyst/expressions/objects.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index b2facfda24446..96bc4fe67a985 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -296,15 +296,12 @@ case class UnwrapOption( * (in the case of reference types) equality with null. * @param child The expression to evaluate and wrap. */ -case class WrapOption(child: Expression) - extends UnaryExpression with ExpectsInputTypes { +case class WrapOption(child: Expression) extends UnaryExpression { override def dataType: DataType = ObjectType(classOf[Option[_]]) override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") From 8a8d663df3f5a6c5495d4b510617ee37a800decb Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 14 Dec 2015 18:33:45 -0800 Subject: [PATCH 1137/1168] [SPARK-12188][SQL][FOLLOW-UP] Code refactoring and comment correction in Dataset APIs marmbrus This PR is to address your comment. Thanks for your review! Author: gatorsmile Closes #10214 from gatorsmile/followup12188. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3bd18a14f9e8f..dc69822e92908 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -79,7 +79,7 @@ class Dataset[T] private[sql]( /** * The encoder where the expressions used to construct an object from an input row have been - * bound to the ordinals of the given schema. + * bound to the ordinals of this [[Dataset]]'s output schema. */ private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) From c8a084a4b3d31ab3a19783a483be97412b46b145 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 14 Dec 2015 19:42:16 -0800 Subject: [PATCH 1138/1168] [SPARK-12288] [SQL] Support UnsafeRow in Coalesce/Except/Intersect. Support UnsafeRow for the Coalesce/Except/Intersect. Could you review if my code changes are ok? davies Thank you! Author: gatorsmile Closes #10285 from gatorsmile/unsafeSupportCIE. --- .../spark/sql/execution/basicOperators.scala | 12 ++++++- .../execution/RowFormatConvertersSuite.scala | 35 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index a42aea0b96d43..b3e4688557ba0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -137,7 +137,7 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan { } } } - override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows) + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = @@ -250,7 +250,9 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode { child.execute().coalesce(numPartitions, shuffle = false) } + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** @@ -263,6 +265,10 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } + + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** @@ -275,6 +281,10 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } + + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 13d68a103a225..2328899bb2f8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -58,6 +58,41 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { assert(!preparedPlan.outputsUnsafeRows) } + test("coalesce can process unsafe rows") { + val plan = Coalesce(1, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 1) + assert(preparedPlan.outputsUnsafeRows) + } + + test("except can process unsafe rows") { + val plan = Except(outputsUnsafe, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 2) + assert(preparedPlan.outputsUnsafeRows) + } + + test("except requires all of its input rows' formats to agree") { + val plan = Except(outputsSafe, outputsUnsafe) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("intersect can process unsafe rows") { + val plan = Intersect(outputsUnsafe, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 2) + assert(preparedPlan.outputsUnsafeRows) + } + + test("intersect requires all of its input rows' formats to agree") { + val plan = Intersect(outputsSafe, outputsUnsafe) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + test("execute() fails an assertion if inputs rows are of different formats") { val e = intercept[AssertionError] { Union(Seq(outputsSafe, outputsUnsafe)).execute() From d9608beac4ad4b8c47e48d49c82a81b37ce6ee46 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 15 Dec 2015 11:38:57 +0000 Subject: [PATCH 1139/1168] [SPARK-12332][TRIVIAL][TEST] Fix minor typo in ResetSystemProperties Fix a minor typo (unbalanced bracket) in ResetSystemProperties. Author: Holden Karau Closes #10303 from holdenk/SPARK-12332-trivial-typo-in-ResetSystemProperties-comment. --- .../scala/org/apache/spark/util/ResetSystemProperties.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala index c58db5e606f7c..60fb7abb66d32 100644 --- a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala +++ b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala @@ -45,7 +45,7 @@ private[spark] trait ResetSystemProperties extends BeforeAndAfterEach { this: Su var oldProperties: Properties = null override def beforeEach(): Unit = { - // we need SerializationUtils.clone instead of `new Properties(System.getProperties()` because + // we need SerializationUtils.clone instead of `new Properties(System.getProperties())` because // the later way of creating a copy does not copy the properties but it initializes a new // Properties object with the given properties as defaults. They are not recognized at all // by standard Scala wrapper over Java Properties then. From cd4ce8b0114b540a55451036a9710ea7f560be83 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 15 Dec 2015 09:41:40 -0800 Subject: [PATCH 1140/1168] [STREAMING][MINOR] Fix typo in function name of StateImpl cc\ tdas zsxwing , please review. Thanks a lot. Author: jerryshao Closes #10305 from jerryshao/fix-typo-state-impl. --- streaming/src/main/scala/org/apache/spark/streaming/State.scala | 2 +- .../scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala | 2 +- .../scala/org/apache/spark/streaming/MapWithStateSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index b47bdda2c2137..42424d67d8838 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -206,7 +206,7 @@ private[streaming] class StateImpl[S] extends State[S] { * Update the internal data and flags in `this` to the given state that is going to be timed out. * This method allows `this` object to be reused across many state records. */ - def wrapTiminoutState(newState: S): Unit = { + def wrapTimingOutState(newState: S): Unit = { this.state = newState defined = true timingOut = true diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index ed95171f73ee1..fdf61674a37f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -67,7 +67,7 @@ private[streaming] object MapWithStateRDDRecord { // data returned if (removeTimedoutData && timeoutThresholdTime.isDefined) { newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => - wrappedState.wrapTiminoutState(state) + wrappedState.wrapTimingOutState(state) val returned = mappingFunction(batchTime, key, None, wrappedState) mappedData ++= returned newStateMap.remove(key) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 4b08085e09b1f..6b21433f1781b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -125,7 +125,7 @@ class MapWithStateSuite extends SparkFunSuite state.remove() testState(None, shouldBeRemoved = true) - state.wrapTiminoutState(3) + state.wrapTimingOutState(3) testState(Some(3), shouldBeTimingOut = true) } From e05f4551b91d179fcdfb8c171f077fc05debd81c Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 15 Dec 2015 16:29:39 -0800 Subject: [PATCH 1141/1168] [MINOR][ML] Rename weights to coefficients for examples/DeveloperApiExample Rename ```weights``` to ```coefficients``` for examples/DeveloperApiExample. cc mengxr jkbradley Author: Yanbo Liang Closes #10280 from yanboliang/spark-coefficients. --- .../examples/ml/JavaDeveloperApiExample.java | 22 +++++++++---------- .../examples/ml/DeveloperApiExample.scala | 16 +++++++------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 0b4c0d9ba9f8b..b9dd3ad957714 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -89,7 +89,7 @@ public static void main(String[] args) throws Exception { } if (sumPredictions != 0.0) { throw new Exception("MyJavaLogisticRegression predicted something other than 0," + - " even though all weights are 0!"); + " even though all coefficients are 0!"); } jsc.stop(); @@ -149,12 +149,12 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { // Extract columns from data using helper method. JavaRDD oldDataset = extractLabeledPoints(dataset).toJavaRDD(); - // Do learning to estimate the weight vector. + // Do learning to estimate the coefficients vector. int numFeatures = oldDataset.take(1).get(0).features().size(); - Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. + Vector coefficients = Vectors.zeros(numFeatures); // Learning would happen here. // Create a model, and return it. - return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this); + return new MyJavaLogisticRegressionModel(uid(), coefficients).setParent(this); } @Override @@ -173,12 +173,12 @@ public MyJavaLogisticRegression copy(ParamMap extra) { class MyJavaLogisticRegressionModel extends ClassificationModel { - private Vector weights_; - public Vector weights() { return weights_; } + private Vector coefficients_; + public Vector coefficients() { return coefficients_; } - public MyJavaLogisticRegressionModel(String uid, Vector weights) { + public MyJavaLogisticRegressionModel(String uid, Vector coefficients) { this.uid_ = uid; - this.weights_ = weights; + this.coefficients_ = coefficients; } private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); @@ -208,7 +208,7 @@ public String uid() { * modifier. */ public Vector predictRaw(Vector features) { - double margin = BLAS.dot(features, weights_); + double margin = BLAS.dot(features, coefficients_); // There are 2 classes (binary classification), so we return a length-2 vector, // where index i corresponds to class i (i = 0, 1). return Vectors.dense(-margin, margin); @@ -222,7 +222,7 @@ public Vector predictRaw(Vector features) { /** * Number of features the model was trained on. */ - public int numFeatures() { return weights_.size(); } + public int numFeatures() { return coefficients_.size(); } /** * Create a copy of the model. @@ -235,7 +235,7 @@ public Vector predictRaw(Vector features) { */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra) + return copyValues(new MyJavaLogisticRegressionModel(uid(), coefficients_), extra) .setParent(parent()); } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 3758edc56198a..c1f63c6a1dce3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -75,7 +75,7 @@ object DeveloperApiExample { prediction }.sum assert(sumPredictions == 0.0, - "MyLogisticRegression predicted something other than 0, even though all weights are 0!") + "MyLogisticRegression predicted something other than 0, even though all coefficients are 0!") sc.stop() } @@ -124,12 +124,12 @@ private class MyLogisticRegression(override val uid: String) // Extract columns from data using helper method. val oldDataset = extractLabeledPoints(dataset) - // Do learning to estimate the weight vector. + // Do learning to estimate the coefficients vector. val numFeatures = oldDataset.take(1)(0).features.size - val weights = Vectors.zeros(numFeatures) // Learning would happen here. + val coefficients = Vectors.zeros(numFeatures) // Learning would happen here. // Create a model, and return it. - new MyLogisticRegressionModel(uid, weights).setParent(this) + new MyLogisticRegressionModel(uid, coefficients).setParent(this) } override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra) @@ -142,7 +142,7 @@ private class MyLogisticRegression(override val uid: String) */ private class MyLogisticRegressionModel( override val uid: String, - val weights: Vector) + val coefficients: Vector) extends ClassificationModel[Vector, MyLogisticRegressionModel] with MyLogisticRegressionParams { @@ -163,7 +163,7 @@ private class MyLogisticRegressionModel( * confidence for that label. */ override protected def predictRaw(features: Vector): Vector = { - val margin = BLAS.dot(features, weights) + val margin = BLAS.dot(features, coefficients) // There are 2 classes (binary classification), so we return a length-2 vector, // where index i corresponds to class i (i = 0, 1). Vectors.dense(-margin, margin) @@ -173,7 +173,7 @@ private class MyLogisticRegressionModel( override val numClasses: Int = 2 /** Number of features the model was trained on. */ - override val numFeatures: Int = weights.size + override val numFeatures: Int = coefficients.size /** * Create a copy of the model. @@ -182,7 +182,7 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent) + copyValues(new MyLogisticRegressionModel(uid, coefficients), extra).setParent(parent) } } // scalastyle:on println From 960238944cff91baf9c65c14e15c654efc2eea93 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 15 Dec 2015 16:55:58 -0800 Subject: [PATCH 1142/1168] [SPARK-12271][SQL] Improve error message when Dataset.as[ ] has incompatible schemas. Author: Nong Li Closes #10260 from nongli/spark-11271. --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 2 +- .../sql/catalyst/encoders/ExpressionEncoder.scala | 1 + .../spark/sql/catalyst/expressions/objects.scala | 12 +++++++----- .../scala/org/apache/spark/sql/DatasetSuite.scala | 10 +++++++++- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9013fd050b5f9..ecff8605706de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -184,7 +184,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath - WrapOption(constructorFor(optType, path, newTypePath)) + WrapOption(constructorFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 3e8420ecb9ccf..363178b0e21a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -251,6 +251,7 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) + SimpleAnalyzer.checkAnalysis(analyzedPlan) val optimizedPlan = SimplifyCasts(analyzedPlan) // In order to construct instances of inner classes (for example those declared in a REPL cell), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 96bc4fe67a985..10ec75eca37f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -23,11 +23,9 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} -import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ /** @@ -295,13 +293,17 @@ case class UnwrapOption( * Converts the result of evaluating `child` into an option, checking both the isNull bit and * (in the case of reference types) equality with null. * @param child The expression to evaluate and wrap. + * @param optType The type of this option. */ -case class WrapOption(child: Expression) extends UnaryExpression { +case class WrapOption(child: Expression, optType: DataType) + extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = ObjectType(classOf[Option[_]]) override def nullable: Boolean = true + override def inputTypes: Seq[AbstractDataType] = optType :: Nil + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 542e4d6c43b9f..8f8db318261db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -481,10 +481,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData] assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3))) } -} + test("verify mismatching field names fail with a good error") { + val ds = Seq(ClassData("a", 1)).toDS() + val e = intercept[AnalysisException] { + ds.as[ClassData2].collect() + } + assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage) + } +} case class ClassData(a: String, b: Int) +case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) /** From e809ee02ca2bfad50956c4b992031aad43e70d5c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 15 Dec 2015 17:02:14 -0800 Subject: [PATCH 1143/1168] [SPARK-12236][SQL] JDBC filter tests all pass if filters are not really pushed down https://issues.apache.org/jira/browse/SPARK-12236 Currently JDBC filters are not tested properly. All the tests pass even if the filters are not pushed down due to Spark-side filtering. In this PR, Firstly, I corrected the tests to properly check the pushed down filters by removing Spark-side filtering. Also, `!=` was being tested which is actually not pushed down. So I removed them. Lastly, I moved the `stripSparkFilter()` function to `SQLTestUtils` as this functions would be shared for all tests for pushed down filters. This function would be also shared with ORC datasource as the filters for that are also not being tested properly. Author: hyukjinkwon Closes #10221 from HyukjinKwon/SPARK-12236. --- .../datasources/parquet/ParquetFilterSuite.scala | 15 --------------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 10 ++++------ .../org/apache/spark/sql/test/SQLTestUtils.scala | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index daf41bc292cc9..6178e37d2a585 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -110,21 +110,6 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } - /** - * Strip Spark-side filtering in order to check if a datasource filters rows correctly. - */ - protected def stripSparkFilter(df: DataFrame): DataFrame = { - val schema = df.schema - val childRDD = df - .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] - .child - .execute() - .map(row => Row.fromSeq(row.toSeq(schema))) - - sqlContext.createDataFrame(childRDD, schema) - } - test("filter pushdown - boolean") { withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 8c24aa3151bc1..a360947152996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -176,12 +176,10 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT * WHERE (simple predicates)") { - assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) - assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) - assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) - assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) - assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) - assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size === 0) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size === 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size === 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) } test("SELECT * WHERE (quoted strings)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 9214569f18e93..e87da1527c4d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -179,6 +179,21 @@ private[sql] trait SQLTestUtils try f finally sqlContext.sql(s"USE default") } + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val childRDD = df + .queryExecution + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .child + .execute() + .map(row => Row.fromSeq(row.toSeq(schema))) + + sqlContext.createDataFrame(childRDD, schema) + } + /** * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier * way to construct [[DataFrame]] directly out of local data without relying on implicits. From 01dc131cd7fcce980b7eb7117a9c2e22559490d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Tue, 15 Dec 2015 18:06:30 -0800 Subject: [PATCH 1144/1168] [SPARK-12105] [SQL] add convenient show functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: Jean-Baptiste Onofré Closes #10130 from jbonofre/SPARK-12105. --- .../org/apache/spark/sql/DataFrame.scala | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 497bd48266770..b69d4411425d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -160,17 +160,24 @@ class DataFrame private[sql]( } } + /** + * Compose the string representing rows for output + */ + def showString(): String = { + showString(20) + } + /** * Compose the string representing rows for output - * @param _numRows Number of rows to show + * @param numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ - private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { - val numRows = _numRows.max(0) + def showString(numRows: Int, truncate: Boolean = true): String = { + val _numRows = numRows.max(0) val sb = new StringBuilder - val takeResult = take(numRows + 1) - val hasMoreData = takeResult.length > numRows - val data = takeResult.take(numRows) + val takeResult = take(_numRows + 1) + val hasMoreData = takeResult.length > _numRows + val data = takeResult.take(_numRows) val numCols = schema.fieldNames.length // For array values, replace Seq and Array with square brackets @@ -224,10 +231,10 @@ class DataFrame private[sql]( sb.append(sep) - // For Data that has more than "numRows" records + // For Data that has more than "_numRows" records if (hasMoreData) { - val rowsString = if (numRows == 1) "row" else "rows" - sb.append(s"only showing top $numRows $rowsString\n") + val rowsString = if (_numRows == 1) "row" else "rows" + sb.append(s"only showing top $_numRows $rowsString\n") } sb.toString() From 8f55388baa71daa2ca3068092cfc8845be23fb45 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Dec 2015 18:11:53 -0800 Subject: [PATCH 1145/1168] [HOTFIX] Compile error from commit 31b3910 --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index b69d4411425d4..33b03be1138be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -234,7 +234,7 @@ class DataFrame private[sql]( // For Data that has more than "_numRows" records if (hasMoreData) { val rowsString = if (_numRows == 1) "row" else "rows" - sb.append(s"only showing top $_numRows $rowsString\n") + sb.append(s"only showing top ${_numRows} $rowsString\n") } sb.toString() From e45f4c2106ba749626d3b3b1d0df5da8ee1b2655 Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 15 Dec 2015 18:15:10 -0800 Subject: [PATCH 1146/1168] [SPARK-12056][CORE] Part 2 Create a TaskAttemptContext only after calling setConf This is continuation of SPARK-12056 where change is applied to SqlNewHadoopRDD.scala andrewor14 FYI Author: tedyu Closes #10164 from tedyu/master. --- .../spark/sql/execution/datasources/SqlNewHadoopRDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 56cb63d9eff2a..eea780cbaa7e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -148,14 +148,14 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } inputMetrics.setBytesReadCallback(bytesReadCallback) - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) private[this] var reader: RecordReader[Void, V] = null /** From 1ba6d148a10977cd634382780232d20d35ba7070 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 15 Dec 2015 18:16:22 -0800 Subject: [PATCH 1147/1168] [SPARK-12130] Replace shuffleManagerClass with shortShuffleMgrNames in ExternalShuffleBlockResolver Replace shuffleManagerClassName with shortShuffleMgrName is to reduce time of string's comparison. and put sort's comparison on the front. cc JoshRosen andrewor14 Author: Lianhui Wang Closes #10131 from lianhuiwang/spark-12130. --- .../scala/org/apache/spark/shuffle/ShuffleManager.scala | 4 ++++ .../org/apache/spark/shuffle/hash/HashShuffleManager.scala | 2 ++ .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 2 ++ .../main/scala/org/apache/spark/storage/BlockManager.scala | 2 +- .../network/shuffle/ExternalShuffleBlockResolver.java | 7 +++---- .../apache/spark/network/sasl/SaslIntegrationSuite.java | 3 +-- .../network/shuffle/ExternalShuffleBlockResolverSuite.java | 6 +++--- .../network/shuffle/ExternalShuffleIntegrationSuite.java | 4 ++-- 8 files changed, 18 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 978366d1a1d1b..a3444bf4daa3b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -28,6 +28,10 @@ import org.apache.spark.{TaskContext, ShuffleDependency} * boolean isDriver as parameters. */ private[spark] trait ShuffleManager { + + /** Return short name for the ShuffleManager */ + val shortName: String + /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index d2e2fc4c110a7..4f30da0878ee1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -34,6 +34,8 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf) + override val shortName: String = "hash" + /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 66b6bbc61fe8e..9b1a279528428 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -79,6 +79,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager */ private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + override val shortName: String = "sort" + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ed05143877e20..540e1ec003a2b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -200,7 +200,7 @@ private[spark] class BlockManager( val shuffleConfig = new ExecutorShuffleInfo( diskBlockManager.localDirs.map(_.toString), diskBlockManager.subDirsPerLocalDir, - shuffleManager.getClass.getName) + shuffleManager.shortName) val MAX_ATTEMPTS = 3 val SLEEP_TIME_SECS = 5 diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index e5cb68c8a4dbb..fe933ed650caf 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -183,11 +183,10 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { - return getHashBasedShuffleBlockData(executor, blockId); - } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager) - || "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager".equals(executor.shuffleManager)) { + if ("sort".equals(executor.shuffleManager) || "tungsten-sort".equals(executor.shuffleManager)) { return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + } else if ("hash".equals(executor.shuffleManager)) { + return getHashBasedShuffleBlockData(executor, blockId); } else { throw new UnsupportedOperationException( "Unsupported shuffle manager: " + executor.shuffleManager); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index f573d962fe361..0ea631ea14d70 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -221,8 +221,7 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable t) { // Register an executor so that the next steps work. ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo( - new String[] { System.getProperty("java.io.tmpdir") }, 1, - "org.apache.spark.shuffle.sort.SortShuffleManager"); + new String[] { System.getProperty("java.io.tmpdir") }, 1, "sort"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index a9958232a1d28..60a1b8b0451fe 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -83,7 +83,7 @@ public void testBadRequests() throws IOException { // Nonexistent shuffle block resolver.registerExecutor("app0", "exec3", - dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + dataContext.createExecutorInfo("sort")); try { resolver.getBlockData("app0", "exec3", "shuffle_1_1_0"); fail("Should have failed"); @@ -96,7 +96,7 @@ public void testBadRequests() throws IOException { public void testSortShuffleBlocks() throws IOException { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + dataContext.createExecutorInfo("sort")); InputStream block0Stream = resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); @@ -115,7 +115,7 @@ public void testSortShuffleBlocks() throws IOException { public void testHashShuffleBlocks() throws IOException { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); + dataContext.createExecutorInfo("hash")); InputStream block0Stream = resolver.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 2095f41d79c16..5e706bf401693 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -49,8 +49,8 @@ public class ExternalShuffleIntegrationSuite { static String APP_ID = "app-id"; - static String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; - static String HASH_MANAGER = "org.apache.spark.shuffle.hash.HashShuffleManager"; + static String SORT_MANAGER = "sort"; + static String HASH_MANAGER = "hash"; // Executor 0 is sort-based static TestShuffleDataContext dataContext0; From 8195550c89fb4eacad5c88bc69ac9108a325307a Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Tue, 15 Dec 2015 18:20:00 -0800 Subject: [PATCH 1148/1168] [SPARK-12351][MESOS] Add documentation about submitting Spark with mesos cluster mode. Adding more documentation about submitting jobs with mesos cluster mode. Author: Timothy Chen Closes #10086 from tnachen/mesos_supervise_docs. --- docs/running-on-mesos.md | 26 +++++++++++++++++++++----- docs/submitting-applications.md | 15 ++++++++++++++- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index a197d0e373027..3193e17853483 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -150,14 +150,30 @@ it does not need to be redundantly passed in as a system property. Spark on Mesos also supports cluster mode, where the driver is launched in the cluster and the client can find the results of the driver from the Mesos Web UI. -To use cluster mode, you must start the MesosClusterDispatcher in your cluster via the `sbin/start-mesos-dispatcher.sh` script, -passing in the Mesos master url (e.g: mesos://host:5050). +To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, +passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. -From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master url -to the url of the MesosClusterDispatcher (e.g: mesos://dispatcher:7077). You can view driver statuses on the +If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). + +From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master URL +to the URL of the `MesosClusterDispatcher` (e.g: mesos://dispatcher:7077). You can view driver statuses on the Spark cluster Web UI. -Note that jars or python files that are passed to spark-submit should be URIs reachable by Mesos slaves. +For example: +{% highlight bash %} +./bin/spark-submit \ + --class org.apache.spark.examples.SparkPi \ + --master mesos://207.184.161.138:7077 \ + --deploy-mode cluster + --supervise + --executor-memory 20G \ + --total-executor-cores 100 \ + http://path/to/examples.jar \ + 1000 +{% endhighlight %} + + +Note that jars or python files that are passed to spark-submit should be URIs reachable by Mesos slaves, as the Spark driver doesn't automatically upload local jars. # Mesos Run Modes diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index ac2a14eb56fea..acbb0f298fe47 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -115,6 +115,18 @@ export HADOOP_CONF_DIR=XXX --master spark://207.184.161.138:7077 \ examples/src/main/python/pi.py \ 1000 + +# Run on a Mesos cluster in cluster deploy mode with supervise +./bin/spark-submit \ + --class org.apache.spark.examples.SparkPi \ + --master mesos://207.184.161.138:7077 \ + --deploy-mode cluster + --supervise + --executor-memory 20G \ + --total-executor-cores 100 \ + http://path/to/examples.jar \ + 1000 + {% endhighlight %} # Master URLs @@ -132,9 +144,10 @@ The master URL passed to Spark can be in one of the following formats:
    ") + } else { + if (!forceAdd) { + stackTrace.remove() + } + } +} + +function expandAllThreadStackTrace(toggleButton) { + $('.accordion-heading').each(function() { + //get thread ID + if (!$(this).hasClass("hidden")) { + var trId = $(this).attr('id').match(/thread_([0-9]+)_tr/m)[1] + toggleThreadStackTrace(trId, true) + } + }) + if (toggleButton) { + $('.expandbutton').toggleClass('hidden') + } +} + +function collapseAllThreadStackTrace(toggleButton) { + $('.accordion-body').each(function() { + $(this).remove() + }) + if (toggleButton) { + $('.expandbutton').toggleClass('hidden'); + } +} + + +// inOrOut - true: over, false: out +function onMouseOverAndOut(threadId) { + $("#" + threadId + "_td_id").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_name").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_state").toggleClass("threaddump-td-mouseover"); +} + +function onSearchStringChange() { + var searchString = $('#search').val().toLowerCase(); + //remove the stacktrace + collapseAllThreadStackTrace(false) + if (searchString.length == 0) { + $('tr').each(function() { + $(this).removeClass('hidden') + }) + } else { + $('tr').each(function(){ + if($(this).attr('id') && $(this).attr('id').match(/thread_[0-9]+_tr/) ) { + var children = $(this).children() + var found = false + for (i = 0; i < children.length; i++) { + if (children.eq(i).text().toLowerCase().indexOf(searchString) >= 0) { + found = true + } + } + if (found) { + $(this).removeClass('hidden') + } else { + $(this).addClass('hidden') + } + } + }); + } +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index c628a0c706553..b54e33a96fa23 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -221,10 +221,8 @@ a.expandbutton { cursor: pointer; } -.executor-thread { - background: #E6E6E6; -} - -.non-executor-thread { - background: #FAFAFA; +.threaddump-td-mouseover { + background-color: #49535a !important; + color: white; + cursor:pointer; } \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index b0a2cb4aa4d4b..58575d154ce5c 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -60,44 +60,49 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } } }.map { thread => - val threadName = thread.threadName - val className = "accordion-heading " + { - if (threadName.contains("Executor task launch")) { - "executor-thread" - } else { - "non-executor-thread" - } - } -
    - -
    + + + + + + } + +
    +

    Updated at {UIUtils.formatDate(time)}

    + { + // scalastyle:off +

    + Expand All +

    +

    +
    +
    +
    +
    + Search:
    +
    +

    + // scalastyle:on } - -
    -

    Updated at {UIUtils.formatDate(time)}

    - { - // scalastyle:off -

    - Expand All -

    -

    - // scalastyle:on - } -
    {dumpRows}
    -
    +
    MLlib modelPMML model
    `spark.mllib` modelPMML model
    spark.replClassServer.port(random) - Port for the driver's HTTP class server to listen on. - This is only relevant for the Spark shell. -
    spark.rpc.numRetries 3Jetty-based. Not used by TorrentBroadcast, which sends data through the block manager instead.
    ExecutorDriver(random)Class file serverspark.replClassServer.portJetty-based. Only used in Spark shells.
    Executor / Driver Executor / Driver
    spark.memory.offHeap.enabledtrue + If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. +
    spark.memory.offHeap.size0 + The absolute amount of memory which can be used for off-heap allocation. + This setting has no impact on heap memory usage, so if your executors' total memory consumption must fit within some hard limit then be sure to shrink your JVM heap size accordingly. + This must be set to a positive value when spark.memory.offHeap.enabled=true. +
    spark.memory.useLegacyMode false
    mesos://HOST:PORT Connect to the given Mesos cluster. The port must be whichever one your is configured to use, which is 5050 by default. Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... + To submit with --deploy-mode cluster, the HOST:PORT should be configured to connect to the MesosClusterDispatcher.
    yarn Connect to a YARN cluster in - client or cluster mode depending on the value of --deploy-mode. + client or cluster mode depending on the value of --deploy-mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable.
    yarn-client Equivalent to yarn with --deploy-mode client, From e7046b507a1edcaf5c92d26133477791dd23d89d Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 15 Dec 2015 18:21:00 -0800 Subject: [PATCH 1149/1168] [SPARK-9516][UI] Improvement of Thread Dump Page https://issues.apache.org/jira/browse/SPARK-9516 - [x] new look of Thread Dump Page - [x] click column title to sort - [x] grep - [x] search as you type squito JoshRosen It's ready for the review now Author: CodingCat Closes #7910 from CodingCat/SPARK-9516. --- .../org/apache/spark/ui/static/sorttable.js | 6 +- .../org/apache/spark/ui/static/table.js | 72 ++++++++++++++++++ .../org/apache/spark/ui/static/webui.css | 10 +-- .../ui/exec/ExecutorThreadDumpPage.scala | 73 ++++++++++--------- 4 files changed, 118 insertions(+), 43 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index a73d9a5cbc215..ff241470f32df 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -169,7 +169,7 @@ sorttable = { for (var i=0; i
    " +
    +            stackTraceText +  "
    {threadId}{thread.threadName}{thread.threadState}
    + + + + + + {dumpRows} +
    Thread IDThread NameThread State
    +
    }.getOrElse(Text("Error fetching thread dump")) UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent) } From 9c36948d30b15e1321c7b0b9fa9e87799fed6c78 Mon Sep 17 00:00:00 2001 From: "Richard W. Eggert II" Date: Tue, 15 Dec 2015 18:22:58 -0800 Subject: [PATCH 1150/1168] [SPARK-9026][SPARK-4514] Modifications to JobWaiter, FutureAction, and AsyncRDDActions to support non-blocking operation These changes rework the implementations of `SimpleFutureAction`, `ComplexFutureAction`, `JobWaiter`, and `AsyncRDDActions` such that asynchronous callbacks on the generated `Futures` NEVER block waiting for a job to complete. A small amount of mutex synchronization is necessary to protect the internal fields that manage cancellation, but these locks are only held very briefly and in practice should almost never cause any blocking to occur. The existing blocking APIs of these classes are retained, but they simply delegate to the underlying non-blocking API and `Await` the results with indefinite timeouts. Associated JIRA ticket: https://issues.apache.org/jira/browse/SPARK-9026 Also fixes: https://issues.apache.org/jira/browse/SPARK-4514 This pull request contains all my own original work, which I release to the Spark project under its open source license. Author: Richard W. Eggert II Closes #9264 from reggert/fix-futureaction. --- .../scala/org/apache/spark/FutureAction.scala | 164 +++++++----------- .../apache/spark/rdd/AsyncRDDActions.scala | 48 ++--- .../apache/spark/scheduler/DAGScheduler.scala | 8 +- .../apache/spark/scheduler/JobWaiter.scala | 48 ++--- .../test/scala/org/apache/spark/Smuggle.scala | 82 +++++++++ .../org/apache/spark/StatusTrackerSuite.scala | 26 +++ .../spark/rdd/AsyncRDDActionsSuite.scala | 33 +++- 7 files changed, 251 insertions(+), 158 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/Smuggle.scala diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 48792a958130c..2a8220ff40090 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -20,13 +20,15 @@ package org.apache.spark import java.util.Collections import java.util.concurrent.TimeUnit +import scala.concurrent._ +import scala.concurrent.duration.Duration +import scala.util.Try + +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaFutureAction import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter} +import org.apache.spark.scheduler.JobWaiter -import scala.concurrent._ -import scala.concurrent.duration.Duration -import scala.util.{Failure, Try} /** * A future for the result of an action to support cancellation. This is an extension of the @@ -105,6 +107,7 @@ trait FutureAction[T] extends Future[T] { * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include * count, collect, reduce. */ +@DeveloperApi class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) extends FutureAction[T] { @@ -116,142 +119,96 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { - if (!atMost.isFinite()) { - awaitResult() - } else jobWaiter.synchronized { - val finishTime = System.currentTimeMillis() + atMost.toMillis - while (!isCompleted) { - val time = System.currentTimeMillis() - if (time >= finishTime) { - throw new TimeoutException - } else { - jobWaiter.wait(finishTime - time) - } - } - } + jobWaiter.completionFuture.ready(atMost) this } @throws(classOf[Exception]) override def result(atMost: Duration)(implicit permit: CanAwait): T = { - ready(atMost)(permit) - awaitResult() match { - case scala.util.Success(res) => res - case scala.util.Failure(e) => throw e - } + jobWaiter.completionFuture.ready(atMost) + assert(value.isDefined, "Future has not completed properly") + value.get.get } override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) { - executor.execute(new Runnable { - override def run() { - func(awaitResult()) - } - }) + jobWaiter.completionFuture onComplete {_ => func(value.get)} } override def isCompleted: Boolean = jobWaiter.jobFinished override def isCancelled: Boolean = _cancelled - override def value: Option[Try[T]] = { - if (jobWaiter.jobFinished) { - Some(awaitResult()) - } else { - None - } - } - - private def awaitResult(): Try[T] = { - jobWaiter.awaitResult() match { - case JobSucceeded => scala.util.Success(resultFunc) - case JobFailed(e: Exception) => scala.util.Failure(e) - } - } + override def value: Option[Try[T]] = + jobWaiter.completionFuture.value.map {res => res.map(_ => resultFunc)} def jobIds: Seq[Int] = Seq(jobWaiter.jobId) } +/** + * Handle via which a "run" function passed to a [[ComplexFutureAction]] + * can submit jobs for execution. + */ +@DeveloperApi +trait JobSubmitter { + /** + * Submit a job for execution and return a FutureAction holding the result. + * This is a wrapper around the same functionality provided by SparkContext + * to enable cancellation. + */ + def submitJob[T, U, R]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + partitions: Seq[Int], + resultHandler: (Int, U) => Unit, + resultFunc: => R): FutureAction[R] +} + + /** * A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take, - * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the - * action thread if it is being blocked by a job. + * takeSample. Cancellation works by setting the cancelled flag to true and cancelling any pending + * jobs. */ -class ComplexFutureAction[T] extends FutureAction[T] { +@DeveloperApi +class ComplexFutureAction[T](run : JobSubmitter => Future[T]) + extends FutureAction[T] { self => - // Pointer to the thread that is executing the action. It is set when the action is run. - @volatile private var thread: Thread = _ + @volatile private var _cancelled = false - // A flag indicating whether the future has been cancelled. This is used in case the future - // is cancelled before the action was even run (and thus we have no thread to interrupt). - @volatile private var _cancelled: Boolean = false - - @volatile private var jobs: Seq[Int] = Nil + @volatile private var subActions: List[FutureAction[_]] = Nil // A promise used to signal the future. - private val p = promise[T]() + private val p = Promise[T]().tryCompleteWith(run(jobSubmitter)) - override def cancel(): Unit = this.synchronized { + override def cancel(): Unit = synchronized { _cancelled = true - if (thread != null) { - thread.interrupt() - } - } - - /** - * Executes some action enclosed in the closure. To properly enable cancellation, the closure - * should use runJob implementation in this promise. See takeAsync for example. - */ - def run(func: => T)(implicit executor: ExecutionContext): this.type = { - scala.concurrent.future { - thread = Thread.currentThread - try { - p.success(func) - } catch { - case e: Exception => p.failure(e) - } finally { - // This lock guarantees when calling `thread.interrupt()` in `cancel`, - // thread won't be set to null. - ComplexFutureAction.this.synchronized { - thread = null - } - } - } - this + p.tryFailure(new SparkException("Action has been cancelled")) + subActions.foreach(_.cancel()) } - /** - * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext - * to enable cancellation. - */ - def runJob[T, U, R]( + private def jobSubmitter = new JobSubmitter { + def submitJob[T, U, R]( rdd: RDD[T], processPartition: Iterator[T] => U, partitions: Seq[Int], resultHandler: (Int, U) => Unit, - resultFunc: => R) { - // If the action hasn't been cancelled yet, submit the job. The check and the submitJob - // command need to be in an atomic block. - val job = this.synchronized { + resultFunc: => R): FutureAction[R] = self.synchronized { + // If the action hasn't been cancelled yet, submit the job. The check and the submitJob + // command need to be in an atomic block. if (!isCancelled) { - rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) + val job = rdd.context.submitJob( + rdd, + processPartition, + partitions, + resultHandler, + resultFunc) + subActions = job :: subActions + job } else { throw new SparkException("Action has been cancelled") } } - - this.jobs = jobs ++ job.jobIds - - // Wait for the job to complete. If the action is cancelled (with an interrupt), - // cancel the job and stop the execution. This is not in a synchronized block because - // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. - try { - Await.ready(job, Duration.Inf) - } catch { - case e: InterruptedException => - job.cancel() - throw new SparkException("Action has been cancelled") - } } override def isCancelled: Boolean = _cancelled @@ -276,10 +233,11 @@ class ComplexFutureAction[T] extends FutureAction[T] { override def value: Option[Try[T]] = p.future.value - def jobIds: Seq[Int] = jobs + def jobIds: Seq[Int] = subActions.flatMap(_.jobIds) } + private[spark] class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T) extends JavaFutureAction[T] { @@ -303,7 +261,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S Await.ready(futureAction, timeout) futureAction.value.get match { case scala.util.Success(value) => converter(value) - case Failure(exception) => + case scala.util.Failure(exception) => if (isCancelled) { throw new CancellationException("Job cancelled").initCause(exception) } else { diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index d5e853613b05b..14f541f937b4c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -19,13 +19,12 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.util.ThreadUtils - import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext +import scala.concurrent.{Future, ExecutionContext} import scala.reflect.ClassTag -import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.{JobSubmitter, ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.util.ThreadUtils /** * A set of asynchronous RDD actions available through an implicit conversion. @@ -65,17 +64,23 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for retrieving the first num elements of the RDD. */ def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope { - val f = new ComplexFutureAction[Seq[T]] val callSite = self.context.getCallSite - - f.run { - // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which - // is a cached thread pool. - val results = new ArrayBuffer[T](num) - val totalParts = self.partitions.length - var partsScanned = 0 - self.context.setCallSite(callSite) - while (results.size < num && partsScanned < totalParts) { + val localProperties = self.context.getLocalProperties + // Cached thread pool to handle aggregation of subtasks. + implicit val executionContext = AsyncRDDActions.futureExecutionContext + val results = new ArrayBuffer[T](num) + val totalParts = self.partitions.length + + /* + Recursively triggers jobs to scan partitions until either the requested + number of elements are retrieved, or the partitions to scan are exhausted. + This implementation is non-blocking, asynchronously handling the + results of each job and triggering the next job using callbacks on futures. + */ + def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = + if (results.size >= num || partsScanned >= totalParts) { + Future.successful(results.toSeq) + } else { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 @@ -97,19 +102,20 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) val buf = new Array[Array[T]](p.size) - f.runJob(self, + self.context.setCallSite(callSite) + self.context.setLocalProperties(localProperties) + val job = jobSubmitter.submitJob(self, (it: Iterator[T]) => it.take(left).toArray, p, (index: Int, data: Array[T]) => buf(index) = data, Unit) - - buf.foreach(results ++= _.take(num - results.size)) - partsScanned += numPartsToTry + job.flatMap {_ => + buf.foreach(results ++= _.take(num - results.size)) + continue(partsScanned + numPartsToTry) + } } - results.toSeq - }(AsyncRDDActions.futureExecutionContext) - f + new ComplexFutureAction[Seq[T]](continue(0)(_)) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5582720bbcff2..8d0e0c8624a55 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map import scala.collection.mutable.{HashMap, HashSet, Stack} +import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.existentials import scala.language.postfixOps @@ -610,11 +611,12 @@ class DAGScheduler( properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) - waiter.awaitResult() match { - case JobSucceeded => + Await.ready(waiter.completionFuture, atMost = Duration.Inf) + waiter.completionFuture.value.get match { + case scala.util.Success(_) => logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) - case JobFailed(exception: Exception) => + case scala.util.Failure(exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 382b09422a4a0..4326135186a73 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -17,6 +17,10 @@ package org.apache.spark.scheduler +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{Future, Promise} + /** * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their * results to the given handler function. @@ -28,17 +32,15 @@ private[spark] class JobWaiter[T]( resultHandler: (Int, T) => Unit) extends JobListener { - private var finishedTasks = 0 - - // Is the job as a whole finished (succeeded or failed)? - @volatile - private var _jobFinished = totalTasks == 0 - - def jobFinished: Boolean = _jobFinished - + private val finishedTasks = new AtomicInteger(0) // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero // partition RDDs), we set the jobResult directly to JobSucceeded. - private var jobResult: JobResult = if (jobFinished) JobSucceeded else null + private val jobPromise: Promise[Unit] = + if (totalTasks == 0) Promise.successful(()) else Promise() + + def jobFinished: Boolean = jobPromise.isCompleted + + def completionFuture: Future[Unit] = jobPromise.future /** * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled @@ -49,29 +51,17 @@ private[spark] class JobWaiter[T]( dagScheduler.cancelJob(jobId) } - override def taskSucceeded(index: Int, result: Any): Unit = synchronized { - if (_jobFinished) { - throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") + override def taskSucceeded(index: Int, result: Any): Unit = { + // resultHandler call must be synchronized in case resultHandler itself is not thread safe. + synchronized { + resultHandler(index, result.asInstanceOf[T]) } - resultHandler(index, result.asInstanceOf[T]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - _jobFinished = true - jobResult = JobSucceeded - this.notifyAll() + if (finishedTasks.incrementAndGet() == totalTasks) { + jobPromise.success(()) } } - override def jobFailed(exception: Exception): Unit = synchronized { - _jobFinished = true - jobResult = JobFailed(exception) - this.notifyAll() - } + override def jobFailed(exception: Exception): Unit = + jobPromise.failure(exception) - def awaitResult(): JobResult = synchronized { - while (!_jobFinished) { - this.wait() - } - return jobResult - } } diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala b/core/src/test/scala/org/apache/spark/Smuggle.scala new file mode 100644 index 0000000000000..01694a6e6f741 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/Smuggle.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.UUID +import java.util.concurrent.locks.ReentrantReadWriteLock + +import scala.collection.mutable + +/** + * Utility wrapper to "smuggle" objects into tasks while bypassing serialization. + * This is intended for testing purposes, primarily to make locks, semaphores, and + * other constructs that would not survive serialization available from within tasks. + * A Smuggle reference is itself serializable, but after being serialized and + * deserialized, it still refers to the same underlying "smuggled" object, as long + * as it was deserialized within the same JVM. This can be useful for tests that + * depend on the timing of task completion to be deterministic, since one can "smuggle" + * a lock or semaphore into the task, and then the task can block until the test gives + * the go-ahead to proceed via the lock. + */ +class Smuggle[T] private(val key: Symbol) extends Serializable { + def smuggledObject: T = Smuggle.get(key) +} + + +object Smuggle { + /** + * Wraps the specified object to be smuggled into a serialized task without + * being serialized itself. + * + * @param smuggledObject + * @tparam T + * @return Smuggle wrapper around smuggledObject. + */ + def apply[T](smuggledObject: T): Smuggle[T] = { + val key = Symbol(UUID.randomUUID().toString) + lock.writeLock().lock() + try { + smuggledObjects += key -> smuggledObject + } finally { + lock.writeLock().unlock() + } + new Smuggle(key) + } + + private val lock = new ReentrantReadWriteLock + private val smuggledObjects = mutable.WeakHashMap.empty[Symbol, Any] + + private def get[T](key: Symbol) : T = { + lock.readLock().lock() + try { + smuggledObjects(key).asInstanceOf[T] + } finally { + lock.readLock().unlock() + } + } + + /** + * Implicit conversion of a Smuggle wrapper to the object being smuggled. + * + * @param smuggle the wrapper to unpack. + * @tparam T + * @return the smuggled object represented by the wrapper. + */ + implicit def unpackSmuggledObject[T](smuggle : Smuggle[T]): T = smuggle.smuggledObject + +} diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 46516e8d25298..5483f2b8434aa 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -86,4 +86,30 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont Set(firstJobId, secondJobId)) } } + + test("getJobIdsForGroup() with takeAsync()") { + sc = new SparkContext("local", "test", new SparkConf(false)) + sc.setJobGroup("my-job-group2", "description") + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty + val firstJobFuture = sc.parallelize(1 to 1000, 1).takeAsync(1) + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup("my-job-group2") should be (Seq(firstJobId)) + } + } + + test("getJobIdsForGroup() with takeAsync() across multiple partitions") { + sc = new SparkContext("local", "test", new SparkConf(false)) + sc.setJobGroup("my-job-group2", "description") + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty + val firstJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999) + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2 + } + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index ec99f2a1bad66..de015ebd5d237 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore -import scala.concurrent.{Await, TimeoutException} +import scala.concurrent._ import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark._ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts { @@ -197,4 +197,33 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim Await.result(f, Duration(20, "milliseconds")) } } + + private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = { + val executionContextInvoked = Promise[Unit] + val fakeExecutionContext = new ExecutionContext { + override def execute(runnable: Runnable): Unit = { + executionContextInvoked.success(()) + } + override def reportFailure(t: Throwable): Unit = () + } + val starter = Smuggle(new Semaphore(0)) + starter.drainPermits() + val rdd = sc.parallelize(1 to 100, 4).mapPartitions {itr => starter.acquire(1); itr} + val f = action(rdd) + f.onComplete(_ => ())(fakeExecutionContext) + // Here we verify that registering the callback didn't cause a thread to be consumed. + assert(!executionContextInvoked.isCompleted) + // Now allow the executors to proceed with task processing. + starter.release(rdd.partitions.length) + // Waiting for the result verifies that the tasks were successfully processed. + Await.result(executionContextInvoked.future, atMost = 15.seconds) + } + + test("SimpleFutureAction callback must not consume a thread while waiting") { + testAsyncAction(_.countAsync()) + } + + test("ComplexFutureAction callback must not consume a thread while waiting") { + testAsyncAction((_.takeAsync(100))) + } } From 442df7c8775c519b70fb3b8d250355722d100998 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 15 Dec 2015 18:24:23 -0800 Subject: [PATCH 1151/1168] [SPARK-10123][DEPLOY] Support specifying deploy mode from configuration Please help to review, thanks a lot. Author: jerryshao Closes #10195 from jerryshao/SPARK-10123. --- .../spark/deploy/SparkSubmitArguments.scala | 5 ++- .../spark/deploy/SparkSubmitSuite.scala | 41 +++++++++++++++++++ docs/configuration.md | 15 +++++-- .../apache/spark/launcher/SparkLauncher.java | 3 ++ .../launcher/SparkSubmitCommandBuilder.java | 7 ++-- 5 files changed, 64 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 18a1c52ae53fb..915ef81b4eae3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -176,7 +176,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull packagesExclusions = Option(packagesExclusions) .orElse(sparkProperties.get("spark.jars.excludes")).orNull - deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull + deployMode = Option(deployMode) + .orElse(sparkProperties.get("spark.submit.deployMode")) + .orElse(env.get("DEPLOY_MODE")) + .orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d494b0caab85f..2626f5a16dfb8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -136,6 +136,47 @@ class SparkSubmitSuite appArgs.childArgs should be (Seq("--master", "local", "some", "--weird", "args")) } + test("specify deploy mode through configuration") { + val clArgs = Seq( + "--master", "yarn", + "--conf", "spark.submit.deployMode=client", + "--class", "org.SomeClass", + "thejar.jar" + ) + val appArgs = new SparkSubmitArguments(clArgs) + val (_, _, sysProps, _) = prepareSubmitEnvironment(appArgs) + + appArgs.deployMode should be ("client") + sysProps("spark.submit.deployMode") should be ("client") + + // Both cmd line and configuration are specified, cmdline option takes the priority + val clArgs1 = Seq( + "--master", "yarn", + "--deploy-mode", "cluster", + "--conf", "spark.submit.deployMode=client", + "-class", "org.SomeClass", + "thejar.jar" + ) + val appArgs1 = new SparkSubmitArguments(clArgs1) + val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) + + appArgs1.deployMode should be ("cluster") + sysProps1("spark.submit.deployMode") should be ("cluster") + + // Neither cmdline nor configuration are specified, client mode is the default choice + val clArgs2 = Seq( + "--master", "yarn", + "--class", "org.SomeClass", + "thejar.jar" + ) + val appArgs2 = new SparkSubmitArguments(clArgs2) + appArgs2.deployMode should be (null) + + val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) + appArgs2.deployMode should be ("client") + sysProps2("spark.submit.deployMode") should be ("client") + } + test("handles YARN cluster mode") { val clArgs = Seq( "--deploy-mode", "cluster", diff --git a/docs/configuration.md b/docs/configuration.md index 55cf4b2dac5f5..38d3d059f9d31 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -48,7 +48,7 @@ The following format is accepted: 1y (years) -Properties that specify a byte size should be configured with a unit of size. +Properties that specify a byte size should be configured with a unit of size. The following format is accepted: 1b (bytes) @@ -192,6 +192,15 @@ of the most common options to set are: allowed master URL's. + + spark.submit.deployMode + (none) + + The deploy mode of Spark driver program, either "client" or "cluster", + Which means to launch driver program locally ("client") + or remotely ("cluster") on one of the nodes inside the cluster. + + Apart from these, the following properties are also available, and may be useful in some situations: @@ -1095,7 +1104,7 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.lookupTimeout 120s - Duration for an RPC remote endpoint lookup operation to wait before timing out. + Duration for an RPC remote endpoint lookup operation to wait before timing out. @@ -1559,7 +1568,7 @@ Apart from these, the following properties are also available, and may be useful spark.streaming.stopGracefullyOnShutdown false - If true, Spark shuts down the StreamingContext gracefully on JVM + If true, Spark shuts down the StreamingContext gracefully on JVM shutdown rather than immediately. diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index dd1c93af6ca4c..20e6003a00c19 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -40,6 +40,9 @@ public class SparkLauncher { /** The Spark master. */ public static final String SPARK_MASTER = "spark.master"; + /** The Spark deploy mode. */ + public static final String DEPLOY_MODE = "spark.submit.deployMode"; + /** Configuration key for the driver memory. */ public static final String DRIVER_MEMORY = "spark.driver.memory"; /** Configuration key for the driver class path. */ diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 312df0b269f32..a95f0f17517d1 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -294,10 +294,11 @@ private void constructEnvVarArgs( private boolean isClientMode(Map userProps) { String userMaster = firstNonEmpty(master, userProps.get(SparkLauncher.SPARK_MASTER)); - // Default master is "local[*]", so assume client mode in that case. + String userDeployMode = firstNonEmpty(deployMode, userProps.get(SparkLauncher.DEPLOY_MODE)); + // Default master is "local[*]", so assume client mode in that case return userMaster == null || - "client".equals(deployMode) || - (!userMaster.equals("yarn-cluster") && deployMode == null); + "client".equals(userDeployMode) || + (!userMaster.equals("yarn-cluster") && userDeployMode == null); } /** From f4aee47692612b3924c545c533f5683fc8754f06 Mon Sep 17 00:00:00 2001 From: Naveen Date: Tue, 15 Dec 2015 18:25:22 -0800 Subject: [PATCH 1152/1168] [SPARK-9886][CORE] Fix to use ShutdownHookManager in ExternalBlockStore.scala Author: Naveen Closes #10313 from naveenminchu/branch-fix-SPARK-9886. --- .../spark/storage/ExternalBlockStore.scala | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala index db965d54bafd6..94883a54a74e4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** @@ -177,15 +177,6 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: } } - private def addShutdownHook() { - Runtime.getRuntime.addShutdownHook(new Thread("ExternalBlockStore shutdown hook") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - externalBlockManager.map(_.shutdown()) - } - }) - } - // Create concrete block manager and fall back to Tachyon by default for backward compatibility. private def createBlkManager(): Option[ExternalBlockManager] = { val clsName = blockManager.conf.getOption(ExternalBlockStore.BLOCK_MANAGER_NAME) @@ -196,7 +187,10 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: .newInstance() .asInstanceOf[ExternalBlockManager] instance.init(blockManager, executorId) - addShutdownHook(); + ShutdownHookManager.addShutdownHook { () => + logDebug("Shutdown hook called") + externalBlockManager.map(_.shutdown()) + } Some(instance) } catch { case NonFatal(t) => From 64cf457dd6f217bbf1d4d1ef108716c4fd78510e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 15 Dec 2015 18:28:16 -0800 Subject: [PATCH 1153/1168] [SPARK-12062][CORE] Change Master to asyc rebuild UI when application completes This change builds the event history of completed apps asynchronously so the RPC thread will not be blocked and allow new workers to register/remove if the event log history is very large and takes a long time to rebuild. Author: Bryan Cutler Closes #10284 from BryanCutler/async-MasterUI-SPARK-12062. --- .../apache/spark/deploy/master/Master.scala | 79 ++++++++++++------- .../spark/deploy/master/MasterMessages.scala | 2 + 2 files changed, 52 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 1355e1ad1b523..fc42bf06e40a2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -21,9 +21,11 @@ import java.io.FileNotFoundException import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date -import java.util.concurrent.{ScheduledFuture, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.postfixOps import scala.util.Random @@ -56,6 +58,10 @@ private[deploy] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") + private val rebuildUIThread = + ThreadUtils.newDaemonSingleThreadExecutor("master-rebuild-ui-thread") + private val rebuildUIContext = ExecutionContext.fromExecutor(rebuildUIThread) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs @@ -78,7 +84,8 @@ private[deploy] class Master( private val addressToApp = new HashMap[RpcAddress, ApplicationInfo] private val completedApps = new ArrayBuffer[ApplicationInfo] private var nextAppNumber = 0 - private val appIdToUI = new HashMap[String, SparkUI] + // Using ConcurrentHashMap so that master-rebuild-ui-thread can add a UI after asyncRebuildUI + private val appIdToUI = new ConcurrentHashMap[String, SparkUI] private val drivers = new HashSet[DriverInfo] private val completedDrivers = new ArrayBuffer[DriverInfo] @@ -191,6 +198,7 @@ private[deploy] class Master( checkForWorkerTimeOutTask.cancel(true) } forwardMessageThread.shutdownNow() + rebuildUIThread.shutdownNow() webUi.stop() restServer.foreach(_.stop()) masterMetricsSystem.stop() @@ -367,6 +375,10 @@ private[deploy] class Master( case CheckForWorkerTimeOut => { timeOutDeadWorkers() } + + case AttachCompletedRebuildUI(appId) => + // An asyncRebuildSparkUI has completed, so need to attach to master webUi + Option(appIdToUI.get(appId)).foreach { ui => webUi.attachSparkUI(ui) } } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -809,7 +821,7 @@ private[deploy] class Master( if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { - appIdToUI.remove(a.id).foreach { ui => webUi.detachSparkUI(ui) } + Option(appIdToUI.remove(a.id)).foreach { ui => webUi.detachSparkUI(ui) } applicationMetricsSystem.removeSource(a.appSource) }) completedApps.trimStart(toRemove) @@ -818,7 +830,7 @@ private[deploy] class Master( waitingApps -= app // If application events are logged, use them to rebuild the UI - rebuildSparkUI(app) + asyncRebuildSparkUI(app) for (exec <- app.executors.values) { killExecutor(exec) @@ -923,49 +935,57 @@ private[deploy] class Master( * Return the UI if successful, else None */ private[master] def rebuildSparkUI(app: ApplicationInfo): Option[SparkUI] = { + val futureUI = asyncRebuildSparkUI(app) + Await.result(futureUI, Duration.Inf) + } + + /** Rebuild a new SparkUI asynchronously to not block RPC event loop */ + private[master] def asyncRebuildSparkUI(app: ApplicationInfo): Future[Option[SparkUI]] = { val appName = app.desc.name val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found" - try { - val eventLogDir = app.desc.eventLogDir - .getOrElse { - // Event logging is not enabled for this application - app.appUIUrlAtHistoryServer = Some(notFoundBasePath) - return None - } - + val eventLogDir = app.desc.eventLogDir + .getOrElse { + // Event logging is disabled for this application + app.appUIUrlAtHistoryServer = Some(notFoundBasePath) + return Future.successful(None) + } + val futureUI = Future { val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, appAttemptId = None, compressionCodecName = app.desc.eventLogCodec) + eventLogDir, app.id, appAttemptId = None, compressionCodecName = app.desc.eventLogCodec) val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) val inProgressExists = fs.exists(new Path(eventLogFilePrefix + - EventLoggingListener.IN_PROGRESS)) + EventLoggingListener.IN_PROGRESS)) - if (inProgressExists) { + val eventLogFile = if (inProgressExists) { // Event logging is enabled for this application, but the application is still in progress logWarning(s"Application $appName is still in progress, it may be terminated abnormally.") - } - - val (eventLogFile, status) = if (inProgressExists) { - (eventLogFilePrefix + EventLoggingListener.IN_PROGRESS, " (in progress)") + eventLogFilePrefix + EventLoggingListener.IN_PROGRESS } else { - (eventLogFilePrefix, " (completed)") + eventLogFilePrefix } val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) val replayBus = new ReplayListenerBus() val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), appName, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) - val maybeTruncated = eventLogFile.endsWith(EventLoggingListener.IN_PROGRESS) try { - replayBus.replay(logInput, eventLogFile, maybeTruncated) + replayBus.replay(logInput, eventLogFile, inProgressExists) } finally { logInput.close() } - appIdToUI(app.id) = ui - webUi.attachSparkUI(ui) + + Some(ui) + }(rebuildUIContext) + + futureUI.onSuccess { case Some(ui) => + appIdToUI.put(app.id, ui) + self.send(AttachCompletedRebuildUI(app.id)) // Application UI is successfully rebuilt, so link the Master UI to it + // NOTE - app.appUIUrlAtHistoryServer is volatile app.appUIUrlAtHistoryServer = Some(ui.basePath) - Some(ui) - } catch { + }(ThreadUtils.sameThread) + + futureUI.onFailure { case fnf: FileNotFoundException => // Event logging is enabled for this application, but no event logs are found val title = s"Application history not found (${app.id})" @@ -974,7 +994,7 @@ private[deploy] class Master( msg += " Did you specify the correct logging directory?" msg = URLEncoder.encode(msg, "UTF-8") app.appUIUrlAtHistoryServer = Some(notFoundBasePath + s"?msg=$msg&title=$title") - None + case e: Exception => // Relay exception message to application UI page val title = s"Application history load error (${app.id})" @@ -984,8 +1004,9 @@ private[deploy] class Master( msg = URLEncoder.encode(msg, "UTF-8") app.appUIUrlAtHistoryServer = Some(notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title") - None - } + }(ThreadUtils.sameThread) + + futureUI } /** Generate a new app ID given a app's submission date */ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index a952cee36eb44..a055d097674ce 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -39,4 +39,6 @@ private[master] object MasterMessages { case object BoundPortsRequest case class BoundPortsResponse(rpcEndpointPort: Int, webUIPort: Int, restPort: Option[Int]) + + case class AttachCompletedRebuildUI(appId: String) } From 84198a7cd8441a29528d1ff7db13afd76ffab701 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 15 Dec 2015 18:29:19 -0800 Subject: [PATCH 1154/1168] [SPARK-10477][SQL] using DSL in ColumnPruningSuite to improve readability Author: Wenchen Fan Closes #8645 from cloud-fan/test. --- .../spark/sql/catalyst/dsl/package.scala | 7 ++-- .../optimizer/ColumnPruningSuite.scala | 41 +++++++++++-------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index af594c25c54cb..e50971173c499 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -275,13 +275,14 @@ package object dsl { def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - // TODO specify the output column names def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, - alias: Option[String] = None): LogicalPlan = - Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) + alias: Option[String] = None, + outputNames: Seq[String] = Nil): LogicalPlan = + Generate(generator, join = join, outer = outer, alias, + outputNames.map(UnresolvedAttribute(_)), logicalPlan) def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 4a1e7ceaf394b..9bf61ae091786 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions.Explode import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -35,12 +35,11 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning for Generate when Generate.join = false") { val input = LocalRelation('a.int, 'b.array(StringType)) - val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze + val query = input.generate(Explode('b), join = false).analyze + val optimized = Optimize.execute(query) - val correctAnswer = - Generate(Explode('b), false, false, None, 's.string :: Nil, - Project('b.attr :: Nil, input)).analyze + val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze comparePlans(optimized, correctAnswer) } @@ -49,16 +48,19 @@ class ColumnPruningSuite extends PlanTest { val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) val query = - Project(Seq('a, 's), - Generate(Explode('c), true, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('c), join = true, outputNames = "explode" :: Nil) + .select('a, 'explode) + .analyze + val optimized = Optimize.execute(query) val correctAnswer = - Project(Seq('a, 's), - Generate(Explode('c), true, false, None, 's.string :: Nil, - Project(Seq('a, 'c), - input))).analyze + input + .select('a, 'c) + .generate(Explode('c), join = true, outputNames = "explode" :: Nil) + .select('a, 'explode) + .analyze comparePlans(optimized, correctAnswer) } @@ -67,15 +69,18 @@ class ColumnPruningSuite extends PlanTest { val input = LocalRelation('b.array(StringType)) val query = - Project(('s + 1).as("s+1") :: Nil, - Generate(Explode('b), true, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('b), join = true, outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) + .analyze + val optimized = Optimize.execute(query) val correctAnswer = - Project(('s + 1).as("s+1") :: Nil, - Generate(Explode('b), false, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('b), join = false, outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) + .analyze comparePlans(optimized, correctAnswer) } From 98904fda92ee174d89d2ecd15121859df648bd72 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Tue, 15 Dec 2015 18:30:59 -0800 Subject: [PATCH 1155/1168] [SPARK-4117][YARN] Spark on Yarn handle AM being told command from RM Spark on Yarn handle AM being told command from RM When RM throws ApplicationAttemptNotFoundException for allocate invocation, making the ApplicationMaster to finish immediately without any retries. Author: Devaraj K Closes #10129 from devaraj-kavali/SPARK-4117. --- .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 1970f7d150feb..fc742df73d731 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -376,7 +376,14 @@ private[spark] class ApplicationMaster( case i: InterruptedException => case e: Throwable => { failureCount += 1 - if (!NonFatal(e) || failureCount >= reporterMaxFailures) { + // this exception was introduced in hadoop 2.4 and this code would not compile + // with earlier versions if we refer it directly. + if ("org.apache.hadoop.yarn.exceptions.ApplicationAttemptNotFoundException" == + e.getClass().getName()) { + logError("Exception from Reporter thread.", e) + finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, + e.getMessage) + } else if (!NonFatal(e) || failureCount >= reporterMaxFailures) { finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " + s"$failureCount time(s) from Reporter thread.") From 11bb8c159a626631f769d199e064d8d669c09eae Mon Sep 17 00:00:00 2001 From: proflin Date: Tue, 15 Dec 2015 20:22:56 -0800 Subject: [PATCH 1156/1168] =?UTF-8?q?[SPARK-12304][STREAMING]=20Make=20Spa?= =?UTF-8?q?rk=20Streaming=20web=20UI=20display=20more=20fri=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …endly Receiver graphs Currently, the Spark Streaming web UI uses the same maxY when displays 'Input Rate Times& Histograms' and 'Per-Receiver Times& Histograms'. This may lead to somewhat un-friendly graphs: once we have tens of Receivers or more, every 'Per-Receiver Times' line almost hits the ground. This issue proposes to calculate a new maxY against the original one, which is shared among all the `Per-Receiver Times& Histograms' graphs. Before: ![before-5](https://cloud.githubusercontent.com/assets/15843379/11761362/d790c356-a0fa-11e5-860e-4b834603de1d.png) After: ![after-5](https://cloud.githubusercontent.com/assets/15843379/11761361/cfabf692-a0fa-11e5-97d0-4ad124aaca2a.png) Author: proflin Closes #10318 from proflin/SPARK-12304. --- .../org/apache/spark/streaming/ui/StreamingPage.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 88a4483e8068f..b3692c3ea302b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -392,9 +392,15 @@ private[ui] class StreamingPage(parent: StreamingTab) maxX: Long, minY: Double, maxY: Double): Seq[Node] = { + val maxYCalculated = listener.receivedEventRateWithBatchTime.values + .flatMap { case streamAndRates => streamAndRates.map { case (_, eventRate) => eventRate } } + .reduceOption[Double](math.max) + .map(_.ceil.toLong) + .getOrElse(0L) + val content = listener.receivedEventRateWithBatchTime.toList.sortBy(_._1).map { case (streamId, eventRates) => - generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxY) + generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxYCalculated) }.foldLeft[Seq[Node]](Nil)(_ ++ _) // scalastyle:off From d690e1c0bbfa0daa97ed11419f6f35029d89fa77 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 15 Dec 2015 22:22:49 -0800 Subject: [PATCH 1157/1168] [SPARK-12249][SQL] JDBC non-equality comparison operator not pushed down. https://issues.apache.org/jira/browse/SPARK-12249 Currently `!=` operator is not pushed down correctly. I simply added a case for this. Author: hyukjinkwon Closes #10233 from HyukjinKwon/SPARK-12249. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 1 + .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 ++ 2 files changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 1c348ed62fc78..c18a2d2cc0768 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -281,6 +281,7 @@ private[sql] class JDBCRDD( */ private def compileFilter(f: Filter): String = f match { case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case Not(EqualTo(attr, value)) => s"$attr != ${compileValue(value)}" case LessThan(attr, value) => s"$attr < ${compileValue(value)}" case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index a360947152996..aca1443057343 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -177,9 +177,11 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("SELECT * WHERE (simple predicates)") { assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size === 0) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size === 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size === 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size === 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size === 2) } test("SELECT * WHERE (quoted strings)") { From f354717af818ec2314f181cd5d172eae2c0ff027 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 15 Dec 2015 22:25:08 -0800 Subject: [PATCH 1158/1168] [SPARK-12314][SQL] isnull operator not pushed down for JDBC datasource. https://issues.apache.org/jira/browse/SPARK-12314 `IsNull` filter is not being pushed down for JDBC datasource. It looks it is SQL standard according to [SQL-92](http://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt), SQL:1999, [SQL:2003](http://www.wiscorp.com/sql_2003_standard.zip) and [SQL:201x](http://www.wiscorp.com/sql20nn.zip) and I believe most databases support this. In this PR, I simply added the case for `IsNull` filter to produce a proper filter string. Author: hyukjinkwon This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #10286 from HyukjinKwon/SPARK-12314. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 1 + .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index c18a2d2cc0768..3271b46be18fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -286,6 +286,7 @@ private[sql] class JDBCRDD( case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" + case IsNull(attr) => s"$attr IS NULL" case _ => null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index aca1443057343..0305667ff66ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -182,6 +182,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size === 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size === 2) + assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size === 1) } test("SELECT * WHERE (quoted strings)") { From 1c533ab4b2565ad6d4735468ef66532334e3a40c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 15 Dec 2015 22:30:35 -0800 Subject: [PATCH 1159/1168] [SPARK-12315][SQL] isnotnull operator not pushed down for JDBC datasource. https://issues.apache.org/jira/browse/SPARK-12315 `IsNotNull` filter is not being pushed down for JDBC datasource. It looks it is SQL standard according to [SQL-92](http://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt), SQL:1999, [SQL:2003](http://www.wiscorp.com/sql_2003_standard.zip) and [SQL:201x](http://www.wiscorp.com/sql20nn.zip) and I believe most databases support this. In this PR, I simply added the case for `IsNotNull` filter to produce a proper filter string. Author: hyukjinkwon This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #10287 from HyukjinKwon/SPARK-12315. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 1 + .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 ++ 2 files changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 3271b46be18fb..2d38562e0901a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -287,6 +287,7 @@ private[sql] class JDBCRDD( case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" case IsNull(attr) => s"$attr IS NULL" + case IsNotNull(attr) => s"$attr IS NOT NULL" case _ => null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0305667ff66ee..d6aeb523ea8d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -183,6 +183,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size === 2) assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size === 1) + assert(stripSparkFilter( + sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size === 0) } test("SELECT * WHERE (quoted strings)") { From 140a0b85ec2bc395b01da8da528b2bddce207740 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 15 Dec 2015 22:32:51 -0800 Subject: [PATCH 1160/1168] Style fix for the previous 3 JDBC filter push down commits. --- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d6aeb523ea8d6..2b91f62c2fa22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -176,15 +176,14 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT * WHERE (simple predicates)") { - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size === 0) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size === 2) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size === 1) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size === 1) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size === 2) - assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size === 1) - assert(stripSparkFilter( - sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size === 0) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size == 0) } test("SELECT * WHERE (quoted strings)") { From 29d744da572cc725e7b77bb310df9575e22f977f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 16 Dec 2015 00:57:07 -0800 Subject: [PATCH 1161/1168] Revert "[HOTFIX] Compile error from commit 31b3910" This reverts commit 840bd2e008da5b22bfa73c587ea2c57666fffc60. --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 33b03be1138be..b69d4411425d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -234,7 +234,7 @@ class DataFrame private[sql]( // For Data that has more than "_numRows" records if (hasMoreData) { val rowsString = if (_numRows == 1) "row" else "rows" - sb.append(s"only showing top ${_numRows} $rowsString\n") + sb.append(s"only showing top $_numRows $rowsString\n") } sb.toString() From 8b64134728b3ef3415a23349dfef231618be44f1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 16 Dec 2015 00:57:34 -0800 Subject: [PATCH 1162/1168] Revert "[SPARK-12105] [SQL] add convenient show functions" This reverts commit 31b391019ff6eb5a483f4b3e62fd082de7ff8416. --- .../org/apache/spark/sql/DataFrame.scala | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index b69d4411425d4..497bd48266770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -160,24 +160,17 @@ class DataFrame private[sql]( } } - /** - * Compose the string representing rows for output - */ - def showString(): String = { - showString(20) - } - /** * Compose the string representing rows for output - * @param numRows Number of rows to show + * @param _numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ - def showString(numRows: Int, truncate: Boolean = true): String = { - val _numRows = numRows.max(0) + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + val numRows = _numRows.max(0) val sb = new StringBuilder - val takeResult = take(_numRows + 1) - val hasMoreData = takeResult.length > _numRows - val data = takeResult.take(_numRows) + val takeResult = take(numRows + 1) + val hasMoreData = takeResult.length > numRows + val data = takeResult.take(numRows) val numCols = schema.fieldNames.length // For array values, replace Seq and Array with square brackets @@ -231,10 +224,10 @@ class DataFrame private[sql]( sb.append(sep) - // For Data that has more than "_numRows" records + // For Data that has more than "numRows" records if (hasMoreData) { - val rowsString = if (_numRows == 1) "row" else "rows" - sb.append(s"only showing top $_numRows $rowsString\n") + val rowsString = if (numRows == 1) "row" else "rows" + sb.append(s"only showing top $numRows $rowsString\n") } sb.toString() From 95385ea76967b69884d794899b08ec453433c9d2 Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Tue, 15 Sep 2015 18:11:10 +0200 Subject: [PATCH 1163/1168] [SPARK-10618] [Mesos] Refactoring coarsed-grained scheduling condition and adding test for same --- .../cluster/mesos/CoarseMesosSchedulerBackend.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 3adb022822094..d5831d2ddf8d9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -303,6 +303,17 @@ private[spark] class CoarseMesosSchedulerBackend( } } + def isOfferValidForScheduling(meetsConstraints: Boolean, + slaveId: String, mem: Double, + cpus: Int, sc: SparkContext): Boolean = { + taskIdToSlaveId.size < executorLimit && + totalCoresAcquired < maxCores && + meetsConstraints && + mem >= calculateTotalMemory(sc) && + cpus >= 1 && + failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && + !slaveIdsWithExecutors.contains(slaveId) + } def isOfferValidForScheduling(meetsConstraints: Boolean, slaveId: String, mem: Double, From 3b01c2a5207cb0a801211c7e7bff69774da997a1 Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Wed, 16 Sep 2015 11:10:55 +0200 Subject: [PATCH 1164/1168] [SPARK-10618] [Mesos] Killing space and removing duplication, also having one test for each case --- .../CoarseMesosSchedulerBackendSuite.scala | 72 ++++++++++--------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index b68fcb7e4d13a..3b8713785ee05 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -77,6 +77,14 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite backend } + private def createSchedulerBackendForGivenSparkConf(sc : SparkContext) = { + val driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + createSchedulerBackend(taskScheduler, driver, sc) + } + var sparkConf: SparkConf = _ before { @@ -84,10 +92,12 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite .setMaster("local[*]") .setAppName("test-mesos-dynamic-alloc") .setSparkHome("/path") + .set("spark.cores.max", "10") + + sc = new SparkContext(sparkConf) } test("mesos supports killing and limiting executors") { - sc = new SparkContext(sparkConf) val driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) val taskScheduler = mock[TaskSchedulerImpl] @@ -139,7 +149,6 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite } test("mesos supports killing and relaunching tasks with executors") { - sc = new SparkContext(sparkConf) val driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) val taskScheduler = mock[TaskSchedulerImpl] @@ -153,7 +162,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite val offer1 = createOffer("o1", "s1", minMem, minCpu) mesosOffers.add(offer1) - val offer2 = createOffer("o2", "s1", minMem, 1); + val offer2 = createOffer("o2", "s1", minMem, 1) backend.resourceOffers(driver, mesosOffers) @@ -185,51 +194,50 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite verify(driver, times(1)).reviveOffers() } - test("testing various isOfferValidForScheduling cases") { - val sparkConf = (new SparkConf) - .setMaster("local[*]") - .setAppName("test-mesos-dynamic-alloc") - .setSparkHome("/path") - .set("spark.cores.max", "10") - - val sc = new SparkContext(sparkConf) - - val driver = mock[SchedulerDriver] - when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) - val taskScheduler = mock[TaskSchedulerImpl] - when(taskScheduler.sc).thenReturn(sc) - - - val schedulerBackend = createSchedulerBackend(taskScheduler, driver, sc) + test("isOfferValidForScheduling return true when there is a valid offer") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - // Return true when there is a valid offer assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 5, sc)) + } - schedulerBackend.slaveIdsWithExecutors += "Slave2" - schedulerBackend.failuresBySlaveId("Slave3") = 2 - schedulerBackend.totalCoresAcquired = 5 + test("isOfferValidForScheduling return false when offer do not meet constraints") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - // Return false When offer do not meet constraints assert(schedulerBackend.isOfferValidForScheduling(false, "Slave1", 10000, 5, sc) === false) + } + + test("isOfferValidForScheduling return false when memory in offer is less than required memory") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - // Return false When memory in offer is less than required memory assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 1, 5, sc) === false) + } + + test("isOfferValidForScheduling return false when cpu in offer is less than required cpu") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - // Return false When cpu in offer is less than required cpu assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 0, sc) === false) + } + + test("isOfferValidForScheduling return false when offer is from slave already running" + + " an executor") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + schedulerBackend.slaveIdsWithExecutors += "Slave2" - // Return false When offer is from slave already running an executor assert(schedulerBackend.isOfferValidForScheduling(true, "Slave2", 10000, 5, sc) === false) + } + + test("isOfferValidForScheduling return false when task is failed more than " + + "MAX_SLAVE_FAILURES times on the given slave") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + schedulerBackend.failuresBySlaveId("Slave3") = 2 - // Return false When task is failed more than MAX_SLAVE_FAILURES times on the given slave assert(schedulerBackend.isOfferValidForScheduling(true, "Slave3", 10000, 5, sc) === false) + } + test("isOfferValidForScheduling return false when max core is already acquired") { + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.totalCoresAcquired = 10 - // Return false When max core is already acquired assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 5, sc) === false) - } - - } From cc67b7dd339aa7da58b1affa3f776a9f8ab03e1b Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Thu, 17 Sep 2015 18:28:59 +0200 Subject: [PATCH 1165/1168] [SPARK-10514] [CORE] [Mesos] Refactoring fine-grained scheduling condition and adding test for it. --- .../cluster/mesos/MesosSchedulerBackend.scala | 25 +++++--- .../mesos/MesosSchedulerBackendSuite.scala | 62 ++++++++++++++++++- 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 281965a5981bb..2617a126889a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -246,14 +246,16 @@ private[spark] class MesosSchedulerBackend( val slaveId = o.getSlaveId.getValue val offerAttributes = toAttributeMap(o.getAttributesList) - // check offers for - // 1. Memory requirements - // 2. CPU requirements - need at least 1 for executor, 1 for task - val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) - val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) + // check if all constraints are satisfield + // 1. Attribute constraints + // 2. Memory requirements + // 3. CPU requirements - need at least 1 for executor, 1 for task + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val meetsRequirements = - (meetsMemoryRequirements && meetsCPURequirements) || - (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) + isOfferValidForScheduling(cpus, mem, slaveId, sc) + + // add some debug messaging val debugstr = if (meetsRequirements) "Accepting" else "Declining" logDebug(s"$debugstr offer: ${o.getId.getValue} with attributes: " + s"$offerAttributes mem: $mem cpu: $cpus") @@ -328,6 +330,15 @@ private[spark] class MesosSchedulerBackend( } } + def isOfferValidForScheduling(cpusOffered: Double, memory : Double, + slaveId: String, sc : SparkContext): Boolean = { + val meetsMemoryRequirements = memory >= calculateTotalMemory(sc) + val meetsCPURequirements = cpusOffered >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) + + (meetsMemoryRequirements && meetsCPURequirements) || + (slaveIdToExecutorInfo.contains(slaveId) && cpusOffered >= scheduler.CPUS_PER_TASK) + } + /** Turn a Spark TaskDescription into a Mesos task and also resources unused by the task */ def createMesosTask( task: TaskDescription, diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index c4dc560031207..11e8a330f9c60 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -28,7 +28,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.mesos.Protos.Value.Scalar import org.apache.mesos.Protos._ -import org.apache.mesos.SchedulerDriver +import org.apache.mesos.{Protos, SchedulerDriver} import org.mockito.Matchers._ import org.mockito.Mockito._ import org.mockito.{ArgumentCaptor, Matchers} @@ -344,4 +344,64 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi r.getName.equals("cpus") && r.getScalar.getValue.equals(1.0) && r.getRole.equals("prod") }) } + + private def createSchedulerBackendForGivenSparkConf(sc : SparkContext) : MesosSchedulerBackend = { + val conf = new SparkConf + + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + when(sc.getSparkHome()).thenReturn(Option("/spark-home")) + + when(sc.conf).thenReturn(conf) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.executorMemory).thenReturn(100) + when(sc.listenerBus).thenReturn(listenerBus) + + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) + + new MesosSchedulerBackend(taskScheduler, sc, "master") + } + + test("isOfferValidForScheduling return true when there offer meet cpu and memory requirement") { + val sc = mock[SparkContext] + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + + assert(schedulerBackend.isOfferValidForScheduling( 5, 10000, "Slave1", sc)) + } + + test("isOfferValidForScheduling return false when memory in offer is less than required memory") { + val sc = mock[SparkContext] + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + + assert(schedulerBackend.isOfferValidForScheduling(5, 10, "Slave1", sc) === false) + } + + test("isOfferValidForScheduling return false when cpu in offer is less than required cpu") { + val sc = mock[SparkContext] + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + + assert(schedulerBackend.isOfferValidForScheduling(0, 10000, "Slave1", sc) === false) + } + + test("isOfferValidForScheduling return true when offer is from slave already running and" + + " cpu is less than minimum cpu per task an executor") { + val sc = mock[SparkContext] + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + schedulerBackend.slaveIdToExecutorInfo("Slave2") = null + + assert(schedulerBackend.isOfferValidForScheduling(2, 10000, "Slave2", sc) === true) + } + + test("isOfferValidForScheduling return false when offer is from slave already running but" + + " cpu is less than minimum cpu per task an executor") { + val sc = mock[SparkContext] + val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) + schedulerBackend.slaveIdToExecutorInfo("Slave2") = null + + assert(schedulerBackend.isOfferValidForScheduling(1, 10000, "Slave2", sc) === false) + } + } From fde2994bf6d71c0de06ae4ba61a6f64627bdb467 Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Wed, 16 Dec 2015 00:23:30 +0100 Subject: [PATCH 1166/1168] [SPARK-10618] [Mesos] Adressing comments on PR. --- .../mesos/CoarseMesosSchedulerBackend.scala | 26 +++++++++++------ .../cluster/mesos/MesosSchedulerBackend.scala | 12 ++++---- .../CoarseMesosSchedulerBackendSuite.scala | 28 +++++++++---------- .../mesos/MesosSchedulerBackendSuite.scala | 20 ++++++------- 4 files changed, 47 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d5831d2ddf8d9..aa847fb6b80e8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -254,7 +254,7 @@ private[spark] class CoarseMesosSchedulerBackend( val cpus = getResource(offer.getResourcesList, "cpus").toInt val id = offer.getId.getValue if (meetsConstraints) { - if (isOfferValidForScheduling(meetsConstraints, slaveId, mem, cpus, sc)) { + if (isOfferSatisfiesRequirements(meetsConstraints, slaveId, mem, cpus, sc)) { // Launch an executor on the slave val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) totalCoresAcquired += cpusToUse @@ -303,16 +303,24 @@ private[spark] class CoarseMesosSchedulerBackend( } } - def isOfferValidForScheduling(meetsConstraints: Boolean, + // ToDo: Abstract out each condition and log them. + def isOfferSatisfiesRequirements(meetsConstraints: Boolean, slaveId: String, mem: Double, - cpus: Int, sc: SparkContext): Boolean = { - taskIdToSlaveId.size < executorLimit && - totalCoresAcquired < maxCores && + cpusOffered: Int, sc: SparkContext): Boolean = { + val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) + val meetsCPURequirements = cpusOffered >= 1 + val needMoreCores = totalCoresAcquired < maxCores + val healthySlave = failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES + val taskOnEachSlaveLessThanExecutorLimit = taskIdToSlaveId.size < executorLimit + val executorNotRunningOnSlave = !slaveIdsWithExecutors.contains(slaveId) + + executorNotRunningOnSlave && + taskOnEachSlaveLessThanExecutorLimit && + needMoreCores && meetsConstraints && - mem >= calculateTotalMemory(sc) && - cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId) + meetsMemoryRequirements && + meetsCPURequirements && + healthySlave } def isOfferValidForScheduling(meetsConstraints: Boolean, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 2617a126889a9..9b398749eb48e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -246,14 +246,11 @@ private[spark] class MesosSchedulerBackend( val slaveId = o.getSlaveId.getValue val offerAttributes = toAttributeMap(o.getAttributesList) - // check if all constraints are satisfield - // 1. Attribute constraints - // 2. Memory requirements - // 3. CPU requirements - need at least 1 for executor, 1 for task + // check if Attribute constraints is satisfied val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) val meetsRequirements = - isOfferValidForScheduling(cpus, mem, slaveId, sc) + isOfferSatisfiesRequirements(cpus, mem, slaveId, sc) // add some debug messaging val debugstr = if (meetsRequirements) "Accepting" else "Declining" @@ -330,7 +327,10 @@ private[spark] class MesosSchedulerBackend( } } - def isOfferValidForScheduling(cpusOffered: Double, memory : Double, + // check if all constraints are satisfied + // 1. Memory requirements + // 2. CPU requirements - need at least 1 for executor, 1 for task + def isOfferSatisfiesRequirements(cpusOffered: Double, memory : Double, slaveId: String, sc : SparkContext): Boolean = { val meetsMemoryRequirements = memory >= calculateTotalMemory(sc) val meetsCPURequirements = cpusOffered >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index 3b8713785ee05..c9bcfce8d846c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -194,50 +194,50 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite verify(driver, times(1)).reviveOffers() } - test("isOfferValidForScheduling return true when there is a valid offer") { + test("isOfferSatisfiesRequirements return true when there is a valid offer") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 5, sc)) + assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 10000, 5, sc)) } - test("isOfferValidForScheduling return false when offer do not meet constraints") { + test("isOfferSatisfiesRequirements return false when offer do not meet constraints") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling(false, "Slave1", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(false, "Slave1", 10000, 5, sc) === false) } - test("isOfferValidForScheduling return false when memory in offer is less than required memory") { + test("isOfferSatisfiesRequirements return false when memory in offer is less than required memory") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 1, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 1, 5, sc) === false) } - test("isOfferValidForScheduling return false when cpu in offer is less than required cpu") { + test("isOfferSatisfiesRequirements return false when cpu in offer is less than required cpu") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 0, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 10000, 0, sc) === false) } - test("isOfferValidForScheduling return false when offer is from slave already running" + + test("isOfferSatisfiesRequirements return false when offer is from slave already running" + " an executor") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.slaveIdsWithExecutors += "Slave2" - assert(schedulerBackend.isOfferValidForScheduling(true, "Slave2", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave2", 10000, 5, sc) === false) } - test("isOfferValidForScheduling return false when task is failed more than " + + test("isOfferSatisfiesRequirements return false when task is failed more than " + "MAX_SLAVE_FAILURES times on the given slave") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.failuresBySlaveId("Slave3") = 2 - assert(schedulerBackend.isOfferValidForScheduling(true, "Slave3", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave3", 10000, 5, sc) === false) } - test("isOfferValidForScheduling return false when max core is already acquired") { + test("isOfferSatisfiesRequirements return false when max core is already acquired") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.totalCoresAcquired = 10 - assert(schedulerBackend.isOfferValidForScheduling(true, "Slave1", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 10000, 5, sc) === false) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 11e8a330f9c60..cbfd4bcae7bf2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -365,43 +365,43 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi new MesosSchedulerBackend(taskScheduler, sc, "master") } - test("isOfferValidForScheduling return true when there offer meet cpu and memory requirement") { + test("isOfferSatisfiesRequirements return true when there offer meet cpu and memory requirement") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling( 5, 10000, "Slave1", sc)) + assert(schedulerBackend.isOfferSatisfiesRequirements( 5, 10000, "Slave1", sc)) } - test("isOfferValidForScheduling return false when memory in offer is less than required memory") { + test("isOfferSatisfiesRequirements return false when memory in offer is less than required memory") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling(5, 10, "Slave1", sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(5, 10, "Slave1", sc) === false) } - test("isOfferValidForScheduling return false when cpu in offer is less than required cpu") { + test("isOfferSatisfiesRequirements return false when cpu in offer is less than required cpu") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferValidForScheduling(0, 10000, "Slave1", sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(0, 10000, "Slave1", sc) === false) } - test("isOfferValidForScheduling return true when offer is from slave already running and" + + test("isOfferSatisfiesRequirements return true when offer is from slave already running and" + " cpu is less than minimum cpu per task an executor") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.slaveIdToExecutorInfo("Slave2") = null - assert(schedulerBackend.isOfferValidForScheduling(2, 10000, "Slave2", sc) === true) + assert(schedulerBackend.isOfferSatisfiesRequirements(2, 10000, "Slave2", sc) === true) } - test("isOfferValidForScheduling return false when offer is from slave already running but" + + test("isOfferSatisfiesRequirements return false when offer is from slave already running but" + " cpu is less than minimum cpu per task an executor") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.slaveIdToExecutorInfo("Slave2") = null - assert(schedulerBackend.isOfferValidForScheduling(1, 10000, "Slave2", sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements(1, 10000, "Slave2", sc) === false) } } From 66f1750fd73be16481d0ae5bd1b563f5649d2d27 Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Wed, 16 Dec 2015 11:06:17 +0100 Subject: [PATCH 1167/1168] [SPARK-10618] [Mesos] removing meetconstraints check as it is already done --- .../mesos/CoarseMesosSchedulerBackend.scala | 8 +++----- .../CoarseMesosSchedulerBackendSuite.scala | 17 ++++++----------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index aa847fb6b80e8..f5d9e0884b055 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -254,7 +254,7 @@ private[spark] class CoarseMesosSchedulerBackend( val cpus = getResource(offer.getResourcesList, "cpus").toInt val id = offer.getId.getValue if (meetsConstraints) { - if (isOfferSatisfiesRequirements(meetsConstraints, slaveId, mem, cpus, sc)) { + if (isOfferSatisfiesRequirements(slaveId, mem, cpus, sc)) { // Launch an executor on the slave val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) totalCoresAcquired += cpusToUse @@ -304,9 +304,8 @@ private[spark] class CoarseMesosSchedulerBackend( } // ToDo: Abstract out each condition and log them. - def isOfferSatisfiesRequirements(meetsConstraints: Boolean, - slaveId: String, mem: Double, - cpusOffered: Int, sc: SparkContext): Boolean = { + def isOfferSatisfiesRequirements(slaveId: String, mem: Double, cpusOffered: Int, + sc: SparkContext): Boolean = { val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) val meetsCPURequirements = cpusOffered >= 1 val needMoreCores = totalCoresAcquired < maxCores @@ -317,7 +316,6 @@ private[spark] class CoarseMesosSchedulerBackend( executorNotRunningOnSlave && taskOnEachSlaveLessThanExecutorLimit && needMoreCores && - meetsConstraints && meetsMemoryRequirements && meetsCPURequirements && healthySlave diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index c9bcfce8d846c..f6517c9090415 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -197,25 +197,20 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite test("isOfferSatisfiesRequirements return true when there is a valid offer") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 10000, 5, sc)) + assert(schedulerBackend.isOfferSatisfiesRequirements("Slave1", 10000, 5, sc)) } - test("isOfferSatisfiesRequirements return false when offer do not meet constraints") { - val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - - assert(schedulerBackend.isOfferSatisfiesRequirements(false, "Slave1", 10000, 5, sc) === false) - } test("isOfferSatisfiesRequirements return false when memory in offer is less than required memory") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 1, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements("Slave1", 1, 5, sc) === false) } test("isOfferSatisfiesRequirements return false when cpu in offer is less than required cpu") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) - assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 10000, 0, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements("Slave1", 10000, 0, sc) === false) } test("isOfferSatisfiesRequirements return false when offer is from slave already running" + @@ -223,7 +218,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.slaveIdsWithExecutors += "Slave2" - assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave2", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements("Slave2", 10000, 5, sc) === false) } test("isOfferSatisfiesRequirements return false when task is failed more than " + @@ -231,13 +226,13 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.failuresBySlaveId("Slave3") = 2 - assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave3", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements("Slave3", 10000, 5, sc) === false) } test("isOfferSatisfiesRequirements return false when max core is already acquired") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) schedulerBackend.totalCoresAcquired = 10 - assert(schedulerBackend.isOfferSatisfiesRequirements(true, "Slave1", 10000, 5, sc) === false) + assert(schedulerBackend.isOfferSatisfiesRequirements("Slave1", 10000, 5, sc) === false) } } From c10af6639e27e301a0b698d725d7dadaa271127b Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Wed, 16 Dec 2015 14:37:13 +0100 Subject: [PATCH 1168/1168] [SPARK-10618] [Mesos] checking in missed Style check fix --- .../cluster/mesos/CoarseMesosSchedulerBackendSuite.scala | 3 ++- .../cluster/mesos/MesosSchedulerBackendSuite.scala | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index f6517c9090415..9a4f9b7352fd4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -201,7 +201,8 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite } - test("isOfferSatisfiesRequirements return false when memory in offer is less than required memory") { + test("isOfferSatisfiesRequirements return false when memory in offer is less" + + " than required memory") { val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) assert(schedulerBackend.isOfferSatisfiesRequirements("Slave1", 1, 5, sc) === false) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index cbfd4bcae7bf2..9e013a19c6b18 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -365,14 +365,16 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi new MesosSchedulerBackend(taskScheduler, sc, "master") } - test("isOfferSatisfiesRequirements return true when there offer meet cpu and memory requirement") { + test("isOfferSatisfiesRequirements return true when there offer meet cpu and" + + " memory requirement") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc) assert(schedulerBackend.isOfferSatisfiesRequirements( 5, 10000, "Slave1", sc)) } - test("isOfferSatisfiesRequirements return false when memory in offer is less than required memory") { + test("isOfferSatisfiesRequirements return false when memory in offer is less " + + "than required memory") { val sc = mock[SparkContext] val schedulerBackend = createSchedulerBackendForGivenSparkConf(sc)