From 81d6933e75579343b1dd14792c18149e97e92cdd Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Mon, 24 Oct 2016 15:33:02 -0700 Subject: [PATCH 001/381] [SPARK-17894][CORE] Ensure uniqueness of TaskSetManager name. `TaskSetManager` should have unique name to avoid adding duplicate ones to parent `Pool` via `SchedulableBuilder`. This problem has been surfaced with following discussion: [[PR: Avoid adding duplicate schedulables]](https://github.com/apache/spark/pull/15326) **Proposal** : There is 1x1 relationship between `stageAttemptId` and `TaskSetManager` so `taskSet.Id` covering both `stageId` and `stageAttemptId` looks to be used for uniqueness of `TaskSetManager` name instead of just `stageId`. **Current TaskSetManager Name** : `var name = "TaskSet_" + taskSet.stageId.toString` **Sample**: TaskSet_0 **Proposed TaskSetManager Name** : `val name = "TaskSet_" + taskSet.Id ` `// taskSet.Id = (stageId + "." + stageAttemptId)` **Sample** : TaskSet_0.0 Added new Unit Test. Author: erenavsarogullari Closes #15463 from erenavsarogullari/SPARK-17894. --- .../spark/scheduler/TaskSetManager.scala | 2 +- .../org/apache/spark/scheduler/FakeTask.scala | 13 ++++++++---- .../spark/scheduler/TaskSetManagerSuite.scala | 20 ++++++++++++++++++- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 9491bc7a0497..b766e4148e49 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -79,7 +79,7 @@ private[spark] class TaskSetManager( var minShare = 0 var priority = taskSet.priority var stageId = taskSet.stageId - var name = "TaskSet_" + taskSet.stageId.toString + val name = "TaskSet_" + taskSet.id var parent: Pool = null var totalResultSize = 0L var calculatedTasks = 0 diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 87600fe504b9..f395fe9804c9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -22,7 +22,7 @@ import org.apache.spark.TaskContext class FakeTask( stageId: Int, partitionId: Int, - prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, partitionId) { + prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, stageAttemptId = 0, partitionId) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs } @@ -33,16 +33,21 @@ object FakeTask { * locations for each task (given as varargs) if this sequence is not empty. */ def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { - createTaskSet(numTasks, 0, prefLocs: _*) + createTaskSet(numTasks, stageAttemptId = 0, prefLocs: _*) } def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + createTaskSet(numTasks, stageId = 0, stageAttemptId, prefLocs: _*) + } + + def createTaskSet(numTasks: Int, stageId: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): + TaskSet = { if (prefLocs.size != 0 && prefLocs.size != numTasks) { throw new IllegalArgumentException("Wrong number of task locations") } val tasks = Array.tabulate[Task[_]](numTasks) { i => - new FakeTask(0, i, if (prefLocs.size != 0) prefLocs(i) else Nil) + new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil) } - new TaskSet(tasks, 0, stageAttemptId, 0, null) + new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 69edcf334724..b49ba085ca5d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -904,7 +904,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg task.index == index && !sched.endedTasks.contains(task.taskId) }.getOrElse { throw new RuntimeException(s"couldn't find index $index in " + - s"tasks: ${tasks.map{t => t.index -> t.taskId}} with endedTasks:" + + s"tasks: ${tasks.map { t => t.index -> t.taskId }} with endedTasks:" + s" ${sched.endedTasks.keys}") } } @@ -974,6 +974,24 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.isZombie) } + test("SPARK-17894: Verify TaskSetManagers for different stage attempts have unique names") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, new ManualClock) + assert(manager.name === "TaskSet_0.0") + + // Make sure a task set with the same stage ID but different attempt ID has a unique name + val taskSet2 = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 1) + val manager2 = new TaskSetManager(sched, taskSet2, MAX_TASK_FAILURES, new ManualClock) + assert(manager2.name === "TaskSet_0.1") + + // Make sure a task set with the same attempt ID but different stage ID also has a unique name + val taskSet3 = FakeTask.createTaskSet(numTasks = 1, stageId = 1, stageAttemptId = 1) + val manager3 = new TaskSetManager(sched, taskSet3, MAX_TASK_FAILURES, new ManualClock) + assert(manager3.name === "TaskSet_1.1") + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { From 407c3cedf29a4413339dcde758295dc3225a0054 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 24 Oct 2016 17:21:16 -0700 Subject: [PATCH 002/381] [SPARK-17624][SQL][STREAMING][TEST] Fixed flaky StateStoreSuite.maintenance ## What changes were proposed in this pull request? The reason for the flakiness was follows. The test starts the maintenance background thread, and then writes 20 versions of the state store. The maintenance thread is expected to create snapshots in the middle, and clean up old files that are not needed any more. The earliest delta file (1.delta) is expected to be deleted as snapshots will ensure that the earliest delta would not be needed. However, the default configuration for the maintenance thread is to retain files such that last 2 versions can be recovered, and delete the rest. Now while generating the versions, the maintenance thread can kick in and create snapshots anywhere between version 10 and 20 (at least 10 deltas needed for snapshot). Then later it will choose to retain only version 20 and 19 (last 2). There are two cases. - Common case: One of the version between 10 and 19 gets snapshotted. Then recovering versions 19 and 20 just needs 19.snapshot and 20.delta, so 1.delta gets deleted. - Uncommon case (reason for flakiness): Only version 20 gets snapshotted. Then recovering versoin 20 requires 20.snapshot, and recovering version 19 all the previous 19...1.delta. So 1.delta does not get deleted. This PR rearranges the checks such that it create 20 versions, and then waits that there is at least one snapshot, then creates another 20. This will ensure that the latest 2 versions cannot require anything older than the first snapshot generated, and therefore will 1.delta will be deleted. In addition, I have added more logs, and comments that I felt would help future debugging and understanding what is going on. ## How was this patch tested? Ran the StateStoreSuite > 6K times in a heavily loaded machine (10 instances of tests running in parallel). No failures. Author: Tathagata Das Closes #15592 from tdas/SPARK-17624. --- .../state/HDFSBackedStateStoreProvider.scala | 18 ++++--- .../state/StateStoreCoordinator.scala | 18 +++++-- .../streaming/state/StateStoreSuite.scala | 49 ++++++++++++------- 3 files changed, 57 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 7d71f5242c27..f1e7f1d113ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -159,7 +159,7 @@ private[state] class HDFSBackedStateStoreProvider( } catch { case NonFatal(e) => throw new IllegalStateException( - s"Error committing version $newVersion into ${HDFSBackedStateStoreProvider.this}", e) + s"Error committing version $newVersion into $this", e) } } @@ -205,6 +205,10 @@ private[state] class HDFSBackedStateStoreProvider( override private[state] def hasCommitted: Boolean = { state == COMMITTED } + + override def toString(): String = { + s"HDFSStateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + } } /** Get the state store for making updates to create a new `version` of the store. */ @@ -215,7 +219,7 @@ private[state] class HDFSBackedStateStoreProvider( newMap.putAll(loadMap(version)) } val store = new HDFSBackedStateStore(version, newMap) - logInfo(s"Retrieved version $version of $this for update") + logInfo(s"Retrieved version $version of ${HDFSBackedStateStoreProvider.this} for update") store } @@ -231,7 +235,7 @@ private[state] class HDFSBackedStateStoreProvider( } override def toString(): String = { - s"StateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" } /* Internal classes and methods */ @@ -493,10 +497,12 @@ private[state] class HDFSBackedStateStoreProvider( val mapsToRemove = loadedMaps.keys.filter(_ < earliestVersionToRetain).toSeq mapsToRemove.foreach(loadedMaps.remove) } - files.filter(_.version < earliestFileToRetain.version).foreach { f => + val filesToDelete = files.filter(_.version < earliestFileToRetain.version) + filesToDelete.foreach { f => fs.delete(f.path, true) } - logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this") + logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + + filesToDelete.mkString(", ")) } } } catch { @@ -560,7 +566,7 @@ private[state] class HDFSBackedStateStoreProvider( } } val storeFiles = versionToFiles.values.toSeq.sortBy(_.version) - logDebug(s"Current set of files for $this: $storeFiles") + logDebug(s"Current set of files for $this: ${storeFiles.mkString(", ")}") storeFiles } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index d945d7aff2da..267d17623d5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -38,7 +38,7 @@ private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: Str private case class GetLocation(storeId: StateStoreId) extends StateStoreCoordinatorMessage -private case class DeactivateInstances(storeRootLocation: String) +private case class DeactivateInstances(checkpointLocation: String) extends StateStoreCoordinatorMessage private object StopCoordinator @@ -111,11 +111,13 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster, * and get their locations for job scheduling. */ -private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { +private class StateStoreCoordinator(override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] override def receive: PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => + logDebug(s"Reported state store $id is active at $executorId") instances.put(id, ExecutorCacheTaskLocation(host, executorId)) } @@ -125,19 +127,25 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadS case Some(location) => location.executorId == execId case None => false } + logDebug(s"Verified that state store $id is active: $response") context.reply(response) case GetLocation(id) => - context.reply(instances.get(id).map(_.toString)) + val executorId = instances.get(id).map(_.toString) + logDebug(s"Got location of the state store $id: $executorId") + context.reply(executorId) - case DeactivateInstances(loc) => + case DeactivateInstances(checkpointLocation) => val storeIdsToRemove = - instances.keys.filter(_.checkpointLocation == loc).toSeq + instances.keys.filter(_.checkpointLocation == checkpointLocation).toSeq instances --= storeIdsToRemove + logDebug(s"Deactivating instances related to checkpoint location $checkpointLocation: " + + storeIdsToRemove.mkString(", ")) context.reply(true) case StopCoordinator => stop() // Stop before replying to ensure that endpoint name has been deregistered + logInfo("StateStoreCoordinator stopped") context.reply(true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 06f1bd6c3bcc..fcf300b3c81b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -367,7 +367,10 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val conf = new SparkConf() .setMaster("local") .setAppName("test") + // Make maintenance thread do snapshots and cleanups very fast .set(StateStore.MAINTENANCE_INTERVAL_CONFIG, "10ms") + // Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly' + // fails to talk to the StateStoreCoordinator and unloads all the StateStores .set("spark.rpc.numRetries", "1") val opId = 0 val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString @@ -377,37 +380,49 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val provider = new HDFSBackedStateStoreProvider( storeId, keySchema, valueSchema, storeConf, hadoopConf) + var latestStoreVersion = 0 + + def generateStoreVersions() { + for (i <- 1 to 20) { + val store = StateStore.get( + storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) + put(store, "a", i) + store.commit() + latestStoreVersion += 1 + } + } quietly { withSpark(new SparkContext(conf)) { sc => withCoordinatorRef(sc) { coordinatorRef => require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running") - for (i <- 1 to 20) { - val store = StateStore.get( - storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf) - put(store, "a", i) - store.commit() - } + // Generate sufficient versions of store for snapshots + generateStoreVersions() eventually(timeout(10 seconds)) { + // Store should have been reported to the coordinator assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") - } - // Background maintenance should clean up and generate snapshots - assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") - - eventually(timeout(10 seconds)) { - // Earliest delta file should get cleaned up - assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + // Background maintenance should clean up and generate snapshots + assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") // Some snapshots should have been generated - val snapshotVersions = (0 to 20).filter { version => + val snapshotVersions = (1 to latestStoreVersion).filter { version => fileExists(provider, version, isSnapshot = true) } assert(snapshotVersions.nonEmpty, "no snapshot file found") } + // Generate more versions such that there is another snapshot and + // the earliest delta file will be cleaned up + generateStoreVersions() + + // Earliest delta file should get cleaned up + eventually(timeout(10 seconds)) { + assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + } + // If driver decides to deactivate all instances of the store, then this instance // should be unloaded coordinatorRef.deactivateInstances(dir) @@ -416,7 +431,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) // If some other executor loads the store, then this instance should be unloaded @@ -426,14 +441,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) } } // Verify if instance is unloaded if SparkContext is stopped - require(SparkEnv.get === null) eventually(timeout(10 seconds)) { + require(SparkEnv.get === null) assert(!StateStore.isLoaded(storeId)) assert(!StateStore.isMaintenanceRunning) } From 84a33999082af88ea6365cdb5c7232ed0933b1c6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 25 Oct 2016 08:42:21 +0800 Subject: [PATCH 003/381] [SPARK-18028][SQL] simplify TableFileCatalog ## What changes were proposed in this pull request? Simplify/cleanup TableFileCatalog: 1. pass a `CatalogTable` instead of `databaseName` and `tableName` into `TableFileCatalog`, so that we don't need to fetch table metadata from metastore again 2. In `TableFileCatalog.filterPartitions0`, DO NOT set `PartitioningAwareFileCatalog.BASE_PATH_PARAM`. According to the [classdoc](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala#L189-L209), the default value of `basePath` already satisfies our need. What's more, if we set this parameter, we may break the case 2 which is metioned in the classdoc. 3. add `equals` and `hashCode` to `TableFileCatalog` 4. add `SessionCatalog.listPartitionsByFilter` which handles case sensitivity. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #15568 from cloud-fan/table-file-catalog. --- .../sql/catalyst/catalog/SessionCatalog.scala | 14 +++++ .../datasources/TableFileCatalog.scala | 54 ++++++++++--------- .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +- .../spark/sql/hive/CachedTableSuite.scala | 41 +++++++++++++- .../PruneFileSourcePartitionsSuite.scala | 7 +-- 5 files changed, 84 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 9711131d88a0..3d6eec81c03c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -755,6 +755,20 @@ class SessionCatalog( externalCatalog.listPartitions(db, table, partialSpec) } + /** + * List the metadata of partitions that belong to the specified table, assuming it exists, that + * satisfy the given partition-pruning predicate expressions. + */ + def listPartitionsByFilter( + tableName: TableIdentifier, + predicates: Seq[Expression]): Seq[CatalogTablePartition] = { + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + externalCatalog.listPartitionsByFilter(db, table, predicates) + } + /** * Verify if the input partition spec exactly matches the existing defined partition spec * The columns must be the same but the orders could be different. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala index 31a01bc6db08..667379b222c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala @@ -20,36 +20,30 @@ package org.apache.spark.sql.execution.datasources import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.StructType /** * A [[FileCatalog]] for a metastore catalog table. * * @param sparkSession a [[SparkSession]] - * @param db the table's database name - * @param table the table's (unqualified) name - * @param partitionSchema the schema of a partitioned table's partition columns + * @param table the metadata of the table * @param sizeInBytes the table's data size in bytes - * @param fileStatusCache optional cache implementation to use for file listing */ class TableFileCatalog( sparkSession: SparkSession, - db: String, - table: String, - partitionSchema: Option[StructType], + val table: CatalogTable, override val sizeInBytes: Long) extends FileCatalog { protected val hadoopConf = sparkSession.sessionState.newHadoopConf private val fileStatusCache = FileStatusCache.newCache(sparkSession) - private val externalCatalog = sparkSession.sharedState.externalCatalog + assert(table.identifier.database.isDefined, + "The table identifier must be qualified in TableFileCatalog") - private val catalogTable = externalCatalog.getTable(db, table) - - private val baseLocation = catalogTable.storage.locationUri + private val baseLocation = table.storage.locationUri override def rootPaths: Seq[Path] = baseLocation.map(new Path(_)).toSeq @@ -66,24 +60,32 @@ class TableFileCatalog( * @param filters partition-pruning filters */ def filterPartitions(filters: Seq[Expression]): ListingFileCatalog = { - val parameters = baseLocation - .map(loc => Map(PartitioningAwareFileCatalog.BASE_PATH_PARAM -> loc)) - .getOrElse(Map.empty) - partitionSchema match { - case Some(schema) => - val selectedPartitions = externalCatalog.listPartitionsByFilter(db, table, filters) - val partitions = selectedPartitions.map { p => - PartitionPath(p.toRow(schema), p.storage.locationUri.get) - } - val partitionSpec = PartitionSpec(schema, partitions) - new PrunedTableFileCatalog( - sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec) - case None => - new ListingFileCatalog(sparkSession, rootPaths, parameters, None, fileStatusCache) + if (table.partitionColumnNames.nonEmpty) { + val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( + table.identifier, filters) + val partitionSchema = table.partitionSchema + val partitions = selectedPartitions.map { p => + PartitionPath(p.toRow(partitionSchema), p.storage.locationUri.get) + } + val partitionSpec = PartitionSpec(partitionSchema, partitions) + new PrunedTableFileCatalog( + sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec) + } else { + new ListingFileCatalog(sparkSession, rootPaths, table.storage.properties, None) } } override def inputFiles: Array[String] = filterPartitions(Nil).inputFiles + + // `TableFileCatalog` may be a member of `HadoopFsRelation`, `HadoopFsRelation` may be a member + // of `LogicalRelation`, and `LogicalRelation` may be used as the cache key. So we need to + // implement `equals` and `hashCode` here, to make it work with cache lookup. + override def equals(o: Any): Boolean = o match { + case other: TableFileCatalog => this.table.identifier == other.table.identifier + case _ => false + } + + override def hashCode(): Int = table.identifier.hashCode() } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 44089335e1a1..6c1585d5f561 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -226,12 +226,10 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log Some(partitionSchema)) val logicalRelation = cached.getOrElse { - val db = metastoreRelation.databaseName - val table = metastoreRelation.tableName val sizeInBytes = metastoreRelation.statistics.sizeInBytes.toLong val fileCatalog = { val catalog = new TableFileCatalog( - sparkSession, db, table, Some(partitionSchema), sizeInBytes) + sparkSession, metastoreRelation.catalogTable, sizeInBytes) if (lazyPruningEnabled) { catalog } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 7d4ef6f26a60..ecdf4f14b398 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} +import org.apache.spark.sql.{AnalysisException, Dataset, QueryTest, SaveMode} import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, TableFileCatalog} +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.StructType import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils @@ -317,4 +320,40 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("DROP TABLE cachedTable") } + + test("cache a table using TableFileCatalog") { + withTable("test") { + sql("CREATE TABLE test(i int) PARTITIONED BY (p int) STORED AS parquet") + val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test") + val tableFileCatalog = new TableFileCatalog(spark, tableMeta, 0) + + val dataSchema = StructType(tableMeta.schema.filterNot { f => + tableMeta.partitionColumnNames.contains(f.name) + }) + val relation = HadoopFsRelation( + location = tableFileCatalog, + partitionSchema = tableMeta.partitionSchema, + dataSchema = dataSchema, + bucketSpec = None, + fileFormat = new ParquetFileFormat(), + options = Map.empty)(sparkSession = spark) + + val plan = LogicalRelation(relation, catalogTable = Some(tableMeta)) + spark.sharedState.cacheManager.cacheQuery(Dataset.ofRows(spark, plan)) + + assert(spark.sharedState.cacheManager.lookupCachedData(plan).isDefined) + + val sameCatalog = new TableFileCatalog(spark, tableMeta, 0) + val sameRelation = HadoopFsRelation( + location = sameCatalog, + partitionSchema = tableMeta.partitionSchema, + dataSchema = dataSchema, + bucketSpec = None, + fileFormat = new ParquetFileFormat(), + options = Map.empty)(sparkSession = spark) + val samePlan = LogicalRelation(sameRelation, catalogTable = Some(tableMeta)) + + assert(spark.sharedState.cacheManager.lookupCachedData(samePlan).isDefined) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index 346ea0ca4367..59639aacf3a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -45,12 +45,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te |LOCATION '${dir.getAbsolutePath}'""".stripMargin) val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test") - val tableFileCatalog = new TableFileCatalog( - spark, - tableMeta.database, - tableMeta.identifier.table, - Some(tableMeta.partitionSchema), - 0) + val tableFileCatalog = new TableFileCatalog(spark, tableMeta, 0) val dataSchema = StructType(tableMeta.schema.filterNot { f => tableMeta.partitionColumnNames.contains(f.name) From d479c5262276b47302659bd877a9e3467400bdb6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 25 Oct 2016 10:47:11 +0800 Subject: [PATCH 004/381] [SPARK-17409][SQL][FOLLOW-UP] Do Not Optimize Query in CTAS More Than Once ### What changes were proposed in this pull request? This follow-up PR is for addressing the [comment](https://github.com/apache/spark/pull/15048). We added two test cases based on the suggestion from yhuai . One is a new test case using the `saveAsTable` API to create a data source table. Another is for CTAS on Hive serde table. Note: No need to backport this PR to 2.0. Will submit a new PR to backport the whole fix with new test cases to Spark 2.0 ### How was this patch tested? N/A Author: gatorsmile Closes #15459 from gatorsmile/ctasOptimizedTestCases. --- .../org/apache/spark/sql/DataFrameSuite.scala | 18 +++++++++++++++++ .../sources/CreateTableAsSelectSuite.scala | 2 +- .../sql/hive/MetastoreRelationSuite.scala | 20 +++++++++++++++++-- 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e87baa454c8b..3fb7eeefba67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1599,6 +1599,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100) } + test("SPARK-17409: Do Not Optimize Query in CTAS (Data source tables) More Than Once") { + withTable("bar") { + withTempView("foo") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { + sql("select 0 as id").createOrReplaceTempView("foo") + val df = sql("select * from foo group by id") + // If we optimize the query in CTAS more than once, the following saveAsTable will fail + // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])` + df.write.mode("overwrite").saveAsTable("bar") + checkAnswer(spark.table("bar"), Row(0) :: Nil) + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar")) + assert(tableMetadata.provider == Some("json"), + "the expected table is a data source table using json") + } + } + } + } + test("copy results for sampling with replacement") { val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b") val sampleDf = df.sample(true, 2.00) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index c39005f6a106..5cc9467395ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -238,7 +238,7 @@ class CreateTableAsSelectSuite } } - test("CTAS of decimal calculation") { + test("SPARK-17409: CTAS of decimal calculation") { withTable("tab2") { withTempView("tab1") { spark.range(99, 101).createOrReplaceTempView("tab1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala index c28e41a85c39..91ff711445e8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.hive -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -class MetastoreRelationSuite extends SparkFunSuite { +class MetastoreRelationSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("makeCopy and toJSON should work") { val table = CatalogTable( identifier = TableIdentifier("test", Some("db")), @@ -36,4 +38,18 @@ class MetastoreRelationSuite extends SparkFunSuite { // No exception should be thrown relation.toJSON } + + test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") { + withTable("bar") { + withTempView("foo") { + sql("select 0 as id").createOrReplaceTempView("foo") + // If we optimize the query in CTAS more than once, the following saveAsTable will fail + // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])` + sql("CREATE TABLE bar AS SELECT * FROM foo group by id") + checkAnswer(spark.table("bar"), Row(0) :: Nil) + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar")) + assert(tableMetadata.provider == Some("hive"), "the expected table is a Hive serde table") + } + } + } } From 483c37c581fedc64b218e294ecde1a7bb4b2af9c Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Mon, 24 Oct 2016 20:16:00 -0700 Subject: [PATCH 005/381] [SPARK-17894][HOTFIX] Fix broken build from The named parameter in an overridden class isn't supported in Scala 2.10 so was breaking the build. cc zsxwing Author: Kay Ousterhout Closes #15617 from kayousterhout/hotfix. --- core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index f395fe9804c9..a75704129941 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -22,7 +22,7 @@ import org.apache.spark.TaskContext class FakeTask( stageId: Int, partitionId: Int, - prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, stageAttemptId = 0, partitionId) { + prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, partitionId) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs } From 78d740a08a04b74b49b5cba4bb6a821631390ab4 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 24 Oct 2016 23:47:59 -0700 Subject: [PATCH 006/381] [SPARK-17748][ML] One pass solver for Weighted Least Squares with ElasticNet ## What changes were proposed in this pull request? 1. Make a pluggable solver interface for `WeightedLeastSquares` 2. Add a `QuasiNewton` solver to handle elastic net regularization for `WeightedLeastSquares` 3. Add method `BLAS.dspmv` used by QN solver 4. Add mechanism for WLS to handle singular covariance matrices by falling back to QN solver when Cholesky fails. ## How was this patch tested? Unit tests - see below. ## Design choices **Pluggable Normal Solver** Before, the `WeightedLeastSquares` package always used the Cholesky decomposition solver to compute the solution to the normal equations. Now, we specify the solver as a constructor argument to the `WeightedLeastSquares`. We introduce a new trait: ````scala private[ml] sealed trait NormalEquationSolver { def solve( bBar: Double, bbBar: Double, abBar: DenseVector, aaBar: DenseVector, aBar: DenseVector): NormalEquationSolution } ```` We extend this trait for different variants of normal equation solvers. In the future, we can easily add others (like QR) using this interface. **Always train in the standardized space** The normal solver did not previously standardize the data, but this patch introduces a change such that we always solve the normal equations in the standardized space. We convert back to the original space in the same way that is done for distributed L-BFGS/OWL-QN. We add test cases for zero variance features/labels. **Use L-BFGS locally to solve normal equations for singular matrix** When linear regression with the normal solver is called for a singular matrix, we initially try to solve with Cholesky. We use the output of `lapack.dppsv` to determine if the matrix is singular. If it is, we fall back to using L-BFGS locally to solve the normal equations. We add test cases for this as well. ## Test cases I found it helpful to enumerate some of the test cases and hopefully it makes review easier. **WeightedLeastSquares** 1. Constant columns - Cholesky solver fails with no regularization, Auto solver falls back to QN, and QN trains successfully. 2. Collinear features - Cholesky solver fails with no regularization, Auto solver falls back to QN, and QN trains successfully. 3. Label is constant zero - no training is performed regardless of intercept. Coefficients are zero and intercept is zero. 4. Label is constant - if fitIntercept, then no training is performed and intercept equals label mean. If not fitIntercept, then we train and return an answer that matches R's lm package. 5. Test with L1 - go through various combinations of L1/L2, standardization, fitIntercept and verify that output matches glmnet. 6. Initial intercept - verify that setting the initial intercept to label mean is correct by training model with strong L1 regularization so that all coefficients are zero and intercept converges to label mean. 7. Test diagInvAtWA - since we are standardizing features now during training, we should test that the inverse is computed to match R. **LinearRegression** 1. For all existing L1 test cases, test the "normal" solver too. 2. Check that using the normal solver now handles singular matrices. 3. Check that using the normal solver with L1 produces an objective history in the model summary, but does not produce the inverse of AtA. **BLAS** 1. Test new method `dspmv`. ## Performance Testing This patch will speed up linear regression with L1/elasticnet penalties when the feature size is < 4096. I have not conducted performance tests at scale, only observed by testing locally that there is a speed improvement. We should decide if this PR needs to be blocked before performance testing is conducted. Author: sethah Closes #15394 from sethah/SPARK-17748. --- .../org/apache/spark/ml/linalg/BLAS.scala | 18 + .../apache/spark/ml/linalg/BLASSuite.scala | 45 ++ .../IterativelyReweightedLeastSquares.scala | 4 +- .../spark/ml/optim/NormalEquationSolver.scala | 163 +++++++ .../spark/ml/optim/WeightedLeastSquares.scala | 270 +++++++++-- .../GeneralizedLinearRegression.scala | 4 +- .../ml/regression/LinearRegression.scala | 20 +- .../mllib/linalg/CholeskyDecomposition.scala | 4 +- ...erativelyReweightedLeastSquaresSuite.scala | 6 +- .../ml/optim/WeightedLeastSquaresSuite.scala | 400 ++++++++++++++-- .../ml/regression/LinearRegressionSuite.scala | 431 +++++++++--------- 11 files changed, 1057 insertions(+), 308 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 4ca19f3387f0..ef3890962494 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -243,6 +243,24 @@ private[spark] object BLAS extends Serializable { spr(alpha, v, U.values) } + /** + * y := alpha*A*x + beta*y + * + * @param n The order of the n by n matrix A. + * @param A The upper triangular part of A in a [[DenseVector]] (column major). + * @param x The [[DenseVector]] transformed by A. + * @param y The [[DenseVector]] to be modified in place. + */ + def dspmv( + n: Int, + alpha: Double, + A: DenseVector, + x: DenseVector, + beta: Double, + y: DenseVector): Unit = { + f2jBLAS.dspmv("U", n, alpha, A.values, x.values, 1, beta, y.values, 1) + } + /** * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. * diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala index 6e72a5fff0a9..877ac6898334 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala @@ -422,4 +422,49 @@ class BLASSuite extends SparkMLFunSuite { assert(dATT.multiply(sx) ~== expected absTol 1e-15) assert(sATT.multiply(sx) ~== expected absTol 1e-15) } + + test("spmv") { + /* + A = [[3.0, -2.0, 2.0, -4.0], + [-2.0, -8.0, 4.0, 7.0], + [2.0, 4.0, -3.0, -3.0], + [-4.0, 7.0, -3.0, 0.0]] + x = [5.0, 2.0, -1.0, -9.0] + Ax = [ 45., -93., 48., -3.] + */ + val A = new DenseVector(Array(3.0, -2.0, -8.0, 2.0, 4.0, -3.0, -4.0, 7.0, -3.0, 0.0)) + val x = new DenseVector(Array(5.0, 2.0, -1.0, -9.0)) + val n = 4 + + val y1 = new DenseVector(Array(-3.0, 6.0, -8.0, -3.0)) + val y2 = y1.copy + val y3 = y1.copy + val y4 = y1.copy + val y5 = y1.copy + val y6 = y1.copy + val y7 = y1.copy + + val expected1 = new DenseVector(Array(42.0, -87.0, 40.0, -6.0)) + val expected2 = new DenseVector(Array(19.5, -40.5, 16.0, -4.5)) + val expected3 = new DenseVector(Array(-25.5, 52.5, -32.0, -1.5)) + val expected4 = new DenseVector(Array(-3.0, 6.0, -8.0, -3.0)) + val expected5 = new DenseVector(Array(43.5, -90.0, 44.0, -4.5)) + val expected6 = new DenseVector(Array(46.5, -96.0, 52.0, -1.5)) + val expected7 = new DenseVector(Array(45.0, -93.0, 48.0, -3.0)) + + dspmv(n, 1.0, A, x, 1.0, y1) + dspmv(n, 0.5, A, x, 1.0, y2) + dspmv(n, -0.5, A, x, 1.0, y3) + dspmv(n, 0.0, A, x, 1.0, y4) + dspmv(n, 1.0, A, x, 0.5, y5) + dspmv(n, 1.0, A, x, -0.5, y6) + dspmv(n, 1.0, A, x, 0.0, y7) + assert(y1 ~== expected1 absTol 1e-8) + assert(y2 ~== expected2 absTol 1e-8) + assert(y3 ~== expected3 absTol 1e-8) + assert(y4 ~== expected4 absTol 1e-8) + assert(y5 ~== expected5 absTol 1e-8) + assert(y6 ~== expected6 absTol 1e-8) + assert(y7 ~== expected7 absTol 1e-8) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index d732f53029e8..8a6b862cda17 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -81,8 +81,8 @@ private[ml] class IterativelyReweightedLeastSquares( } // Estimate new model - model = new WeightedLeastSquares(fitIntercept, regParam, standardizeFeatures = false, - standardizeLabel = false).fit(newInstances) + model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(newInstances) // Check convergence val oldCoefficients = oldModel.coefficients diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala new file mode 100644 index 000000000000..2f5299b01022 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim + +import breeze.linalg.{DenseVector => BDV} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import scala.collection.mutable + +import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vectors} +import org.apache.spark.mllib.linalg.CholeskyDecomposition + +/** + * A class to hold the solution to the normal equations A^T^ W A x = A^T^ W b. + * + * @param coefficients The least squares coefficients. The last element in the coefficients + * is the intercept when bias is added to A. + * @param aaInv An option containing the upper triangular part of (A^T^ W A)^-1^, in column major + * format. None when an optimization program is used to solve the normal equations. + * @param objectiveHistory Option containing the objective history when an optimization program is + * used to solve the normal equations. None when an analytic solver is used. + */ +private[ml] class NormalEquationSolution( + val coefficients: Array[Double], + val aaInv: Option[Array[Double]], + val objectiveHistory: Option[Array[Double]]) + +/** + * Interface for classes that solve the normal equations locally. + */ +private[ml] sealed trait NormalEquationSolver { + + /** Solve the normal equations from summary statistics. */ + def solve( + bBar: Double, + bbBar: Double, + abBar: DenseVector, + aaBar: DenseVector, + aBar: DenseVector): NormalEquationSolution +} + +/** + * A class that solves the normal equations directly, using Cholesky decomposition. + */ +private[ml] class CholeskySolver extends NormalEquationSolver { + + def solve( + bBar: Double, + bbBar: Double, + abBar: DenseVector, + aaBar: DenseVector, + aBar: DenseVector): NormalEquationSolution = { + val k = abBar.size + val x = CholeskyDecomposition.solve(aaBar.values, abBar.values) + val aaInv = CholeskyDecomposition.inverse(aaBar.values, k) + + new NormalEquationSolution(x, Some(aaInv), None) + } +} + +/** + * A class for solving the normal equations using Quasi-Newton optimization methods. + */ +private[ml] class QuasiNewtonSolver( + fitIntercept: Boolean, + maxIter: Int, + tol: Double, + l1RegFunc: Option[(Int) => Double]) extends NormalEquationSolver { + + def solve( + bBar: Double, + bbBar: Double, + abBar: DenseVector, + aaBar: DenseVector, + aBar: DenseVector): NormalEquationSolution = { + val numFeatures = aBar.size + val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures + val initialCoefficientsWithIntercept = new Array[Double](numFeaturesPlusIntercept) + if (fitIntercept) { + initialCoefficientsWithIntercept(numFeaturesPlusIntercept - 1) = bBar + } + + val costFun = + new NormalEquationCostFun(bBar, bbBar, abBar, aaBar, aBar, fitIntercept, numFeatures) + val optimizer = l1RegFunc.map { func => + new BreezeOWLQN[Int, BDV[Double]](maxIter, 10, func, tol) + }.getOrElse(new BreezeLBFGS[BDV[Double]](maxIter, 10, tol)) + + val states = optimizer.iterations(new CachedDiffFunction(costFun), + new BDV[Double](initialCoefficientsWithIntercept)) + + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } + val x = state.x.toArray.clone() + new NormalEquationSolution(x, None, Some(arrayBuilder.result())) + } + + /** + * NormalEquationCostFun implements Breeze's DiffFunction[T] for the normal equation. + * It returns the loss and gradient with L2 regularization at a particular point (coefficients). + * It's used in Breeze's convex optimization routines. + */ + private class NormalEquationCostFun( + bBar: Double, + bbBar: Double, + ab: DenseVector, + aa: DenseVector, + aBar: DenseVector, + fitIntercept: Boolean, + numFeatures: Int) extends DiffFunction[BDV[Double]] { + + private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures + + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + val coef = Vectors.fromBreeze(coefficients).toDense + if (fitIntercept) { + var j = 0 + var dotProd = 0.0 + val coefValues = coef.values + val aBarValues = aBar.values + while (j < numFeatures) { + dotProd += coefValues(j) * aBarValues(j) + j += 1 + } + coefValues(numFeatures) = bBar - dotProd + } + val aax = new DenseVector(new Array[Double](numFeaturesPlusIntercept)) + BLAS.dspmv(numFeaturesPlusIntercept, 1.0, aa, coef, 1.0, aax) + // loss = 1/2 (b^T W b - 2 x^T A^T W b + x^T A^T W A x) + val loss = 0.5 * bbBar - BLAS.dot(ab, coef) + 0.5 * BLAS.dot(coef, aax) + // gradient = A^T W A x - A^T W b + BLAS.axpy(-1.0, ab, aax) + (loss, aax.asBreeze.toDenseVector) + } + } +} + +/** + * Exception thrown when solving a linear system Ax = b for which the matrix A is non-invertible + * (singular). + */ +class SingularMatrixException(message: String, cause: Throwable) + extends IllegalArgumentException(message, cause) { + + def this(message: String) = this(message, null) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 8f5f4427e1f4..2223f126f1b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -20,19 +20,21 @@ package org.apache.spark.ml.optim import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ -import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.rdd.RDD /** * Model fitted by [[WeightedLeastSquares]]. + * * @param coefficients model coefficients * @param intercept model intercept * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ private[ml] class WeightedLeastSquaresModel( val coefficients: DenseVector, val intercept: Double, - val diagInvAtWA: DenseVector) extends Serializable { + val diagInvAtWA: DenseVector, + val objectiveHistory: Array[Double]) extends Serializable { def predict(features: Vector): Double = { BLAS.dot(coefficients, features) + intercept @@ -44,35 +46,52 @@ private[ml] class WeightedLeastSquaresModel( * Given weighted observations (w,,i,,, a,,i,,, b,,i,,), we use the following weighted least squares * formulation: * - * min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w_i - * + 1/2 lambda / delta sum,,j,, (sigma,,j,, x,,j,,)^2^, + * min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w,,i,, + * + lambda / delta (1/2 (1 - alpha) sumj,, (sigma,,j,, x,,j,,)^2^ + * + alpha sum,,j,, abs(sigma,,j,, x,,j,,)), * - * where lambda is the regularization parameter, and delta and sigma,,j,, are controlled by - * [[standardizeLabel]] and [[standardizeFeatures]], respectively. + * where lambda is the regularization parameter, alpha is the ElasticNet mixing parameter, + * and delta and sigma,,j,, are controlled by [[standardizeLabel]] and [[standardizeFeatures]], + * respectively. * * Set [[regParam]] to 0.0 and turn off both [[standardizeFeatures]] and [[standardizeLabel]] to * match R's `lm`. * Turn on [[standardizeLabel]] to match R's `glmnet`. * + * @note The coefficients and intercept are always trained in the scaled space, but are returned + * on the original scale. [[standardizeFeatures]] and [[standardizeLabel]] can be used to + * control whether regularization is applied in the original space or the scaled space. * @param fitIntercept whether to fit intercept. If false, z is 0.0. - * @param regParam L2 regularization parameter (lambda) - * @param standardizeFeatures whether to standardize features. If true, sigma_,,j,, is the + * @param regParam Regularization parameter (lambda). + * @param elasticNetParam the ElasticNet mixing parameter (alpha). + * @param standardizeFeatures whether to standardize features. If true, sigma,,j,, is the * population standard deviation of the j-th column of A. Otherwise, * sigma,,j,, is 1.0. * @param standardizeLabel whether to standardize label. If true, delta is the population standard * deviation of the label column b. Otherwise, delta is 1.0. + * @param solverType the type of solver to use for optimization. + * @param maxIter maximum number of iterations. Only for QuasiNewton solverType. + * @param tol the convergence tolerance of the iterations. Only for QuasiNewton solverType. */ private[ml] class WeightedLeastSquares( val fitIntercept: Boolean, val regParam: Double, + val elasticNetParam: Double, val standardizeFeatures: Boolean, - val standardizeLabel: Boolean) extends Logging with Serializable { + val standardizeLabel: Boolean, + val solverType: WeightedLeastSquares.Solver = WeightedLeastSquares.Auto, + val maxIter: Int = 100, + val tol: Double = 1e-6) extends Logging with Serializable { import WeightedLeastSquares._ require(regParam >= 0.0, s"regParam cannot be negative: $regParam") if (regParam == 0.0) { logWarning("regParam is zero, which might cause numerical instability and overfitting.") } + require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0, + s"elasticNetParam must be in [0, 1]: $elasticNetParam") + require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter") + require(tol > 0, s"tol must be greater than zero: $tol") /** * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. @@ -85,73 +104,198 @@ private[ml] class WeightedLeastSquares( val triK = summary.triK val wSum = summary.wSum val bBar = summary.bBar - val bStd = summary.bStd + val bbBar = summary.bbBar val aBar = summary.aBar - val aVar = summary.aVar + val aStd = summary.aStd val abBar = summary.abBar val aaBar = summary.aaBar - val aaValues = aaBar.values - - if (bStd == 0) { - if (fitIntercept) { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + - s"training is not needed.") - val coefficients = new DenseVector(Array.ofDim(k-1)) + val numFeatures = abBar.size + val rawBStd = summary.bStd + // if b is constant (rawBStd is zero), then b cannot be scaled. In this case + // setting bStd=abs(bBar) ensures that b is not scaled anymore in l-bfgs algorithm. + val bStd = if (rawBStd == 0.0) math.abs(bBar) else rawBStd + + if (rawBStd == 0) { + if (fitIntercept || bBar == 0.0) { + if (bBar == 0.0) { + logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + + s"and the intercept will all be zero; as a result, training is not needed.") + } else { + logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + + s"zeros and the intercept will be the mean of the label; as a result, " + + s"training is not needed.") + } + val coefficients = new DenseVector(Array.ofDim(numFeatures)) val intercept = bBar val diagInvAtWA = new DenseVector(Array(0D)) - return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA) + return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA, Array(0D)) + } else { + require(!(regParam > 0.0 && standardizeLabel), "The standard deviation of the label is " + + "zero. Model cannot be regularized with standardization=true") + logWarning(s"The standard deviation of the label is zero. Consider setting " + + s"fitIntercept=true.") + } + } + + // scale aBar to standardized space in-place + val aBarValues = aBar.values + var j = 0 + while (j < numFeatures) { + if (aStd(j) == 0.0) { + aBarValues(j) = 0.0 } else { - require(!(regParam > 0.0 && standardizeLabel), - "The standard deviation of the label is zero. " + - "Model cannot be regularized with standardization=true") - logWarning(s"The standard deviation of the label is zero. " + - "Consider setting fitIntercept=true.") + aBarValues(j) /= aStd(j) + } + j += 1 + } + + // scale abBar to standardized space in-place + val abBarValues = abBar.values + val aStdValues = aStd.values + j = 0 + while (j < numFeatures) { + if (aStdValues(j) == 0.0) { + abBarValues(j) = 0.0 + } else { + abBarValues(j) /= (aStdValues(j) * bStd) + } + j += 1 + } + + // scale aaBar to standardized space in-place + val aaBarValues = aaBar.values + j = 0 + var p = 0 + while (j < numFeatures) { + val aStdJ = aStdValues(j) + var i = 0 + while (i <= j) { + val aStdI = aStdValues(i) + if (aStdJ == 0.0 || aStdI == 0.0) { + aaBarValues(p) = 0.0 + } else { + aaBarValues(p) /= (aStdI * aStdJ) + } + p += 1 + i += 1 } + j += 1 } - // add regularization to diagonals + val bBarStd = bBar / bStd + val bbBarStd = bbBar / (bStd * bStd) + + val effectiveRegParam = regParam / bStd + val effectiveL1RegParam = elasticNetParam * effectiveRegParam + val effectiveL2RegParam = (1.0 - elasticNetParam) * effectiveRegParam + + // add L2 regularization to diagonals var i = 0 - var j = 2 + j = 2 while (i < triK) { - var lambda = regParam - if (standardizeFeatures) { - lambda *= aVar(j - 2) + var lambda = effectiveL2RegParam + if (!standardizeFeatures) { + val std = aStd(j - 2) + if (std != 0.0) { + lambda /= (std * std) + } else { + lambda = 0.0 + } } - if (standardizeLabel && bStd != 0) { - lambda /= bStd + if (!standardizeLabel) { + lambda *= bStd } - aaValues(i) += lambda + aaBarValues(i) += lambda i += j j += 1 } + val aa = getAtA(aaBar.values, aBar.values) + val ab = getAtB(abBar.values, bBarStd) - val aa = if (fitIntercept) { - Array.concat(aaBar.values, aBar.values, Array(1.0)) + val solver = if ((solverType == WeightedLeastSquares.Auto && elasticNetParam != 0.0 && + regParam != 0.0) || (solverType == WeightedLeastSquares.QuasiNewton)) { + val effectiveL1RegFun: Option[(Int) => Double] = if (effectiveL1RegParam != 0.0) { + Some((index: Int) => { + if (fitIntercept && index == numFeatures) { + 0.0 + } else { + if (standardizeFeatures) { + effectiveL1RegParam + } else { + if (aStdValues(index) != 0.0) effectiveL1RegParam / aStdValues(index) else 0.0 + } + } + }) + } else { + None + } + new QuasiNewtonSolver(fitIntercept, maxIter, tol, effectiveL1RegFun) } else { - aaBar.values + new CholeskySolver + } + + val solution = solver match { + case cholesky: CholeskySolver => + try { + cholesky.solve(bBarStd, bbBarStd, ab, aa, aBar) + } catch { + // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to + // quasi-newton solver + case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto => + logWarning("Cholesky solver failed due to singular covariance matrix. " + + "Retrying with Quasi-Newton solver.") + // ab and aa were modified in place, so reconstruct them + val _aa = getAtA(aaBar.values, aBar.values) + val _ab = getAtB(abBar.values, bBarStd) + val newSolver = new QuasiNewtonSolver(fitIntercept, maxIter, tol, None) + newSolver.solve(bBarStd, bbBarStd, _ab, _aa, aBar) + } + case qn: QuasiNewtonSolver => + qn.solve(bBarStd, bbBarStd, ab, aa, aBar) } - val ab = if (fitIntercept) { - Array.concat(abBar.values, Array(bBar)) + val (coefficientArray, intercept) = if (fitIntercept) { + (solution.coefficients.slice(0, solution.coefficients.length - 1), + solution.coefficients.last * bStd) } else { - abBar.values + (solution.coefficients, 0.0) } - val x = CholeskyDecomposition.solve(aa, ab) - - val aaInv = CholeskyDecomposition.inverse(aa, k) + // convert the coefficients from the scaled space to the original space + var q = 0 + val len = coefficientArray.length + while (q < len) { + coefficientArray(q) *= { if (aStdValues(q) != 0.0) bStd / aStdValues(q) else 0.0 } + q += 1 + } // aaInv is a packed upper triangular matrix, here we get all elements on diagonal - val diagInvAtWA = new DenseVector((1 to k).map { i => - aaInv(i + (i - 1) * i / 2 - 1) / wSum }.toArray) + val diagInvAtWA = solution.aaInv.map { inv => + new DenseVector((1 to k).map { i => + val multiplier = if (i == k && fitIntercept) 1.0 else aStdValues(i - 1) * aStdValues(i - 1) + inv(i + (i - 1) * i / 2 - 1) / (wSum * multiplier) + }.toArray) + }.getOrElse(new DenseVector(Array(0D))) - val (coefficients, intercept) = if (fitIntercept) { - (new DenseVector(x.slice(0, x.length - 1)), x.last) + new WeightedLeastSquaresModel(new DenseVector(coefficientArray), intercept, diagInvAtWA, + solution.objectiveHistory.getOrElse(Array(0D))) + } + + /** Construct A^T^ A from summary statistics. */ + private def getAtA(aaBar: Array[Double], aBar: Array[Double]): DenseVector = { + if (fitIntercept) { + new DenseVector(Array.concat(aaBar, aBar, Array(1.0))) } else { - (new DenseVector(x), 0.0) + new DenseVector(aaBar.clone()) } + } - new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA) + /** Construct A^T^ b from summary statistics. */ + private def getAtB(abBar: Array[Double], bBar: Double): DenseVector = { + if (fitIntercept) { + new DenseVector(Array.concat(abBar, Array(bBar))) + } else { + new DenseVector(abBar.clone()) + } } } @@ -163,6 +307,13 @@ private[ml] object WeightedLeastSquares { */ val MAX_NUM_FEATURES: Int = 4096 + sealed trait Solver + case object Auto extends Solver + case object Cholesky extends Solver + case object QuasiNewton extends Solver + + val supportedSolvers = Array(Auto, Cholesky, QuasiNewton) + /** * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]]. */ @@ -262,6 +413,11 @@ private[ml] object WeightedLeastSquares { */ def bBar: Double = bSum / wSum + /** + * Weighted mean of squared labels. + */ + def bbBar: Double = bbSum / wSum + /** * Weighted population standard deviation of labels. */ @@ -285,6 +441,24 @@ private[ml] object WeightedLeastSquares { output } + /** + * Weighted population standard deviation of features. + */ + def aStd: DenseVector = { + val std = Array.ofDim[Double](k) + var i = 0 + var j = 2 + val aaValues = aaSum.values + while (i < triK) { + val l = j - 2 + val aw = aSum(l) / wSum + std(l) = math.sqrt(aaValues(i) / wSum - aw * aw) + i += j + j += 1 + } + new DenseVector(std) + } + /** * Weighted population variance of features. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index bb9e150c4977..33cb25c8c7f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -262,7 +262,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val if (familyObj == Gaussian && linkObj == Identity) { // TODO: Make standardizeFeatures and standardizeLabel configurable. - val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), + val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) val wlsModel = optimizer.fit(instances) val model = copyValues( @@ -337,7 +337,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine Instance(eta, instance.weight, instance.features) } // TODO: Make standardizeFeatures and standardizeLabel configurable. - val initialModel = new WeightedLeastSquares(fitIntercept, regParam, + val initialModel = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) .fit(newInstances) initialModel diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 025ed20c75a0..519f3bdec82d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ -import org.apache.spark.ml.optim.WeightedLeastSquares +import org.apache.spark.ml.optim.{NormalEquationSolver, WeightedLeastSquares} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ @@ -177,6 +177,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String * If the dimensions of features or the number of partitions are large, * this param could be adjusted to a larger size. * Default is 2. + * * @group expertSetParam */ @Since("2.1.0") @@ -194,21 +195,18 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String Instance(label, weight, features) } - if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && + if (($(solver) == "auto" && numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { - require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " + - "solver is used.'") - // For low dimensional data, WeightedLeastSquares is more efficiently since the + // For low dimensional data, WeightedLeastSquares is more efficient since the // training algorithm only requires one pass through the data. (SPARK-10668) val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), - $(standardization), true) + elasticNetParam = $(elasticNetParam), $(standardization), true, + solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol)) val model = optimizer.fit(instances) // When it is trained by WeightedLeastSquares, training summary does not - // attached returned model. + // attach returned model. val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept)) - // WeightedLeastSquares does not run through iterations. So it does not generate - // an objective history. val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol() val trainingSummary = new LinearRegressionTrainingSummary( summaryModel.transform(dataset), @@ -217,7 +215,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String $(featuresCol), summaryModel, model.diagInvAtWA.toArray, - Array(0D)) + model.objectiveHistory) return lrModel.setSummary(trainingSummary) } @@ -243,7 +241,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val yMean = ySummarizer.mean(0) val rawYStd = math.sqrt(ySummarizer.variance(0)) if (rawYStd == 0.0) { - if ($(fitIntercept) || yMean==0.0) { + if ($(fitIntercept) || yMean == 0.0) { // If the rawYStd is zero and fitIntercept=true, then the intercept is yMean with // zero coefficient; as a result, training is not needed. // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala index 08f8f19c1e77..68771f1afbe8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala @@ -20,6 +20,8 @@ package org.apache.spark.mllib.linalg import com.github.fommil.netlib.LAPACK.{getInstance => lapack} import org.netlib.util.intW +import org.apache.spark.ml.optim.SingularMatrixException + /** * Compute Cholesky decomposition. */ @@ -60,7 +62,7 @@ private[spark] object CholeskyDecomposition { case code if code < 0 => throw new IllegalStateException(s"LAPACK.$method returned $code; arg ${-code} is illegal") case code if code > 0 => - throw new IllegalArgumentException( + throw new SingularMatrixException ( s"LAPACK.$method returned $code because A is not positive definite. Is A derived from " + "a singular matrix (e.g. collinear column values)?") case _ => // do nothing diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala index b30d995794d4..50260952ecb6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala @@ -85,7 +85,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes val eta = math.log(mu / (1.0 - mu)) Instance(eta, instance.weight, instance.features) } - val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, standardizeFeatures = false, standardizeLabel = false).fit(newInstances) val irls = new IterativelyReweightedLeastSquares(initial, BinomialReweightFunc, fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances1) @@ -122,7 +122,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes val eta = math.log(mu) Instance(eta, instance.weight, instance.features) } - val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, standardizeFeatures = false, standardizeLabel = false).fit(newInstances) val irls = new IterativelyReweightedLeastSquares(initial, PoissonReweightFunc, fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances2) @@ -155,7 +155,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes var idx = 0 for (fitIntercept <- Seq(false, true)) { - val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, standardizeFeatures = false, standardizeLabel = false).fit(instances2) val irls = new IterativelyReweightedLeastSquares(initial, L1RegressionReweightFunc, fitIntercept, regParam = 0.0, maxIter = 200, tol = 1e-7).fit(instances2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala index 2cb1af0dee0b..5f638b488005 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.optim import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{BLAS, Vectors} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD @@ -28,6 +28,9 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext private var instances: RDD[Instance] = _ private var instancesConstLabel: RDD[Instance] = _ + private var instancesConstZeroLabel: RDD[Instance] = _ + private var collinearInstances: RDD[Instance] = _ + private var constantFeaturesInstances: RDD[Instance] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -58,26 +61,121 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2) - } - test("two collinear features result in error with no regularization") { - val singularInstances = sc.parallelize(Seq( + /* + A <- matrix(c(1, 2, 3, 4, 2, 4, 6, 8), 4, 2) + b <- c(1, 2, 3, 4) + w <- c(1, 1, 1, 1) + */ + collinearInstances = sc.parallelize(Seq( Instance(1.0, 1.0, Vectors.dense(1.0, 2.0)), Instance(2.0, 1.0, Vectors.dense(2.0, 4.0)), Instance(3.0, 1.0, Vectors.dense(3.0, 6.0)), Instance(4.0, 1.0, Vectors.dense(4.0, 8.0)) ), 2) - intercept[IllegalArgumentException] { - new WeightedLeastSquares( - false, regParam = 0.0, standardizeFeatures = false, - standardizeLabel = false).fit(singularInstances) + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b.const <- c(0, 0, 0, 0) + w <- c(1, 2, 3, 4) + */ + instancesConstZeroLabel = sc.parallelize(Seq( + Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2) + + /* + R code: + + A <- matrix(c(1, 1, 1, 1, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + */ + constantFeaturesInstances = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(1.0, 5.0)), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(1.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(1.0, 13.0)) + ), 2) + } + + test("WLS with strong L1 regularization") { + /* + We initialize the coefficients for WLS QN solver to be weighted average of the label. Check + here that with only an intercept the model converges to bBar. + */ + val bAgg = instances.collect().foldLeft((0.0, 0.0)) { + case ((sum, weightSum), Instance(l, w, f)) => (sum + w * l, weightSum + w) } + val bBar = bAgg._1 / bAgg._2 + val wls = new WeightedLeastSquares(true, 10, 1.0, true, true) + val model = wls.fit(instances) + assert(model.intercept ~== bBar relTol 1e-6) + } - // Should not throw an exception - new WeightedLeastSquares( - false, regParam = 1.0, standardizeFeatures = false, - standardizeLabel = false).fit(singularInstances) + test("diagonal inverse of AtWA") { + /* + library(Matrix) + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + w <- c(1, 2, 3, 4) + W <- Diagonal(length(w), w) + A.intercept <- cbind(A, rep.int(1, length(w))) + AtA.intercept <- t(A.intercept) %*% W %*% A.intercept + inv.intercept <- solve(AtA.intercept) + print(diag(inv.intercept)) + [1] 4.02 0.50 12.02 + + AtA <- t(A) %*% W %*% A + inv <- solve(AtA) + print(diag(inv)) + [1] 0.48336106 0.02079867 + + */ + val expectedWithIntercept = Vectors.dense(4.02, 0.50, 12.02) + val expected = Vectors.dense(0.48336106, 0.02079867) + val wlsWithIntercept = new WeightedLeastSquares(fitIntercept = true, regParam = 0.0, + elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true, + solverType = WeightedLeastSquares.Cholesky) + val wlsModelWithIntercept = wlsWithIntercept.fit(instances) + val wls = new WeightedLeastSquares(false, 0.0, 0.0, true, true, + solverType = WeightedLeastSquares.Cholesky) + val wlsModel = wls.fit(instances) + + assert(expectedWithIntercept ~== wlsModelWithIntercept.diagInvAtWA relTol 1e-4) + assert(expected ~== wlsModel.diagInvAtWA relTol 1e-4) + } + + test("two collinear features") { + // Cholesky solver does not handle singular input + intercept[SingularMatrixException] { + new WeightedLeastSquares(fitIntercept = false, regParam = 0.0, elasticNetParam = 0.0, + standardizeFeatures = false, standardizeLabel = false, + solverType = WeightedLeastSquares.Cholesky).fit(collinearInstances) + } + + // Cholesky should not throw an exception since regularization is applied + new WeightedLeastSquares(fitIntercept = false, regParam = 1.0, elasticNetParam = 0.0, + standardizeFeatures = false, standardizeLabel = false, + solverType = WeightedLeastSquares.Cholesky).fit(collinearInstances) + + // quasi-newton solvers should handle singular input and make correct predictions + // auto solver should try Cholesky first, then fall back to QN + for (fitIntercept <- Seq(false, true); + standardization <- Seq(false, true); + solver <- Seq(WeightedLeastSquares.Auto, WeightedLeastSquares.QuasiNewton)) { + val singularModel = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + elasticNetParam = 0.0, standardizeFeatures = standardization, + standardizeLabel = standardization, solverType = solver).fit(collinearInstances) + + collinearInstances.collect().foreach { case Instance(l, w, f) => + val pred = BLAS.dot(singularModel.coefficients, f) + singularModel.intercept + assert(pred ~== l absTol 1e-6) + } + } } test("WLS against lm") { @@ -100,13 +198,15 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext var idx = 0 for (fitIntercept <- Seq(false, true)) { - for (standardization <- Seq(false, true)) { - val wls = new WeightedLeastSquares( - fitIntercept, regParam = 0.0, standardizeFeatures = standardization, - standardizeLabel = standardization).fit(instances) - val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) - assert(actual ~== expected(idx) absTol 1e-4) - } + for (standardization <- Seq(false, true)) { + for (solver <- WeightedLeastSquares.supportedSolvers) { + val wls = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, + standardizeFeatures = standardization, standardizeLabel = standardization, + solverType = solver).fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + } + } idx += 1 } } @@ -132,28 +232,256 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext var idx = 0 for (fitIntercept <- Seq(false, true)) { for (standardization <- Seq(false, true)) { - val wls = new WeightedLeastSquares( - fitIntercept, regParam = 0.0, standardizeFeatures = standardization, - standardizeLabel = standardization).fit(instancesConstLabel) - val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) - assert(actual ~== expected(idx) absTol 1e-4) + for (solver <- WeightedLeastSquares.supportedSolvers) { + val wls = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, + standardizeFeatures = standardization, standardizeLabel = standardization, + solverType = solver).fit(instancesConstLabel) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + } } idx += 1 } + + // when label is constant zero, and fitIntercept is false, we should not train and get all zeros + for (solver <- WeightedLeastSquares.supportedSolvers) { + val wls = new WeightedLeastSquares(fitIntercept = false, regParam = 0.0, + elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true, + solverType = solver).fit(instancesConstZeroLabel) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual === Vectors.dense(0.0, 0.0, 0.0)) + assert(wls.objectiveHistory === Array(0.0)) + } } test("WLS with regularization when label is constant") { // if regParam is non-zero and standardization is true, the problem is ill-defined and // an exception is thrown. - val wls = new WeightedLeastSquares( - fitIntercept = false, regParam = 0.1, standardizeFeatures = true, - standardizeLabel = true) - intercept[IllegalArgumentException]{ - wls.fit(instancesConstLabel) + for (solver <- WeightedLeastSquares.supportedSolvers) { + val wls = new WeightedLeastSquares(fitIntercept = false, regParam = 0.1, + elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true, + solverType = solver) + intercept[IllegalArgumentException]{ + wls.fit(instancesConstLabel) + } } } - test("WLS against glmnet") { + test("WLS against glmnet with constant features") { + // Cholesky solver does not handle singular input with no regularization + for (fitIntercept <- Seq(false, true); + standardization <- Seq(false, true)) { + val wls = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, + standardizeFeatures = standardization, standardizeLabel = standardization, + solverType = WeightedLeastSquares.Cholesky) + intercept[SingularMatrixException] { + wls.fit(constantFeaturesInstances) + } + } + + // Cholesky also fails when regularization is added but we don't wish to standardize + val wls = new WeightedLeastSquares(true, regParam = 0.5, elasticNetParam = 0.0, + standardizeFeatures = false, standardizeLabel = false, + solverType = WeightedLeastSquares.Cholesky) + intercept[SingularMatrixException] { + wls.fit(constantFeaturesInstances) + } + + /* + for (intercept in c(FALSE, TRUE)) { + model <- glmnet(A, b, weights=w, intercept=intercept, lambda=0.5, + standardize=T, alpha=0.0, thresh=1E-14) + print(as.vector(coef(model))) + } + [1] 0.000000 0.000000 2.235802 + [1] 9.798771 0.000000 1.365503 + */ + // should not fail when regularization and standardization are added + val expectedCholesky = Seq( + Vectors.dense(0.0, 0.0, 2.235802), + Vectors.dense(9.798771, 0.0, 1.365503) + ) + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val wls = new WeightedLeastSquares(fitIntercept = fitIntercept, regParam = 0.5, + elasticNetParam = 0.0, standardizeFeatures = true, + standardizeLabel = true, solverType = WeightedLeastSquares.Cholesky) + .fit(constantFeaturesInstances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expectedCholesky(idx) absTol 1e-6) + idx += 1 + } + + /* + for (intercept in c(FALSE, TRUE)) { + for (standardize in c(FALSE, TRUE)) { + for (regParams in list(c(0.0, 0.0), c(0.5, 0.0), c(0.5, 0.5), c(0.5, 1.0))) { + model <- glmnet(A, b, weights=w, intercept=intercept, lambda=regParams[1], + standardize=standardize, alpha=regParams[2], thresh=1E-14) + print(as.vector(coef(model))) + } + } + } + [1] 0.000000 0.000000 2.253012 + [1] 0.000000 0.000000 2.250857 + [1] 0.000000 0.000000 2.249784 + [1] 0.000000 0.000000 2.248709 + [1] 0.000000 0.000000 2.253012 + [1] 0.000000 0.000000 2.235802 + [1] 0.000000 0.000000 2.238297 + [1] 0.000000 0.000000 2.240811 + [1] 8.218905 0.000000 1.517413 + [1] 8.434286 0.000000 1.496703 + [1] 8.648497 0.000000 1.476106 + [1] 8.865672 0.000000 1.455224 + [1] 8.218905 0.000000 1.517413 + [1] 9.798771 0.000000 1.365503 + [1] 9.919095 0.000000 1.353933 + [1] 10.052804 0.000000 1.341077 + */ + val expectedQuasiNewton = Seq( + Vectors.dense(0.000000, 0.000000, 2.253012), + Vectors.dense(0.000000, 0.000000, 2.250857), + Vectors.dense(0.000000, 0.000000, 2.249784), + Vectors.dense(0.000000, 0.000000, 2.248709), + Vectors.dense(0.000000, 0.000000, 2.253012), + Vectors.dense(0.000000, 0.000000, 2.235802), + Vectors.dense(0.000000, 0.000000, 2.238297), + Vectors.dense(0.000000, 0.000000, 2.240811), + Vectors.dense(8.218905, 0.000000, 1.517413), + Vectors.dense(8.434286, 0.000000, 1.496703), + Vectors.dense(8.648497, 0.000000, 1.476106), + Vectors.dense(8.865672, 0.000000, 1.455224), + Vectors.dense(8.218905, 0.000000, 1.517413), + Vectors.dense(9.798771, 0.000000, 1.365503), + Vectors.dense(9.919095, 0.000000, 1.353933), + Vectors.dense(10.052804, 0.000000, 1.341077)) + + idx = 0 + for (fitIntercept <- Seq(false, true); + standardization <- Seq(false, true); + (lambda, alpha) <- Seq((0.0, 0.0), (0.5, 0.0), (0.5, 0.5), (0.5, 1.0))) { + for (solver <- Seq(WeightedLeastSquares.Auto, WeightedLeastSquares.Cholesky)) { + val wls = new WeightedLeastSquares(fitIntercept, regParam = lambda, elasticNetParam = alpha, + standardizeFeatures = standardization, standardizeLabel = true, + solverType = WeightedLeastSquares.QuasiNewton) + val model = wls.fit(constantFeaturesInstances) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~== expectedQuasiNewton(idx) absTol 1e-6) + } + idx += 1 + } + } + + test("WLS against glmnet with L1/ElasticNet regularization") { + /* + R code: + + library(glmnet) + + for (intercept in c(FALSE, TRUE)) { + for (lambda in c(0.1, 0.5, 1.0)) { + for (standardize in c(FALSE, TRUE)) { + for (alpha in c(0.1, 0.5, 1.0)) { + model <- glmnet(A, b, weights=w, intercept=intercept, lambda=lambda, + standardize=standardize, alpha=alpha, thresh=1E-14) + print(as.vector(coef(model))) + } + } + } + } + [1] 0.000000 -3.292821 2.921188 + [1] 0.000000 -3.230854 2.908484 + [1] 0.000000 -3.145586 2.891014 + [1] 0.000000 -2.919246 2.841724 + [1] 0.000000 -2.938323 2.846369 + [1] 0.000000 -2.965397 2.852838 + [1] 0.000000 -2.137858 2.684464 + [1] 0.000000 -1.680094 2.590844 + [1] 0.0000000 -0.8194631 2.4151405 + [1] 0.0000000 -0.9608375 2.4301013 + [1] 0.0000000 -0.6187922 2.3634907 + [1] 0.000000 0.000000 2.240811 + [1] 0.000000 -1.346573 2.521293 + [1] 0.0000000 -0.3680456 2.3212362 + [1] 0.000000 0.000000 2.244406 + [1] 0.000000 0.000000 2.219816 + [1] 0.000000 0.000000 2.223694 + [1] 0.00000 0.00000 2.22861 + [1] 13.5631592 3.2811513 0.3725517 + [1] 13.6953934 3.3336271 0.3497454 + [1] 13.9600276 3.4600170 0.2999941 + [1] 14.2389889 3.6589920 0.2349065 + [1] 15.2374080 4.2119643 0.0325638 + [1] 15.4 4.3 0.0 + [1] 10.442365 1.246065 1.063991 + [1] 8.9580718 0.1938471 1.4090610 + [1] 8.865672 0.000000 1.455224 + [1] 13.0430927 2.4927151 0.5741805 + [1] 13.814429 2.722027 0.455915 + [1] 16.2 3.9 0.0 + [1] 9.8904768 0.7574694 1.2110177 + [1] 9.072226 0.000000 1.435363 + [1] 9.512438 0.000000 1.393035 + [1] 13.3677796 2.1721216 0.6046132 + [1] 14.2554457 2.2285185 0.5084151 + [1] 17.2 3.4 0.0 + */ + + val expected = Seq( + Vectors.dense(0, -3.2928206726474, 2.92118822588649), + Vectors.dense(0, -3.23085414359003, 2.90848366035008), + Vectors.dense(0, -3.14558628299477, 2.89101408157209), + Vectors.dense(0, -2.91924558816421, 2.84172398097327), + Vectors.dense(0, -2.93832343383477, 2.84636891947663), + Vectors.dense(0, -2.96539689593024, 2.85283836322185), + Vectors.dense(0, -2.13785756976542, 2.68446351346705), + Vectors.dense(0, -1.68009377560774, 2.59084422793154), + Vectors.dense(0, -0.819463123385533, 2.41514053108346), + Vectors.dense(0, -0.960837488151064, 2.43010130999756), + Vectors.dense(0, -0.618792151647599, 2.36349074148962), + Vectors.dense(0, 0, 2.24081114726441), + Vectors.dense(0, -1.34657309253953, 2.52129296638512), + Vectors.dense(0, -0.368045602821844, 2.32123616258871), + Vectors.dense(0, 0, 2.24440619621343), + Vectors.dense(0, 0, 2.21981559944924), + Vectors.dense(0, 0, 2.22369447413621), + Vectors.dense(0, 0, 2.22861024633605), + Vectors.dense(13.5631591827557, 3.28115132060568, 0.372551747695477), + Vectors.dense(13.6953934007661, 3.3336271417751, 0.349745414969587), + Vectors.dense(13.960027608754, 3.46001702257532, 0.29999407173994), + Vectors.dense(14.2389889013085, 3.65899196445023, 0.234906458633754), + Vectors.dense(15.2374079667397, 4.21196428071551, 0.0325637953681963), + Vectors.dense(15.4, 4.3, 0), + Vectors.dense(10.4423647474653, 1.24606545153166, 1.06399080283378), + Vectors.dense(8.95807177856822, 0.193847088148233, 1.4090609658784), + Vectors.dense(8.86567164179104, 0, 1.45522388059702), + Vectors.dense(13.0430927453034, 2.49271514356687, 0.574180477650271), + Vectors.dense(13.8144287399675, 2.72202744354555, 0.455915035859752), + Vectors.dense(16.2, 3.9, 0), + Vectors.dense(9.89047681835741, 0.757469417613661, 1.21101772561685), + Vectors.dense(9.07222551185964, 0, 1.43536293155196), + Vectors.dense(9.51243781094527, 0, 1.39303482587065), + Vectors.dense(13.3677796362763, 2.17212164262107, 0.604613180623227), + Vectors.dense(14.2554457236073, 2.22851848830683, 0.508415124978748), + Vectors.dense(17.2, 3.4, 0) + ) + + var idx = 0 + for (fitIntercept <- Seq(false, true); + regParam <- Seq(0.1, 0.5, 1.0); + standardizeFeatures <- Seq(false, true); + elasticNetParam <- Seq(0.1, 0.5, 1.0)) { + val wls = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = elasticNetParam, + standardizeFeatures, standardizeLabel = true, solverType = WeightedLeastSquares.Auto) + .fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("WLS against glmnet with L2 regularization") { /* R code: @@ -201,11 +529,13 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext for (fitIntercept <- Seq(false, true); regParam <- Seq(0.0, 0.1, 1.0); standardizeFeatures <- Seq(false, true)) { - val wls = new WeightedLeastSquares( - fitIntercept, regParam, standardizeFeatures, standardizeLabel = true) - .fit(instances) - val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) - assert(actual ~== expected(idx) absTol 1e-4) + for (solver <- WeightedLeastSquares.supportedSolvers) { + val wls = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, + standardizeFeatures, standardizeLabel = true, solverType = solver) + .fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + } idx += 1 } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 1c94ec67d79d..c0e8afbf5e34 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -57,7 +57,7 @@ class LinearRegressionSuite xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2).map(_.asML).toDF() val r = new Random(seed) - // When feature size is larger than 4096, normal optimizer is choosed + // When feature size is larger than 4096, normal optimizer is chosen // as the solver of linear regression in the case of "auto" mode. val featureSize = 4100 datasetWithSparseFeature = sc.parallelize(LinearDataGenerator.generateLinearInput( @@ -155,6 +155,42 @@ class LinearRegressionSuite assert(model.numFeatures === numFeatures) } + test("linear regression handles singular matrices") { + // check for both constant columns with intercept (zero std) and collinear + val singularDataConstantColumn = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(1.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(1.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(1.0, 13.0)) + ), 2).toDF() + + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer = new LinearRegression().setSolver(solver).setFitIntercept(true) + val model = trainer.fit(singularDataConstantColumn) + // to make it clear that WLS did not solve analytically + intercept[UnsupportedOperationException] { + model.summary.coefficientStandardErrors + } + assert(model.summary.objectiveHistory !== Array(0.0)) + } + + val singularDataCollinearFeatures = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(10.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(14.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(22.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(26.0, 13.0)) + ), 2).toDF() + + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer = new LinearRegression().setSolver(solver).setFitIntercept(true) + val model = trainer.fit(singularDataCollinearFeatures) + intercept[UnsupportedOperationException] { + model.summary.coefficientStandardErrors + } + assert(model.summary.objectiveHistory !== Array(0.0)) + } + } + test("linear regression with intercept without regularization") { Seq("auto", "l-bfgs", "normal").foreach { solver => val trainer1 = new LinearRegression().setSolver(solver) @@ -233,12 +269,12 @@ class LinearRegressionSuite as.numeric.data3.V2. 4.70011 as.numeric.data3.V3. 7.19943 */ - val coefficientsWithourInterceptR = Vectors.dense(4.70011, 7.19943) + val coefficientsWithoutInterceptR = Vectors.dense(4.70011, 7.19943) assert(modelWithoutIntercept1.intercept ~== 0 absTol 1E-3) - assert(modelWithoutIntercept1.coefficients ~= coefficientsWithourInterceptR relTol 1E-3) + assert(modelWithoutIntercept1.coefficients ~= coefficientsWithoutInterceptR relTol 1E-3) assert(modelWithoutIntercept2.intercept ~== 0 absTol 1E-3) - assert(modelWithoutIntercept2.coefficients ~= coefficientsWithourInterceptR relTol 1E-3) + assert(modelWithoutIntercept2.coefficients ~= coefficientsWithoutInterceptR relTol 1E-3) } } @@ -249,55 +285,47 @@ class LinearRegressionSuite val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) .setSolver(solver).setStandardization(false) - // Normal optimizer is not supported with only L1 regularization case. - if (solver == "normal") { - intercept[IllegalArgumentException] { - trainer1.fit(datasetWithDenseFeature) - trainer2.fit(datasetWithDenseFeature) - } - } else { - val model1 = trainer1.fit(datasetWithDenseFeature) - val model2 = trainer2.fit(datasetWithDenseFeature) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", - alpha = 1.0, lambda = 0.57 )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.242284 - as.numeric.d1.V2. 4.019605 - as.numeric.d1.V3. 6.679538 - */ - val interceptR1 = 6.242284 - val coefficientsR1 = Vectors.dense(4.019605, 6.679538) - assert(model1.intercept ~== interceptR1 relTol 1E-2) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, - lambda = 0.57, standardize=FALSE )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.416948 - as.numeric.data.V2. 3.893869 - as.numeric.data.V3. 6.724286 - */ - val interceptR2 = 6.416948 - val coefficientsR2 = Vectors.dense(3.893869, 6.724286) - - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) - - model1.transform(datasetWithDenseFeature).select("features", "prediction") - .collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + - model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) - } + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", + alpha = 1.0, lambda = 0.57 )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.242284 + as.numeric.d1.V2. 4.019605 + as.numeric.d1.V3. 6.679538 + */ + val interceptR1 = 6.242284 + val coefficientsR1 = Vectors.dense(4.019605, 6.679538) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, + lambda = 0.57, standardize=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.416948 + as.numeric.data.V2. 3.893869 + as.numeric.data.V3. 6.724286 + */ + val interceptR2 = 6.416948 + val coefficientsR2 = Vectors.dense(3.893869, 6.724286) + + assert(model2.intercept ~== interceptR2 relTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + + model1.transform(datasetWithDenseFeature).select("features", "prediction") + .collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) } } } @@ -309,56 +337,48 @@ class LinearRegressionSuite val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) .setFitIntercept(false).setStandardization(false).setSolver(solver) - // Normal optimizer is not supported with only L1 regularization case. - if (solver == "normal") { - intercept[IllegalArgumentException] { - trainer1.fit(datasetWithDenseFeature) - trainer2.fit(datasetWithDenseFeature) - } - } else { - val model1 = trainer1.fit(datasetWithDenseFeature) - val model2 = trainer2.fit(datasetWithDenseFeature) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, - lambda = 0.57, intercept=FALSE )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data.V2. 6.272927 - as.numeric.data.V3. 4.782604 - */ - val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(6.272927, 4.782604) - - assert(model1.intercept ~== interceptR1 absTol 1E-2) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, - lambda = 0.57, intercept=FALSE, standardize=FALSE )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.data.V2. 6.207817 - as.numeric.data.V3. 4.775780 - */ - val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(6.207817, 4.775780) - - assert(model2.intercept ~== interceptR2 absTol 1E-2) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) - - model1.transform(datasetWithDenseFeature).select("features", "prediction") - .collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + - model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) - } + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, + lambda = 0.57, intercept=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.272927 + as.numeric.data.V3. 4.782604 + */ + val interceptR1 = 0.0 + val coefficientsR1 = Vectors.dense(6.272927, 4.782604) + + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, + lambda = 0.57, intercept=FALSE, standardize=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.207817 + as.numeric.data.V3. 4.775780 + */ + val interceptR2 = 0.0 + val coefficientsR2 = Vectors.dense(6.207817, 4.775780) + + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) + + model1.transform(datasetWithDenseFeature).select("features", "prediction") + .collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) } } } @@ -471,56 +491,48 @@ class LinearRegressionSuite val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) .setStandardization(false).setSolver(solver) - // Normal optimizer is not supported with non-zero elasticnet parameter. - if (solver == "normal") { - intercept[IllegalArgumentException] { - trainer1.fit(datasetWithDenseFeature) - trainer2.fit(datasetWithDenseFeature) - } - } else { - val model1 = trainer1.fit(datasetWithDenseFeature) - val model2 = trainer2.fit(datasetWithDenseFeature) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, - lambda = 1.6 )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 5.689855 - as.numeric.d1.V2. 3.661181 - as.numeric.d1.V3. 6.000274 - */ - val interceptR1 = 5.689855 - val coefficientsR1 = Vectors.dense(3.661181, 6.000274) - - assert(model1.intercept ~== interceptR1 relTol 1E-2) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 - standardize=FALSE)) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) 6.113890 - as.numeric.d1.V2. 3.407021 - as.numeric.d1.V3. 6.152512 - */ - val interceptR2 = 6.113890 - val coefficientsR2 = Vectors.dense(3.407021, 6.152512) - - assert(model2.intercept ~== interceptR2 relTol 1E-2) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) - - model1.transform(datasetWithDenseFeature).select("features", "prediction") - .collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + - model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) - } + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, + lambda = 1.6 )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 5.689855 + as.numeric.d1.V2. 3.661181 + as.numeric.d1.V3. 6.000274 + */ + val interceptR1 = 5.689855 + val coefficientsR1 = Vectors.dense(3.661181, 6.000274) + + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 + standardize=FALSE)) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.113890 + as.numeric.d1.V2. 3.407021 + as.numeric.d1.V3. 6.152512 + */ + val interceptR2 = 6.113890 + val coefficientsR2 = Vectors.dense(3.407021, 6.152512) + + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) + + model1.transform(datasetWithDenseFeature).select("features", "prediction") + .collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) } } } @@ -532,57 +544,49 @@ class LinearRegressionSuite val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) .setFitIntercept(false).setStandardization(false).setSolver(solver) - // Normal optimizer is not supported with non-zero elasticnet parameter. - if (solver == "normal") { - intercept[IllegalArgumentException] { - trainer1.fit(datasetWithDenseFeature) - trainer2.fit(datasetWithDenseFeature) - } - } else { - val model1 = trainer1.fit(datasetWithDenseFeature) - val model2 = trainer2.fit(datasetWithDenseFeature) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, - lambda = 1.6, intercept=FALSE )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.d1.V2. 5.643748 - as.numeric.d1.V3. 4.331519 - */ - val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(5.643748, 4.331519) - - assert(model1.intercept ~== interceptR1 absTol 1E-2) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) - - /* - coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, - lambda = 1.6, intercept=FALSE, standardize=FALSE )) - > coefficients - 3 x 1 sparse Matrix of class "dgCMatrix" - s0 - (Intercept) . - as.numeric.d1.V2. 5.455902 - as.numeric.d1.V3. 4.312266 - - */ - val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(5.455902, 4.312266) - - assert(model2.intercept ~== interceptR2 absTol 1E-2) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) - - model1.transform(datasetWithDenseFeature).select("features", "prediction") - .collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = - features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + - model1.intercept - assert(prediction1 ~== prediction2 relTol 1E-5) - } + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, + lambda = 1.6, intercept=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.d1.V2. 5.643748 + as.numeric.d1.V3. 4.331519 + */ + val interceptR1 = 0.0 + val coefficientsR1 = Vectors.dense(5.643748, 4.331519) + + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) + + /* + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, + lambda = 1.6, intercept=FALSE, standardize=FALSE )) + > coefficients + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.d1.V2. 5.455902 + as.numeric.d1.V3. 4.312266 + + */ + val interceptR2 = 0.0 + val coefficientsR2 = Vectors.dense(5.455902, 4.312266) + + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) + + model1.transform(datasetWithDenseFeature).select("features", "prediction") + .collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) } } } @@ -757,7 +761,8 @@ class LinearRegressionSuite assert(model.summary.meanAbsoluteError ~== 0.07961668 relTol 1E-4) assert(model.summary.r2 ~== 0.9998737 relTol 1E-4) - // Normal solver uses "WeightedLeastSquares". This algorithm does not generate + // Normal solver uses "WeightedLeastSquares". If no regularization is applied or only L2 + // regularization is applied, this algorithm uses a direct solver and does not generate an // objective history because it does not run through iterations. if (solver == "l-bfgs") { // Objective function should be monotonically decreasing for linear regression @@ -776,7 +781,7 @@ class LinearRegressionSuite val pValsR = Array(0, 0, 0) model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => assert(x._1 ~== x._2 absTol 1E-4) } - model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + model.summary.coefficientStandardErrors.zip(seCoefR).foreach { x => assert(x._1 ~== x._2 absTol 1E-4) } model.summary.tValues.map(_.round).zip(tValsR).foreach{ x => assert(x._1 === x._2) } model.summary.pValues.map(_.round).zip(pValsR).foreach{ x => assert(x._1 === x._2) } @@ -950,6 +955,20 @@ class LinearRegressionSuite assert(x._1 ~== x._2 absTol 1E-3) } model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + + val modelWithL1 = new LinearRegression() + .setWeightCol("weight") + .setSolver("normal") + .setRegParam(0.5) + .setElasticNetParam(1.0) + .fit(datasetWithWeight) + + assert(modelWithL1.summary.objectiveHistory !== Array(0.0)) + assert( + modelWithL1.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) } test("linear regression summary with weighted samples and w/o intercept by normal solver") { From 6f31833dbe0b766dfe4540a240fe92ebb7e14737 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 25 Oct 2016 15:00:33 +0800 Subject: [PATCH 007/381] [SPARK-18026][SQL] should not always lowercase partition columns of partition spec in parser ## What changes were proposed in this pull request? Currently we always lowercase the partition columns of partition spec in parser, with the assumption that table partition columns are always lowercased. However, this is not true for data source tables, which are case preserving. It's safe for now because data source tables don't store partition spec in metastore and don't support `ADD PARTITION`, `DROP PARTITION`, `RENAME PARTITION`, but we should make our code future-proof. This PR makes partition spec case preserving at parser, and improve the `PreprocessTableInsertion` analyzer rule to normalize the partition columns in partition spec, w.r.t. the table partition columns. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #15566 from cloud-fan/partition-spec. --- .../sql/catalyst/parser/AstBuilder.scala | 6 ++- .../plans/logical/basicLogicalOperators.scala | 20 +-------- .../spark/sql/execution/command/ddl.scala | 34 +++++++++++++-- .../datasources/PartitioningUtils.scala | 30 +++++++++++++ .../sql/execution/datasources/rules.scala | 41 +++++++++--------- .../sql/execution/command/DDLSuite.scala | 42 +++++++++++++++++++ .../sql/hive/client/HiveClientImpl.scala | 3 ++ .../sql/hive/InsertIntoHiveTableSuite.scala | 15 +------ .../sql/hive/execution/HiveDDLSuite.scala | 5 +-- 9 files changed, 136 insertions(+), 60 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 929c1c4f2d9e..38e9bb6c162a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -192,11 +192,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { override def visitPartitionSpec( ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { val parts = ctx.partitionVal.asScala.map { pVal => - val name = pVal.identifier.getText.toLowerCase + val name = pVal.identifier.getText val value = Option(pVal.constant).map(visitStringConstant) name -> value } - // Check for duplicate partition columns in one spec. + // Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values + // in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for + // partition columns will be done in analyzer. checkDuplicateKeys(parts, ctx) parts.toMap } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 64a787a7ae35..a48974c6322a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -356,26 +356,10 @@ case class InsertIntoTable( override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty - lazy val expectedColumns = { - if (table.output.isEmpty) { - None - } else { - // Note: The parser (visitPartitionSpec in AstBuilder) already turns - // keys in partition to their lowercase forms. - val staticPartCols = partition.filter(_._2.isDefined).keySet - Some(table.output.filterNot(a => staticPartCols.contains(a.name))) - } - } - assert(overwrite || !ifNotExists) assert(partition.values.forall(_.nonEmpty) || !ifNotExists) - override lazy val resolved: Boolean = - childrenResolved && table.resolved && expectedColumns.forall { expected => - child.output.size == expected.size && child.output.zip(expected).forall { - case (childAttr, tableAttr) => - DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) - } - } + + override lazy val resolved: Boolean = childrenResolved && table.resolved } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 45fa293e5895..15656faa08e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -351,8 +351,13 @@ case class AlterTableAddPartitionCommand( "ALTER TABLE ADD PARTITION is not allowed for tables defined using the datasource API") } val parts = partitionSpecsAndLocs.map { case (spec, location) => + val normalizedSpec = PartitioningUtils.normalizePartitionSpec( + spec, + table.partitionColumnNames, + table.identifier.quotedString, + sparkSession.sessionState.conf.resolver) // inherit table storage format (possibly except for location) - CatalogTablePartition(spec, table.storage.copy(locationUri = location)) + CatalogTablePartition(normalizedSpec, table.storage.copy(locationUri = location)) } catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) Seq.empty[Row] @@ -382,8 +387,21 @@ case class AlterTableRenamePartitionCommand( "ALTER TABLE RENAME PARTITION is not allowed for tables defined using the datasource API") } DDLUtils.verifyAlterTableType(catalog, table, isView = false) + + val normalizedOldPartition = PartitioningUtils.normalizePartitionSpec( + oldPartition, + table.partitionColumnNames, + table.identifier.quotedString, + sparkSession.sessionState.conf.resolver) + + val normalizedNewPartition = PartitioningUtils.normalizePartitionSpec( + newPartition, + table.partitionColumnNames, + table.identifier.quotedString, + sparkSession.sessionState.conf.resolver) + catalog.renamePartitions( - tableName, Seq(oldPartition), Seq(newPartition)) + tableName, Seq(normalizedOldPartition), Seq(normalizedNewPartition)) Seq.empty[Row] } @@ -418,7 +436,17 @@ case class AlterTableDropPartitionCommand( throw new AnalysisException( "ALTER TABLE DROP PARTITIONS is not allowed for tables defined using the datasource API") } - catalog.dropPartitions(table.identifier, specs, ignoreIfNotExists = ifExists, purge = purge) + + val normalizedSpecs = specs.map { spec => + PartitioningUtils.normalizePartitionSpec( + spec, + table.partitionColumnNames, + table.identifier.quotedString, + sparkSession.sessionState.conf.resolver) + } + + catalog.dropPartitions( + table.identifier, normalizedSpecs, ignoreIfNotExists = ifExists, purge = purge) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 81bdabb7afda..f66e8b4e2b55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ @@ -243,6 +244,35 @@ object PartitioningUtils { } } + /** + * Normalize the column names in partition specification, w.r.t. the real partition column names + * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a + * partition column named `month`, and it's case insensitive, we will normalize `monTh` to + * `month`. + */ + def normalizePartitionSpec[T]( + partitionSpec: Map[String, T], + partColNames: Seq[String], + tblName: String, + resolver: Resolver): Map[String, T] = { + val normalizedPartSpec = partitionSpec.toSeq.map { case (key, value) => + val normalizedKey = partColNames.find(resolver(_, key)).getOrElse { + throw new AnalysisException(s"$key is not a valid partition column in table $tblName.") + } + normalizedKey -> value + } + + if (normalizedPartSpec.map(_._1).distinct.length != normalizedPartSpec.length) { + val duplicateColumns = normalizedPartSpec.map(_._1).groupBy(identity).collect { + case (x, ys) if ys.length > 1 => x + } + throw new AnalysisException(s"Found duplicated columns in partition specification: " + + duplicateColumns.mkString(", ")) + } + + normalizedPartSpec.toMap + } + /** * Resolves possible type conflicts between partitions by up-casting "lower" types. The up- * casting order is: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index bd6eb6e0535a..cf501cdc919e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -187,8 +187,8 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl colName: String, colType: String): String = { val tableCols = schema.map(_.name) - val conf = sparkSession.sessionState.conf - tableCols.find(conf.resolver(_, colName)).getOrElse { + val resolver = sparkSession.sessionState.conf.resolver + tableCols.find(resolver(_, colName)).getOrElse { failAnalysis(s"$colType column $colName is not defined in table $tableIdent, " + s"defined table columns are: ${tableCols.mkString(", ")}") } @@ -209,42 +209,41 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { tblName: String, partColNames: Seq[String]): InsertIntoTable = { - val expectedColumns = insert.expectedColumns - if (expectedColumns.isDefined && expectedColumns.get.length != insert.child.schema.length) { + val normalizedPartSpec = PartitioningUtils.normalizePartitionSpec( + insert.partition, partColNames, tblName, conf.resolver) + + val expectedColumns = { + val staticPartCols = normalizedPartSpec.filter(_._2.isDefined).keySet + insert.table.output.filterNot(a => staticPartCols.contains(a.name)) + } + + if (expectedColumns.length != insert.child.schema.length) { throw new AnalysisException( s"Cannot insert into table $tblName because the number of columns are different: " + - s"need ${expectedColumns.get.length} columns, " + + s"need ${expectedColumns.length} columns, " + s"but query has ${insert.child.schema.length} columns.") } - if (insert.partition.nonEmpty) { - // the query's partitioning must match the table's partitioning - // this is set for queries like: insert into ... partition (one = "a", two = ) - val samePartitionColumns = - if (conf.caseSensitiveAnalysis) { - insert.partition.keySet == partColNames.toSet - } else { - insert.partition.keySet.map(_.toLowerCase) == partColNames.map(_.toLowerCase).toSet - } - if (!samePartitionColumns) { + if (normalizedPartSpec.nonEmpty) { + if (normalizedPartSpec.size != partColNames.length) { throw new AnalysisException( s""" |Requested partitioning does not match the table $tblName: - |Requested partitions: ${insert.partition.keys.mkString(",")} + |Requested partitions: ${normalizedPartSpec.keys.mkString(",")} |Table partitions: ${partColNames.mkString(",")} """.stripMargin) } - expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert) + + castAndRenameChildOutput(insert.copy(partition = normalizedPartSpec), expectedColumns) } else { - // All partition columns are dynamic because because the InsertIntoTable command does + // All partition columns are dynamic because the InsertIntoTable command does // not explicitly specify partitioning columns. - expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert) + castAndRenameChildOutput(insert, expectedColumns) .copy(partition = partColNames.map(_ -> None).toMap) } } - // TODO: do we really need to rename? - def castAndRenameChildOutput( + private def castAndRenameChildOutput( insert: InsertIntoTable, expectedOutput: Seq[Attribute]): InsertIntoTable = { val newChildOutput = expectedOutput.zip(insert.child.output).map { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index d593bfb4ce19..de326f80f659 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -926,23 +926,33 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createPartitionedTable(tableIdent, isDatasourceTable = false) + + // basic rename partition sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") sql("ALTER TABLE dbx.tab1 PARTITION (a='2', b='c') RENAME TO PARTITION (a='20', b='c')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(Map("a" -> "100", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) + // rename without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 PARTITION (a='100', b='p') RENAME TO PARTITION (a='10', b='p')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(Map("a" -> "10", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) + // table to alter does not exist intercept[NoSuchTableException] { sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") } + // partition to rename does not exist intercept[NoSuchPartitionException] { sql("ALTER TABLE tab1 PARTITION (a='not_found', b='1') RENAME TO PARTITION (a='1', b='2')") } + + // partition spec in RENAME PARTITION should be case insensitive by default + sql("ALTER TABLE tab1 PARTITION (A='10', B='p') RENAME TO PARTITION (A='1', B='p')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "1", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) } test("alter table: rename partition (datasource table)") { @@ -1334,6 +1344,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val part2 = Map("a" -> "2", "b" -> "6") val part3 = Map("a" -> "3", "b" -> "7") val part4 = Map("a" -> "4", "b" -> "8") + val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") createTable(catalog, tableIdent) createTablePartition(catalog, part1, tableIdent) @@ -1341,6 +1352,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { convertToDatasourceTable(catalog, tableIdent) } assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) + + // basic add partition maybeWrapException(isDatasourceTable) { sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") @@ -1351,6 +1364,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris")) assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) } + // add partitions without explicitly specifying database catalog.setCurrentDatabase("dbx") maybeWrapException(isDatasourceTable) { @@ -1360,14 +1374,18 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3, part4)) } + // table to alter does not exist intercept[AnalysisException] { sql("ALTER TABLE does_not_exist ADD IF NOT EXISTS PARTITION (a='4', b='9')") } + // partition to add already exists intercept[AnalysisException] { sql("ALTER TABLE tab1 ADD PARTITION (a='4', b='8')") } + + // partition to add already exists when using IF NOT EXISTS maybeWrapException(isDatasourceTable) { sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (a='4', b='8')") } @@ -1375,6 +1393,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3, part4)) } + + // partition spec in ADD PARTITION should be case insensitive by default + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 ADD PARTITION (A='9', B='9')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4, part5)) + } } private def testDropPartitions(isDatasourceTable: Boolean): Unit = { @@ -1395,12 +1422,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { if (isDatasourceTable) { convertToDatasourceTable(catalog, tableIdent) } + + // basic drop partition maybeWrapException(isDatasourceTable) { sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')") } if (!isDatasourceTable) { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) } + // drop partitions without explicitly specifying database catalog.setCurrentDatabase("dbx") maybeWrapException(isDatasourceTable) { @@ -1409,20 +1439,32 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { if (!isDatasourceTable) { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) } + // table to alter does not exist intercept[AnalysisException] { sql("ALTER TABLE does_not_exist DROP IF EXISTS PARTITION (a='2')") } + // partition to drop does not exist intercept[AnalysisException] { sql("ALTER TABLE tab1 DROP PARTITION (a='300')") } + + // partition to drop does not exist when using IF EXISTS maybeWrapException(isDatasourceTable) { sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (a='300')") } if (!isDatasourceTable) { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) } + + // partition spec in DROP PARTITION should be case insensitive by default + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 DROP PARTITION (A='1', B='5')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).isEmpty) + } } test("drop build-in function") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index e745a8c5b358..8835b266b22a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -831,6 +831,9 @@ private[hive] class HiveClientImpl( new HivePartition(ht, tpart) } + // TODO (cloud-fan): the column names in partition specification are always lower cased because + // Hive metastore is not case preserving. We should normalize them to the actual column names of + // the table, once we store partition spec of data source tables. private def fromHivePartition(hp: HivePartition): CatalogTablePartition = { val apiPartition = hp.getTPartition CatalogTablePartition( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index d9ce1c3dc18f..e3ddaf725424 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -370,17 +370,6 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef assert(cause.getMessage.contains("insertInto() can't be used together with partitionBy().")) } - test("InsertIntoTable#resolved should include dynamic partitions") { - withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { - sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") - val data = (1 to 10).map(i => (i.toLong, s"data-$i")).toDF("id", "data") - - val logical = InsertIntoTable(spark.table("partitioned").logicalPlan, - Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false) - assert(!logical.resolved, "Should not resolve: missing partition data") - } - } - testPartitionedTable( "SPARK-16036: better error message when insert into a table with mismatch schema") { tableName => @@ -409,8 +398,8 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql(s"INSERT INTO TABLE $tableName PARTITION (c=11, b=10) SELECT 9, 12") - // c is defined twice. Parser will complain. - intercept[ParseException] { + // c is defined twice. Analyzer will complain. + intercept[AnalysisException] { sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c=15, c=16) SELECT 13") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 3d1712e4354c..e9268a922cf5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -200,9 +200,8 @@ class HiveDDLSuite val message = intercept[AnalysisException] { sql(s"ALTER TABLE $externalTab DROP PARTITION (ds='2008-04-09', unknownCol='12')") } - assert(message.getMessage.contains( - "Partition spec is invalid. The spec (ds, unknowncol) must be contained within the " + - "partition spec (ds, hr) defined in table '`default`.`exttable_with_partitions`'")) + assert(message.getMessage.contains("unknownCol is not a valid partition column in table " + + "`default`.`exttable_with_partitions`")) sql( s""" From 38cdd6ccdaba7f8da985c4f4efe5bd93a46a2b53 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 25 Oct 2016 03:19:50 -0700 Subject: [PATCH 008/381] [SPARK-14634][ML][FOLLOWUP] Delete superfluous line in BisectingKMeans ## What changes were proposed in this pull request? As commented by jkbradley in https://github.com/apache/spark/pull/12394, `model.setSummary(summary)` is superfluous ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #15619 from zhengruifeng/del_superfluous. --- .../org/apache/spark/ml/clustering/BisectingKMeans.scala | 5 ++--- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index add8ee2a4ff8..ef2d918ea354 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -265,9 +265,8 @@ class BisectingKMeans @Since("2.0.0") ( val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(summary) - val m = model.setSummary(summary) - instr.logSuccess(m) - m + instr.logSuccess(model) + model } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index b04e82838e71..0d2405b50068 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -324,9 +324,9 @@ class KMeans @Since("1.5.0") ( val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) - val m = model.setSummary(summary) - instr.logSuccess(m) - m + model.setSummary(summary) + instr.logSuccess(model) + model } @Since("1.5.0") From ac8ff920faec6ee06e17212e2b5d2ee117495e87 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 25 Oct 2016 10:22:02 -0700 Subject: [PATCH 009/381] [SPARK-17748][FOLLOW-UP][ML] Fix build error for Scala 2.10. ## What changes were proposed in this pull request? #15394 introduced build error for Scala 2.10, this PR fix it. ## How was this patch tested? Existing test. Author: Yanbo Liang Closes #15625 from yanboliang/spark-17748-scala. --- .../spark/ml/optim/WeightedLeastSquaresSuite.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala index 5f638b488005..3cdab0327991 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -280,7 +280,7 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext } // Cholesky also fails when regularization is added but we don't wish to standardize - val wls = new WeightedLeastSquares(true, regParam = 0.5, elasticNetParam = 0.0, + val wls = new WeightedLeastSquares(fitIntercept = true, regParam = 0.5, elasticNetParam = 0.0, standardizeFeatures = false, standardizeLabel = false, solverType = WeightedLeastSquares.Cholesky) intercept[SingularMatrixException] { @@ -470,10 +470,11 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext var idx = 0 for (fitIntercept <- Seq(false, true); regParam <- Seq(0.1, 0.5, 1.0); - standardizeFeatures <- Seq(false, true); + standardization <- Seq(false, true); elasticNetParam <- Seq(0.1, 0.5, 1.0)) { - val wls = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = elasticNetParam, - standardizeFeatures, standardizeLabel = true, solverType = WeightedLeastSquares.Auto) + val wls = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam, + standardizeFeatures = standardization, standardizeLabel = true, + solverType = WeightedLeastSquares.Auto) .fit(instances) val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) assert(actual ~== expected(idx) absTol 1e-4) @@ -528,10 +529,10 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext var idx = 0 for (fitIntercept <- Seq(false, true); regParam <- Seq(0.0, 0.1, 1.0); - standardizeFeatures <- Seq(false, true)) { + standardization <- Seq(false, true)) { for (solver <- WeightedLeastSquares.supportedSolvers) { val wls = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, - standardizeFeatures, standardizeLabel = true, solverType = solver) + standardizeFeatures = standardization, standardizeLabel = true, solverType = solver) .fit(instances) val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) assert(actual ~== expected(idx) absTol 1e-4) From c5fe3dd4f59c464c830b414acccd3cca0fdd877c Mon Sep 17 00:00:00 2001 From: Vinayak Date: Tue, 25 Oct 2016 10:36:03 -0700 Subject: [PATCH 010/381] [SPARK-18010][CORE] Reduce work performed for building up the application list for the History Server app list UI page ## What changes were proposed in this pull request? allow ReplayListenerBus to skip deserialising and replaying certain events using an inexpensive check of the event log entry. Use this to ensure that when event log replay is triggered for building the application list, we get the ReplayListenerBus to skip over all but the few events needed for our immediate purpose. Refer [SPARK-18010] for the motivation behind this change. ## How was this patch tested? Tested with existing HistoryServer and ReplayListener unit test suites. All tests pass. Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. Author: Vinayak Closes #15556 from vijoshi/SAAS-467_master. --- .../deploy/history/FsHistoryProvider.scala | 120 ++++++++++-------- .../spark/scheduler/ReplayListenerBus.scala | 39 +++++- 2 files changed, 101 insertions(+), 58 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 530cc5252214..dfc1aad64c81 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -36,6 +36,7 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.ReplayListenerBus._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} @@ -78,10 +79,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) import FsHistoryProvider._ - private val NOT_STARTED = "" - - private val SPARK_HISTORY_FS_NUM_REPLAY_THREADS = "spark.history.fs.numReplayThreads" - // Interval between safemode checks. private val SAFEMODE_CHECK_INTERVAL_S = conf.getTimeAsSeconds( "spark.history.fs.safemodeCheck.interval", "5s") @@ -241,11 +238,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } - val appListener = new ApplicationEventListener() - replayBus.addListener(appListener) - val appAttemptInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), - replayBus) - appAttemptInfo.map { info => + + val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) + + val appListener = replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) + + if (appListener.appId.isDefined) { val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) ui.getSecurityManager.setAcls(uiAclsEnabled) // make sure to set admin acls before view acls so they are properly picked up @@ -254,8 +252,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) appListener.viewAcls.getOrElse("")) ui.getSecurityManager.setAdminAclsGroups(appListener.adminAclsGroups.getOrElse("")) ui.getSecurityManager.setViewAclsGroups(appListener.viewAclsGroups.getOrElse("")) - LoadedAppUI(ui, updateProbe(appId, attemptId, attempt.fileSize)) + Some(LoadedAppUI(ui, updateProbe(appId, attemptId, attempt.fileSize))) + } else { + None } + } } } catch { @@ -411,28 +412,54 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - /** * Replay the log files in the list and merge the list of old applications with new ones */ private def mergeApplicationListing(fileStatus: FileStatus): Unit = { val newAttempts = try { - val bus = new ReplayListenerBus() - val res = replay(fileStatus, bus) - res match { - case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully: $r") - case None => logWarning(s"Failed to load application log ${fileStatus.getPath}. " + - "The application may have not started.") - } - res - } catch { - case e: Exception => - logError( - s"Exception encountered when attempting to load application log ${fileStatus.getPath}", - e) - None + val eventsFilter: ReplayEventsFilter = { eventString => + eventString.startsWith(APPL_START_EVENT_PREFIX) || + eventString.startsWith(APPL_END_EVENT_PREFIX) + } + + val logPath = fileStatus.getPath() + + val appCompleted = isApplicationCompleted(fileStatus) + + val appListener = replay(fileStatus, appCompleted, new ReplayListenerBus(), eventsFilter) + + // Without an app ID, new logs will render incorrectly in the listing page, so do not list or + // try to show their UI. + if (appListener.appId.isDefined) { + val attemptInfo = new FsApplicationAttemptInfo( + logPath.getName(), + appListener.appName.getOrElse(NOT_STARTED), + appListener.appId.getOrElse(logPath.getName()), + appListener.appAttemptId, + appListener.startTime.getOrElse(-1L), + appListener.endTime.getOrElse(-1L), + fileStatus.getModificationTime(), + appListener.sparkUser.getOrElse(NOT_STARTED), + appCompleted, + fileStatus.getLen() + ) + fileToAppInfo(logPath) = attemptInfo + logDebug(s"Application log ${attemptInfo.logPath} loaded successfully: $attemptInfo") + Some(attemptInfo) + } else { + logWarning(s"Failed to load application log ${fileStatus.getPath}. " + + "The application may have not started.") + None } + } catch { + case e: Exception => + logError( + s"Exception encountered when attempting to load application log ${fileStatus.getPath}", + e) + None + } + if (newAttempts.isEmpty) { return } @@ -564,12 +591,16 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } /** - * Replays the events in the specified log file and returns information about the associated - * application. Return `None` if the application ID cannot be located. + * Replays the events in the specified log file on the supplied `ReplayListenerBus`. Returns + * an `ApplicationEventListener` instance with event data captured from the replay. + * `ReplayEventsFilter` determines what events are replayed and can therefore limit the + * data captured in the returned `ApplicationEventListener` instance. */ private def replay( eventLog: FileStatus, - bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = { + appCompleted: Boolean, + bus: ReplayListenerBus, + eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): ApplicationEventListener = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") // Note that the eventLog may have *increased* in size since when we grabbed the filestatus, @@ -581,30 +612,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val logInput = EventLoggingListener.openEventLog(logPath, fs) try { val appListener = new ApplicationEventListener - val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) - bus.replay(logInput, logPath.toString, !appCompleted) - - // Without an app ID, new logs will render incorrectly in the listing page, so do not list or - // try to show their UI. - if (appListener.appId.isDefined) { - val attemptInfo = new FsApplicationAttemptInfo( - logPath.getName(), - appListener.appName.getOrElse(NOT_STARTED), - appListener.appId.getOrElse(logPath.getName()), - appListener.appAttemptId, - appListener.startTime.getOrElse(-1L), - appListener.endTime.getOrElse(-1L), - eventLog.getModificationTime(), - appListener.sparkUser.getOrElse(NOT_STARTED), - appCompleted, - eventLog.getLen() - ) - fileToAppInfo(logPath) = attemptInfo - Some(attemptInfo) - } else { - None - } + bus.replay(logInput, logPath.toString, !appCompleted, eventsFilter) + appListener } finally { logInput.close() } @@ -689,6 +699,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private[history] object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" + + private val NOT_STARTED = "" + + private val SPARK_HISTORY_FS_NUM_REPLAY_THREADS = "spark.history.fs.numReplayThreads" + + private val APPL_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationStart\"" + + private val APPL_END_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationEnd\"" } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index d32f5eb7bfe9..3eff8d952bfd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -25,6 +25,7 @@ import com.fasterxml.jackson.core.JsonParseException import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ReplayListenerBus._ import org.apache.spark.util.JsonProtocol /** @@ -43,30 +44,45 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { * @param sourceName Filename (or other source identifier) from whence @logData is being read * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations * encountered, log file might not finished writing) or not + * @param eventsFilter Filter function to select JSON event strings in the log data stream that + * should be parsed and replayed. When not specified, all event strings in the log data + * are parsed and replayed. */ def replay( logData: InputStream, sourceName: String, - maybeTruncated: Boolean = false): Unit = { + maybeTruncated: Boolean = false, + eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { + var currentLine: String = null - var lineNumber: Int = 1 + var lineNumber: Int = 0 + try { - val lines = Source.fromInputStream(logData).getLines() - while (lines.hasNext) { - currentLine = lines.next() + val lineEntries = Source.fromInputStream(logData) + .getLines() + .zipWithIndex + .filter { case (line, _) => eventsFilter(line) } + + while (lineEntries.hasNext) { try { + val entry = lineEntries.next() + + currentLine = entry._1 + lineNumber = entry._2 + 1 + postToAll(JsonProtocol.sparkEventFromJson(parse(currentLine))) } catch { case jpe: JsonParseException => // We can only ignore exception from last line of the file that might be truncated - if (!maybeTruncated || lines.hasNext) { + // the last entry may not be the very last line in the event log, but we treat it + // as such in a best effort to replay the given input + if (!maybeTruncated || lineEntries.hasNext) { throw jpe } else { logWarning(s"Got JsonParseException from log file $sourceName" + s" at line $lineNumber, the file might not have finished writing cleanly.") } } - lineNumber += 1 } } catch { case ioe: IOException => @@ -78,3 +94,12 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { } } + + +private[spark] object ReplayListenerBus { + + type ReplayEventsFilter = (String) => Boolean + + // utility filter that selects all event logs during replay + val SELECT_ALL_FILTER: ReplayEventsFilter = { (eventString: String) => true } +} From a21791e3164f4e6546fbe0a90017a4394a05deb1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 25 Oct 2016 12:08:17 -0700 Subject: [PATCH 011/381] [SPARK-18070][SQL] binary operator should not consider nullability when comparing input types ## What changes were proposed in this pull request? Binary operator requires its inputs to be of same type, but it should not consider nullability, e.g. `EqualTo` should be able to compare an element-nullable array and an element-non-nullable array. ## How was this patch tested? a regression test in `DataFrameSuite` Author: Wenchen Fan Closes #15606 from cloud-fan/type-bug. --- .../spark/sql/catalyst/expressions/Expression.scala | 2 +- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index fa1a2ad56ccb..9edc1ceff26a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -511,7 +511,7 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { override def checkInputDataTypes(): TypeCheckResult = { // First check whether left and right have the same type, then check if the type is acceptable. - if (left.dataType != right.dataType) { + if (!left.dataType.sameType(right.dataType)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") } else if (!inputType.acceptsType(left.dataType)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3fb7eeefba67..33b3b78c9f04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1649,4 +1649,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { dates.except(widenTypedRows).collect() dates.intersect(widenTypedRows).collect() } + + test("SPARK-18070 binary operator should not consider nullability when comparing input types") { + val rows = Seq(Row(Seq(1), Seq(1))) + val schema = new StructType() + .add("array1", ArrayType(IntegerType)) + .add("array2", ArrayType(IntegerType, containsNull = false)) + val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema) + assert(df.filter($"array1" === $"array2").count() == 1) + } } From 2c7394ad096201cd721be7f532da9d97028cc577 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 25 Oct 2016 13:11:21 -0700 Subject: [PATCH 012/381] [SPARK-18019][ML] Add instrumentation to GBTs ## What changes were proposed in this pull request? Add instrumentation for logging in ML GBT, part of umbrella ticket [SPARK-14567](https://issues.apache.org/jira/browse/SPARK-14567) ## How was this patch tested? Tested locally: ```` 16/10/20 10:24:51 INFO Instrumentation: GBTRegressor-gbtr_2b460d3e2e93-1207021668-45: training: numPartitions=1 storageLevel=StorageLevel(1 replicas) 16/10/20 10:24:51 INFO Instrumentation: GBTRegressor-gbtr_2b460d3e2e93-1207021668-45: {"maxIter":1} 16/10/20 10:24:51 INFO Instrumentation: GBTRegressor-gbtr_2b460d3e2e93-1207021668-45: {"numFeatures":2} 16/10/20 10:24:51 INFO Instrumentation: GBTRegressor-gbtr_2b460d3e2e93-1207021668-45: {"numClasses":0} ... 16/10/20 15:54:21 INFO Instrumentation: GBTRegressor-gbtr_065fad465377-1922077832-22: training finished ```` Author: sethah Closes #15574 from sethah/gbt_instr. --- .../apache/spark/ml/classification/GBTClassifier.scala | 10 +++++++++- .../org/apache/spark/ml/regression/GBTRegressor.scala | 9 ++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index ba70293273f9..8bffe0cda032 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -137,9 +137,17 @@ class GBTClassifier @Since("1.4.0") ( } val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + + val instr = Instrumentation.create(this, oldDataset) + instr.logParams(params: _*) + instr.logNumFeatures(numFeatures) + instr.logNumClasses(2) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) - new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) + val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) + instr.logSuccess(m) + m } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index bb01f9d5a364..fa69d60836e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -123,9 +123,16 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + + val instr = Instrumentation.create(this, oldDataset) + instr.logParams(params: _*) + instr.logNumFeatures(numFeatures) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) - new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) + val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) + instr.logSuccess(m) + m } @Since("1.4.0") From c329a568b58d65c492a43926bf0f588f2ae6a66e Mon Sep 17 00:00:00 2001 From: hayashidac Date: Wed, 26 Oct 2016 07:13:48 +0900 Subject: [PATCH 013/381] [SPARK-16988][SPARK SHELL] spark history server log needs to be fixed to show https url when ssl is enabled spark history server log needs to be fixed to show https url when ssl is enabled Author: chie8842 Closes #15611 from hayashidac/SPARK-16988. --- core/src/main/scala/org/apache/spark/ui/WebUI.scala | 5 ++++- .../test/scala/org/apache/spark/SSLOptionsSuite.scala | 10 +++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 4118fcf46b42..a05e0efb7a3e 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -147,7 +147,10 @@ private[spark] abstract class WebUI( } /** Return the url of web interface. Only valid after bind(). */ - def webUrl: String = s"http://$publicHostName:$boundPort" + def webUrl: String = { + val protocol = if (sslOptions.enabled) "https" else "http" + s"$protocol://$publicHostName:$boundPort" + } /** Return the actual port to which this server is bound. Only valid after bind(). */ def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 159b448e05b0..2b8b1805bc83 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -79,7 +79,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.protocol", "SSLv3") val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ui.ssl", defaults = Some(defaultOpts)) + val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -102,20 +102,20 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConf conf.set("spark.ssl.enabled", "true") - conf.set("spark.ui.ssl.enabled", "false") + conf.set("spark.ssl.ui.enabled", "false") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") - conf.set("spark.ui.ssl.keyStorePassword", "12345") + conf.set("spark.ssl.ui.keyStorePassword", "12345") conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") - conf.set("spark.ui.ssl.enabledAlgorithms", "ABC, DEF") + conf.set("spark.ssl.ui.enabledAlgorithms", "ABC, DEF") conf.set("spark.ssl.protocol", "SSLv3") val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ui.ssl", defaults = Some(defaultOpts)) + val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === false) assert(opts.trustStore.isDefined === true) From 12b3e8d2e02788c3bebfecdd69755e94d80011c9 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 25 Oct 2016 21:42:59 -0700 Subject: [PATCH 014/381] [SPARK-18007][SPARKR][ML] update SparkR MLP - add initalWeights parameter ## What changes were proposed in this pull request? update SparkR MLP, add initalWeights parameter. ## How was this patch tested? test added. Author: WeichenXu Closes #15552 from WeichenXu123/mlp_r_add_initialWeight_param. --- R/pkg/R/mllib.R | 14 ++++++++++---- R/pkg/inst/tests/testthat/test_mllib.R | 15 +++++++++++++++ .../r/MultilayerPerceptronClassifierWrapper.scala | 9 ++++++++- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b901307f8f40..bf182be8e23d 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -665,6 +665,8 @@ setMethod("predict", signature(object = "KMeansModel"), #' @param tol convergence tolerance of iterations. #' @param stepSize stepSize parameter. #' @param seed seed parameter for weights initialization. +#' @param initialWeights initialWeights parameter for weights initialization, it should be a +#' numeric vector. #' @param ... additional arguments passed to the method. #' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. #' @rdname spark.mlp @@ -677,8 +679,9 @@ setMethod("predict", signature(object = "KMeansModel"), #' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") #' #' # fit a Multilayer Perceptron Classification Model -#' model <- spark.mlp(df, blockSize = 128, layers = c(4, 5, 4, 3), solver = "l-bfgs", -#' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1) +#' model <- spark.mlp(df, blockSize = 128, layers = c(4, 3), solver = "l-bfgs", +#' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, +#' initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) #' #' # get the summary of the model #' summary(model) @@ -695,7 +698,7 @@ setMethod("predict", signature(object = "KMeansModel"), #' @note spark.mlp since 2.1.0 setMethod("spark.mlp", signature(data = "SparkDataFrame"), function(data, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, - tol = 1E-6, stepSize = 0.03, seed = NULL) { + tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) { if (is.null(layers)) { stop ("layers must be a integer vector with length > 1.") } @@ -706,10 +709,13 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame"), if (!is.null(seed)) { seed <- as.character(as.integer(seed)) } + if (!is.null(initialWeights)) { + initialWeights <- as.array(as.numeric(na.omit(initialWeights))) + } jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", "fit", data@sdf, as.integer(blockSize), as.array(layers), as.character(solver), as.integer(maxIter), as.numeric(tol), - as.numeric(stepSize), seed) + as.numeric(stepSize), seed, initialWeights) new("MultilayerPerceptronClassificationModel", jobj = jobj) }) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index c99315726a22..33cc069f1445 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -410,6 +410,21 @@ test_that("spark.mlp", { model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 2, 1, 2, 2, 1, 0, 0, 1)) + + # test initialWeights + model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights = + c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1)) + + model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights = + c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1)) + + model <- spark.mlp(df, layers = c(4, 3), maxIter = 2) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 2, 1, 0, 0, 1)) }) test_that("spark.naiveBayes", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala index 10673003534e..2193eb80e9fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -24,6 +24,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter} import org.apache.spark.sql.{DataFrame, Dataset} @@ -58,7 +59,8 @@ private[r] object MultilayerPerceptronClassifierWrapper maxIter: Int, tol: Double, stepSize: Double, - seed: String + seed: String, + initialWeights: Array[Double] ): MultilayerPerceptronClassifierWrapper = { // get labels and feature names from output schema val schema = data.schema @@ -73,6 +75,11 @@ private[r] object MultilayerPerceptronClassifierWrapper .setStepSize(stepSize) .setPredictionCol(PREDICTED_LABEL_COL) if (seed != null && seed.length > 0) mlp.setSeed(seed.toInt) + if (initialWeights != null) { + require(initialWeights.length > 0) + mlp.setInitialWeights(Vectors.dense(initialWeights)) + } + val pipeline = new Pipeline() .setStages(Array(mlp)) .fit(data) From 93b8ad184aa3634f340d43a8bdf99836ef3d4f6c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 26 Oct 2016 00:38:34 -0700 Subject: [PATCH 015/381] [SPARK-17693][SQL] Fixed Insert Failure To Data Source Tables when the Schema has the Comment Field ### What changes were proposed in this pull request? ```SQL CREATE TABLE tab1(col1 int COMMENT 'a', col2 int) USING parquet INSERT INTO TABLE tab1 SELECT 1, 2 ``` The insert attempt will fail if the target table has a column with comments. The error is strange to the external users: ``` assertion failed: No plan for InsertIntoTable Relation[col1#15,col2#16] parquet, false, false +- Project [1 AS col1#19, 2 AS col2#20] +- OneRowRelation$ ``` This PR is to fix the above bug by checking the metadata when comparing the schema between the table and the query. If not matched, we also copy the metadata. This is an alternative to https://github.com/apache/spark/pull/15266 ### How was this patch tested? Added a test case Author: gatorsmile Closes #15615 from gatorsmile/insertDataSourceTableWithCommentSolution2. --- .../sql/execution/datasources/rules.scala | 10 ++++- .../spark/sql/sources/InsertSuite.scala | 42 +++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index cf501cdc919e..4647b11af4df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -248,10 +248,16 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { expectedOutput: Seq[Attribute]): InsertIntoTable = { val newChildOutput = expectedOutput.zip(insert.child.output).map { case (expected, actual) => - if (expected.dataType.sameType(actual.dataType) && expected.name == actual.name) { + if (expected.dataType.sameType(actual.dataType) && + expected.name == actual.name && + expected.metadata == actual.metadata) { actual } else { - Alias(Cast(actual, expected.dataType), expected.name)() + // Renaming is needed for handling the following cases like + // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 + // 2) Target tables have column metadata + Alias(Cast(actual, expected.dataType), expected.name)( + explicitMetadata = Option(expected.metadata)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 5eb54643f204..4a85b5975ea5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -185,6 +185,48 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { ) } + test("INSERT INTO TABLE with Comment in columns") { + val tabName = "tab1" + withTable(tabName) { + sql( + s""" + |CREATE TABLE $tabName(col1 int COMMENT 'a', col2 int) + |USING parquet + """.stripMargin) + sql(s"INSERT INTO TABLE $tabName SELECT 1, 2") + + checkAnswer( + sql(s"SELECT col1, col2 FROM $tabName"), + Row(1, 2) :: Nil + ) + } + } + + test("INSERT INTO TABLE - complex type but different names") { + val tab1 = "tab1" + val tab2 = "tab2" + withTable(tab1, tab2) { + sql( + s""" + |CREATE TABLE $tab1 (s struct) + |USING parquet + """.stripMargin) + sql(s"INSERT INTO TABLE $tab1 SELECT named_struct('col1','1','col2','2')") + + sql( + s""" + |CREATE TABLE $tab2 (p struct) + |USING parquet + """.stripMargin) + sql(s"INSERT INTO TABLE $tab2 SELECT * FROM $tab1") + + checkAnswer( + spark.table(tab1), + spark.table(tab2) + ) + } + } + test("it is not allowed to write to a table while querying it.") { val message = intercept[AnalysisException] { sql( From 6c7d094ec4d45a05c1ec8a418e507e45f5a88b7d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 26 Oct 2016 14:19:40 +0200 Subject: [PATCH 016/381] [SPARK-18022][SQL] java.lang.NullPointerException instead of real exception when saving DF to MySQL ## What changes were proposed in this pull request? On null next exception in JDBC, don't init it as cause or suppressed ## How was this patch tested? Existing tests Author: Sean Owen Closes #15599 from srowen/SPARK-18022. --- .../apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index e32db73bd6c6..41edb6511c2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -607,7 +607,7 @@ object JdbcUtils extends Logging { } catch { case e: SQLException => val cause = e.getNextException - if (e.getCause != cause) { + if (cause != null && e.getCause != cause) { if (e.getCause == null) { e.initCause(cause) } else { From 297813647508480d7b4b5bccd02b93b8b914301f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 26 Oct 2016 14:23:11 +0200 Subject: [PATCH 017/381] [SPARK-18027][YARN] .sparkStaging not clean on RM ApplicationNotFoundException ## What changes were proposed in this pull request? Cleanup YARN staging dir on all `KILLED`/`FAILED` paths in `monitorApplication` ## How was this patch tested? Existing tests Author: Sean Owen Closes #15598 from srowen/SPARK-18027. --- yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 6e4f68c74c36..55e4a833b670 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1059,9 +1059,11 @@ private[spark] class Client( } catch { case e: ApplicationNotFoundException => logError(s"Application $appId not found.") + cleanupStagingDir(appId) return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) case NonFatal(e) => logError(s"Failed to contact YARN for application $appId.", e) + // Don't necessarily clean up staging dir because status is unknown return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED) } val state = report.getYarnApplicationState From 5d0f81da49e86ee93ecf679a20d024ea2cb8b3d3 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Wed, 26 Oct 2016 14:26:54 +0200 Subject: [PATCH 018/381] [SPARK-4411][WEB UI] Add "kill" link for jobs in the UI ## What changes were proposed in this pull request? Currently users can kill stages via the web ui but not jobs directly (jobs are killed if one of their stages is). I've added the ability to kill jobs via the web ui. This code change is based on #4823 by lianhuiwang and updated to work with the latest code matching how stages are currently killed. In general I've copied the kill stage code warning and note comments and all. I also updated applicable tests and documentation. ## How was this patch tested? Manually tested and dev/run-tests ![screen shot 2016-10-11 at 4 49 43 pm](https://cloud.githubusercontent.com/assets/13952758/19292857/12f1b7c0-8fd4-11e6-8982-210249f7b697.png) Author: Alex Bozarth Author: Lianhui Wang Closes #15441 from ajbozarth/spark4411. --- .../scala/org/apache/spark/ui/SparkUI.scala | 11 +++-- .../apache/spark/ui/jobs/AllJobsPage.scala | 34 ++++++++++++-- .../org/apache/spark/ui/jobs/JobsTab.scala | 17 +++++++ .../org/apache/spark/ui/jobs/StageTable.scala | 5 +- .../org/apache/spark/ui/jobs/StagesTab.scala | 17 +++---- .../org/apache/spark/ui/UISeleniumSuite.scala | 47 +++++++++++++++---- docs/configuration.md | 2 +- 7 files changed, 104 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index ef71db89798f..f631a047a707 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -58,14 +58,13 @@ private[spark] class SparkUI private ( val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false) - - val stagesTab = new StagesTab(this) - var appId: String = _ /** Initialize all components of the server. */ def initialize() { - attachTab(new JobsTab(this)) + val jobsTab = new JobsTab(this) + attachTab(jobsTab) + val stagesTab = new StagesTab(this) attachTab(stagesTab) attachTab(new StorageTab(this)) attachTab(new EnvironmentTab(this)) @@ -73,7 +72,9 @@ private[spark] class SparkUI private ( attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) - // This should be POST only, but, the YARN AM proxy won't proxy POSTs + // These should be POST only, but, the YARN AM proxy won't proxy POSTs + attachHandler(createRedirectHandler( + "/jobs/job/kill", "/jobs/", jobsTab.handleKillRequest, httpMethods = Set("GET", "POST"))) attachHandler(createRedirectHandler( "/stages/stage/kill", "/stages/", stagesTab.handleKillRequest, httpMethods = Set("GET", "POST"))) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index f6713097b934..173fc3cf31ce 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -218,7 +218,8 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { request: HttpServletRequest, tableHeaderId: String, jobTag: String, - jobs: Seq[JobUIData]): Seq[Node] = { + jobs: Seq[JobUIData], + killEnabled: Boolean): Seq[Node] = { val allParameters = request.getParameterMap.asScala.toMap val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag)) .map(para => para._1 + "=" + para._2(0)) @@ -264,6 +265,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { parameterOtherTable, parent.jobProgresslistener.stageIdToInfo, parent.jobProgresslistener.stageIdToData, + killEnabled, currentTime, jobIdTitle, pageSize = jobPageSize, @@ -290,9 +292,12 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val completedJobs = listener.completedJobs.reverse.toSeq val failedJobs = listener.failedJobs.reverse.toSeq - val activeJobsTable = jobsTable(request, "active", "activeJob", activeJobs) - val completedJobsTable = jobsTable(request, "completed", "completedJob", completedJobs) - val failedJobsTable = jobsTable(request, "failed", "failedJob", failedJobs) + val activeJobsTable = + jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled) + val completedJobsTable = + jobsTable(request, "completed", "completedJob", completedJobs, killEnabled = false) + val failedJobsTable = + jobsTable(request, "failed", "failedJob", failedJobs, killEnabled = false) val shouldShowActiveJobs = activeJobs.nonEmpty val shouldShowCompletedJobs = completedJobs.nonEmpty @@ -483,6 +488,7 @@ private[ui] class JobPagedTable( parameterOtherTable: Iterable[String], stageIdToInfo: HashMap[Int, StageInfo], stageIdToData: HashMap[(Int, Int), StageUIData], + killEnabled: Boolean, currentTime: Long, jobIdTitle: String, pageSize: Int, @@ -586,12 +592,30 @@ private[ui] class JobPagedTable( override def row(jobTableRow: JobTableRowData): Seq[Node] = { val job = jobTableRow.jobData + val killLink = if (killEnabled) { + val confirm = + s"if (window.confirm('Are you sure you want to kill job ${job.jobId} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" + // SPARK-6846 this should be POST-only but YARN AM won't proxy POST + /* + val killLinkUri = s"$basePathUri/jobs/job/kill/" +
+ + (kill) +
+ */ + val killLinkUri = s"$basePath/jobs/job/kill/?id=${job.jobId}" + (kill) + } else { + Seq.empty + } + {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")} - {jobTableRow.jobDescription} + {jobTableRow.jobDescription} {killLink} {jobTableRow.lastStageName} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 7b00b558d591..620c54c2dc0a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.jobs +import javax.servlet.http.HttpServletRequest + import org.apache.spark.scheduler.SchedulingMode import org.apache.spark.ui.{SparkUI, SparkUITab} @@ -35,4 +37,19 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { attachPage(new AllJobsPage(this)) attachPage(new JobPage(this)) + + def handleKillRequest(request: HttpServletRequest): Unit = { + if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { + val jobId = Option(request.getParameter("id")).map(_.toInt) + jobId.foreach { id => + if (jobProgresslistener.activeJobs.contains(id)) { + sc.foreach(_.cancelJob(id)) + // Do a quick pause here to give Spark time to kill the job so it shows up as + // killed after the refresh. Note that this will block the serving thread so the + // time should be limited in duration. + Thread.sleep(100) + } + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 9b9b4681ba5d..c9d0431e2d2f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -353,12 +353,13 @@ private[ui] class StagePagedTable( val killLinkUri = s"$basePathUri/stages/stage/kill/"
- (kill)
*/ - val killLinkUri = s"$basePathUri/stages/stage/kill/?id=${s.stageId}&terminate=true" + val killLinkUri = s"$basePathUri/stages/stage/kill/?id=${s.stageId}" (kill) + } else { + Seq.empty } val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 573192ac17d4..c1f25114371f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -39,15 +39,16 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { - val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean - val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt - if (stageId >= 0 && killFlag && progressListener.activeStages.contains(stageId)) { - sc.get.cancelStage(stageId) + val stageId = Option(request.getParameter("id")).map(_.toInt) + stageId.foreach { id => + if (progressListener.activeStages.contains(id)) { + sc.foreach(_.cancelStage(id)) + // Do a quick pause here to give Spark time to kill the stage so it shows up as + // killed after the refresh. Note that this will block the serving thread so the + // time should be limited in duration. + Thread.sleep(100) + } } - // Do a quick pause here to give Spark time to kill the stage so it shows up as - // killed after the refresh. Note that this will block the serving thread so the - // time should be limited in duration. - Thread.sleep(100) } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index fd12a21b7927..e5d408a16736 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -194,6 +194,22 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() } + withSpark(newSparkContext(killEnabled = true)) { sc => + runSlowJob(sc) + eventually(timeout(5 seconds), interval(50 milliseconds)) { + goToUi(sc, "/jobs") + assert(hasKillLink) + } + } + + withSpark(newSparkContext(killEnabled = false)) { sc => + runSlowJob(sc) + eventually(timeout(5 seconds), interval(50 milliseconds)) { + goToUi(sc, "/jobs") + assert(!hasKillLink) + } + } + withSpark(newSparkContext(killEnabled = true)) { sc => runSlowJob(sc) eventually(timeout(5 seconds), interval(50 milliseconds)) { @@ -453,20 +469,24 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } test("kill stage POST/GET response is correct") { - def getResponseCode(url: URL, method: String): Int = { - val connection = url.openConnection().asInstanceOf[HttpURLConnection] - connection.setRequestMethod(method) - connection.connect() - val code = connection.getResponseCode() - connection.disconnect() - code + withSpark(newSparkContext(killEnabled = true)) { sc => + sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + val url = new URL( + sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0") + // SPARK-6846: should be POST only but YARN AM doesn't proxy POST + getResponseCode(url, "GET") should be (200) + getResponseCode(url, "POST") should be (200) + } } + } + test("kill job POST/GET response is correct") { withSpark(newSparkContext(killEnabled = true)) { sc => sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( - sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0&terminate=true") + sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST getResponseCode(url, "GET") should be (200) getResponseCode(url, "POST") should be (200) @@ -651,6 +671,17 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } + def getResponseCode(url: URL, method: String): Int = { + val connection = url.openConnection().asInstanceOf[HttpURLConnection] + connection.setRequestMethod(method) + try { + connection.connect() + connection.getResponseCode() + } finally { + connection.disconnect() + } + } + def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } diff --git a/docs/configuration.md b/docs/configuration.md index b07867d99aa9..6600cb6c0ac0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -632,7 +632,7 @@ Apart from these, the following properties are also available, and may be useful spark.ui.killEnabled true - Allows stages and corresponding jobs to be killed from the web ui. + Allows jobs and stages to be killed from the web UI. From 402205ddf749e7478683ce1b0443df63b46b03fd Mon Sep 17 00:00:00 2001 From: Shuai Lin Date: Wed, 26 Oct 2016 14:31:47 +0200 Subject: [PATCH 019/381] [SPARK-17802] Improved caller context logging. ## What changes were proposed in this pull request? [SPARK-16757](https://issues.apache.org/jira/browse/SPARK-16757) sets the hadoop `CallerContext` when calling hadoop/hdfs apis to make spark applications more diagnosable in hadoop/hdfs logs. However, the `org.apache.hadoop.ipc.CallerContext` class is only added since [hadoop 2.8](https://issues.apache.org/jira/browse/HDFS-9184), which is not officially releaed yet. So each time `utils.CallerContext.setCurrentContext()` is called (e.g [when a task is created](https://github.com/apache/spark/blob/b678e46/core/src/main/scala/org/apache/spark/scheduler/Task.scala#L95-L96)), a "java.lang.ClassNotFoundException: org.apache.hadoop.ipc.CallerContext" error is logged, which pollutes the spark logs when there are lots of tasks. This patch improves this behaviour by only logging the `ClassNotFoundException` once. ## How was this patch tested? Existing tests. Author: Shuai Lin Closes #15377 from lins05/spark-17802-improve-callercontext-logging. --- .../scala/org/apache/spark/util/Utils.scala | 48 +++++++++++++------ .../org/apache/spark/util/UtilsSuite.scala | 7 +-- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index bfc609419ccd..e57eb0de2689 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2508,6 +2508,26 @@ private[spark] object Utils extends Logging { } } +private[util] object CallerContext extends Logging { + val callerContextSupported: Boolean = { + SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && { + try { + // scalastyle:off classforname + Class.forName("org.apache.hadoop.ipc.CallerContext") + Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") + // scalastyle:on classforname + true + } catch { + case _: ClassNotFoundException => + false + case NonFatal(e) => + logWarning("Fail to load the CallerContext class", e) + false + } + } + } +} + /** * An utility class used to set up Spark caller contexts to HDFS and Yarn. The `context` will be * constructed by parameters passed in. @@ -2554,21 +2574,21 @@ private[spark] class CallerContext( * Set up the caller context [[context]] by invoking Hadoop CallerContext API of * [[org.apache.hadoop.ipc.CallerContext]], which was added in hadoop 2.8. */ - def setCurrentContext(): Boolean = { - var succeed = false - try { - // scalastyle:off classforname - val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext") - val Builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") - // scalastyle:on classforname - val builderInst = Builder.getConstructor(classOf[String]).newInstance(context) - val hdfsContext = Builder.getMethod("build").invoke(builderInst) - callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext) - succeed = true - } catch { - case NonFatal(e) => logInfo("Fail to set Spark caller context", e) + def setCurrentContext(): Unit = { + if (CallerContext.callerContextSupported) { + try { + // scalastyle:off classforname + val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext") + val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") + // scalastyle:on classforname + val builderInst = builder.getConstructor(classOf[String]).newInstance(context) + val hdfsContext = builder.getMethod("build").invoke(builderInst) + callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext) + } catch { + case NonFatal(e) => + logWarning("Fail to set Spark caller context", e) + } } - succeed } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4dda80f10a08..aeb2969fd579 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -843,14 +843,11 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("Set Spark CallerContext") { val context = "test" - try { + new CallerContext(context).setCurrentContext() + if (CallerContext.callerContextSupported) { val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext") - assert(new CallerContext(context).setCurrentContext()) assert(s"SPARK_$context" === callerContext.getMethod("getCurrent").invoke(null).toString) - } catch { - case e: ClassNotFoundException => - assert(!new CallerContext(context).setCurrentContext()) } } From 3c023570b28bc1ed24f5b2448311130fd1777fd3 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Wed, 26 Oct 2016 17:09:48 +0200 Subject: [PATCH 020/381] [SPARK-17733][SQL] InferFiltersFromConstraints rule never terminates for query ## What changes were proposed in this pull request? The function `QueryPlan.inferAdditionalConstraints` and `UnaryNode.getAliasedConstraints` can produce a non-converging set of constraints for recursive functions. For instance, if we have two constraints of the form(where a is an alias): `a = b, a = f(b, c)` Applying both these rules in the next iteration would infer: `f(b, c) = f(f(b, c), c)` This process repeated, the iteration won't converge and the set of constraints will grow larger and larger until OOM. ~~To fix this problem, we collect alias from expressions and skip infer constraints if we are to transform an `Expression` to another which contains it.~~ To fix this problem, we apply additional check in `inferAdditionalConstraints`, when it's possible to generate recursive constraints, we skip generate that. ## How was this patch tested? Add new testcase in `SQLQuerySuite`/`InferFiltersFromConstraintsSuite`. Author: jiangxingbo Closes #15319 from jiangxb1987/constraints. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 88 +++++++++++++++++-- .../InferFiltersFromConstraintsSuite.scala | 87 +++++++++++++++++- .../spark/sql/catalyst/plans/PlanTest.scala | 25 +++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 5 +- 4 files changed, 191 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0fb6e7d2e795..45ee2964d4db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -68,26 +68,104 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case _ => Seq.empty[Attribute] } + // Collect aliases from expressions, so we may avoid producing recursive constraints. + private lazy val aliasMap = AttributeMap( + (expressions ++ children.flatMap(_.expressions)).collect { + case a: Alias => (a.toAttribute, a.child) + }) + /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5` + * additional constraint of the form `b = 5`. + * + * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)` + * as they are often useless and can lead to a non-converging set of constraints. */ private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + val constraintClasses = generateEquivalentConstraintClasses(constraints) + var inferredConstraints = Set.empty[Expression] constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(l) => r + val candidateConstraints = constraints - eq + inferredConstraints ++= candidateConstraints.map(_ transform { + case a: Attribute if a.semanticEquals(l) && + !isRecursiveDeduction(r, constraintClasses) => r }) - inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(r) => l + inferredConstraints ++= candidateConstraints.map(_ transform { + case a: Attribute if a.semanticEquals(r) && + !isRecursiveDeduction(l, constraintClasses) => l }) case _ => // No inference } inferredConstraints -- constraints } + /* + * Generate a sequence of expression sets from constraints, where each set stores an equivalence + * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following + * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal + * to an selected attribute. + */ + private def generateEquivalentConstraintClasses( + constraints: Set[Expression]): Seq[Set[Expression]] = { + var constraintClasses = Seq.empty[Set[Expression]] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + // Transform [[Alias]] to its child. + val left = aliasMap.getOrElse(l, l) + val right = aliasMap.getOrElse(r, r) + // Get the expression set for an equivalence constraint class. + val leftConstraintClass = getConstraintClass(left, constraintClasses) + val rightConstraintClass = getConstraintClass(right, constraintClasses) + if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) { + // Combine the two sets. + constraintClasses = constraintClasses + .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+ + (leftConstraintClass ++ rightConstraintClass) + } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty + // Update equivalence class of `left` expression. + constraintClasses = constraintClasses + .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right) + } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty + // Update equivalence class of `right` expression. + constraintClasses = constraintClasses + .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left) + } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty + // Create new equivalence constraint class since neither expression presents + // in any classes. + constraintClasses = constraintClasses :+ Set(left, right) + } + case _ => // Skip + } + + constraintClasses + } + + /* + * Get all expressions equivalent to the selected expression. + */ + private def getConstraintClass( + expr: Expression, + constraintClasses: Seq[Set[Expression]]): Set[Expression] = + constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression]) + + /* + * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it + * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function. + * Here we first get all expressions equal to `attr` and then check whether at least one of them + * is a child of the referenced expression. + */ + private def isRecursiveDeduction( + attr: Attribute, + constraintClasses: Seq[Set[Expression]]): Boolean = { + val expr = aliasMap.getOrElse(attr, attr) + getConstraintClass(expr, constraintClasses).exists { e => + expr.children.exists(_.semanticEquals(e)) + } + } + /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For * example, if this set contains the expression `a = 2` then that expression is guaranteed to diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e7fdd5a6202b..9f57f66a2ea2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -27,9 +27,12 @@ import org.apache.spark.sql.catalyst.rules._ class InferFiltersFromConstraintsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("InferFilters", FixedPoint(5), InferFiltersFromConstraints) :: - Batch("PredicatePushdown", FixedPoint(5), PushPredicateThroughJoin) :: - Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil + val batches = + Batch("InferAndPushDownFilters", FixedPoint(100), + PushPredicateThroughJoin, + PushDownPredicate, + InferFiltersFromConstraints, + CombineFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -120,4 +123,82 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("inner join with alias: alias contains multiple attributes") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + + val originalQuery = t1.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) + .analyze + val correctAnswer = t1 + .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b))) + .select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2.where(IsNotNull('a)), Inner, + Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("inner join with alias: alias contains single attributes") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + + val originalQuery = t1.select('a, 'b.as('d)).as("t") + .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) + .analyze + val correctAnswer = t1 + .where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b) + .select('a, 'b.as('d)).as("t") + .join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner, + Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("inner join with alias: don't generate constraints for recursive functions") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + + val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2, Inner, + Some("t.a".attr === "t2.a".attr + && "t.d".attr === "t2.a".attr + && "t.int_col".attr === "t2.a".attr)) + .analyze + val correctAnswer = t1 + .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a + && Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a)) + && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)) + && Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b + && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) + && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) + && Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b) + .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2 + .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a + && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner, + Some("t.a".attr === "t2.a".attr + && "t.d".attr === "t2.a".attr + && "t.int_col".attr === "t2.a".attr + && Coalesce(Seq("t.d".attr, "t.d".attr)) <=> "t.int_col".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("generate correct filters for alias that don't produce recursive constraints") { + val t1 = testRelation.subquery('t1) + + val originalQuery = t1.select('a.as('x), 'b.as('y)).where('x === 1 && 'x === 'y).analyze + val correctAnswer = + t1.where('a === 1 && 'b === 1 && 'a === 'b && IsNotNull('a) && IsNotNull('b)) + .select('a.as('x), 'b.as('y)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6310f0c2bc0e..64e268703bf5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ /** @@ -56,16 +56,37 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) * etc., will all now be equivalent. * - Sample the seed will replaced by 0L. + * - Join conditions will be resorted by hashCode. */ private def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case filter @ Filter(condition: Expression, child: LogicalPlan) => - Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child) + Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode()) + .reduce(And), child) case sample: Sample => sample.copy(seed = 0L)(true) + case join @ Join(left, right, joinType, condition) if condition.isDefined => + val newCondition = + splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode()) + .reduce(And) + Join(left, right, joinType, Some(newCondition)) } } + /** + * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be + * equivalent: + * 1. (a = b), (b = a); + * 2. (a <=> b), (b <=> a). + */ + private def rewriteEqual(condition: Expression): Expression = condition match { + case eq @ EqualTo(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) + case eq @ EqualNullSafe(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) + case _ => condition // Don't reorder. + } + /** Fails the test if the two plans do not match */ protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { val normalized1 = normalizePlan(normalizeExprIds(plan1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 60978efddd7f..bd4c25315c31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,12 +19,9 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import org.apache.spark.{AccumulatorSuite, SparkException} -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.expressions.SortOrder -import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} From 4bee9540790a40acb74db4b0b44c364c4b3f537d Mon Sep 17 00:00:00 2001 From: Mark Grover Date: Wed, 26 Oct 2016 09:07:30 -0700 Subject: [PATCH 021/381] =?UTF-8?q?[SPARK-18093][SQL]=20Fix=20default=20va?= =?UTF-8?q?lue=20test=20in=20SQLConfSuite=20to=20work=20rega=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …rdless of warehouse dir's existence ## What changes were proposed in this pull request? Appending a trailing slash, if there already isn't one for the sake comparison of the two paths. It doesn't take away from the essence of the check, but removes any potential mismatch due to lack of trailing slash. ## How was this patch tested? Ran unit tests and they passed. Author: Mark Grover Closes #15623 from markgrover/spark-18093. --- .../scala/org/apache/spark/sql/internal/SQLConfSuite.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index a89a43fa1e77..11d4693f1c2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -215,12 +215,15 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } test("default value of WAREHOUSE_PATH") { + val original = spark.conf.get(SQLConf.WAREHOUSE_PATH) try { // to get the default value, always unset it spark.conf.unset(SQLConf.WAREHOUSE_PATH.key) - assert(new Path(Utils.resolveURI("spark-warehouse")).toString === - spark.sessionState.conf.warehousePath + "/") + // JVM adds a trailing slash if the directory exists and leaves it as-is, if it doesn't + // In our comparison, strip trailing slash off of both sides, to account for such cases + assert(new Path(Utils.resolveURI("spark-warehouse")).toString.stripSuffix("/") === spark + .sessionState.conf.warehousePath.stripSuffix("/")) } finally { sql(s"set ${SQLConf.WAREHOUSE_PATH}=$original") } From 312ea3f7f65532818e11016d6d780ad47485175f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 26 Oct 2016 09:28:28 -0700 Subject: [PATCH 022/381] [SPARK-17748][FOLLOW-UP][ML] Reorg variables of WeightedLeastSquares. ## What changes were proposed in this pull request? This is follow-up work of #15394. Reorg some variables of ```WeightedLeastSquares``` and fix one minor issue of ```WeightedLeastSquaresSuite```. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #15621 from yanboliang/spark-17748. --- .../spark/ml/optim/WeightedLeastSquares.scala | 139 ++++++++++-------- .../ml/optim/WeightedLeastSquaresSuite.scala | 15 +- 2 files changed, 86 insertions(+), 68 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 2223f126f1b6..90c24e1b590e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -101,23 +101,19 @@ private[ml] class WeightedLeastSquares( summary.validate() logInfo(s"Number of instances: ${summary.count}.") val k = if (fitIntercept) summary.k + 1 else summary.k + val numFeatures = summary.k val triK = summary.triK val wSum = summary.wSum - val bBar = summary.bBar - val bbBar = summary.bbBar - val aBar = summary.aBar - val aStd = summary.aStd - val abBar = summary.abBar - val aaBar = summary.aaBar - val numFeatures = abBar.size + val rawBStd = summary.bStd + val rawBBar = summary.bBar // if b is constant (rawBStd is zero), then b cannot be scaled. In this case - // setting bStd=abs(bBar) ensures that b is not scaled anymore in l-bfgs algorithm. - val bStd = if (rawBStd == 0.0) math.abs(bBar) else rawBStd + // setting bStd=abs(rawBBar) ensures that b is not scaled anymore in l-bfgs algorithm. + val bStd = if (rawBStd == 0.0) math.abs(rawBBar) else rawBStd if (rawBStd == 0) { - if (fitIntercept || bBar == 0.0) { - if (bBar == 0.0) { + if (fitIntercept || rawBBar == 0.0) { + if (rawBBar == 0.0) { logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + s"and the intercept will all be zero; as a result, training is not needed.") } else { @@ -126,7 +122,7 @@ private[ml] class WeightedLeastSquares( s"training is not needed.") } val coefficients = new DenseVector(Array.ofDim(numFeatures)) - val intercept = bBar + val intercept = rawBBar val diagInvAtWA = new DenseVector(Array(0D)) return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA, Array(0D)) } else { @@ -137,53 +133,70 @@ private[ml] class WeightedLeastSquares( } } - // scale aBar to standardized space in-place - val aBarValues = aBar.values - var j = 0 - while (j < numFeatures) { - if (aStd(j) == 0.0) { - aBarValues(j) = 0.0 - } else { - aBarValues(j) /= aStd(j) - } - j += 1 - } + val bBar = summary.bBar / bStd + val bbBar = summary.bbBar / (bStd * bStd) - // scale abBar to standardized space in-place - val abBarValues = abBar.values + val aStd = summary.aStd val aStdValues = aStd.values - j = 0 - while (j < numFeatures) { - if (aStdValues(j) == 0.0) { - abBarValues(j) = 0.0 - } else { - abBarValues(j) /= (aStdValues(j) * bStd) + + val aBar = { + val _aBar = summary.aBar + val _aBarValues = _aBar.values + var i = 0 + // scale aBar to standardized space in-place + while (i < numFeatures) { + if (aStdValues(i) == 0.0) { + _aBarValues(i) = 0.0 + } else { + _aBarValues(i) /= aStdValues(i) + } + i += 1 } - j += 1 + _aBar } + val aBarValues = aBar.values - // scale aaBar to standardized space in-place - val aaBarValues = aaBar.values - j = 0 - var p = 0 - while (j < numFeatures) { - val aStdJ = aStdValues(j) + val abBar = { + val _abBar = summary.abBar + val _abBarValues = _abBar.values var i = 0 - while (i <= j) { - val aStdI = aStdValues(i) - if (aStdJ == 0.0 || aStdI == 0.0) { - aaBarValues(p) = 0.0 + // scale abBar to standardized space in-place + while (i < numFeatures) { + if (aStdValues(i) == 0.0) { + _abBarValues(i) = 0.0 } else { - aaBarValues(p) /= (aStdI * aStdJ) + _abBarValues(i) /= (aStdValues(i) * bStd) } - p += 1 i += 1 } - j += 1 + _abBar } + val abBarValues = abBar.values - val bBarStd = bBar / bStd - val bbBarStd = bbBar / (bStd * bStd) + val aaBar = { + val _aaBar = summary.aaBar + val _aaBarValues = _aaBar.values + var j = 0 + var p = 0 + // scale aaBar to standardized space in-place + while (j < numFeatures) { + val aStdJ = aStdValues(j) + var i = 0 + while (i <= j) { + val aStdI = aStdValues(i) + if (aStdJ == 0.0 || aStdI == 0.0) { + _aaBarValues(p) = 0.0 + } else { + _aaBarValues(p) /= (aStdI * aStdJ) + } + p += 1 + i += 1 + } + j += 1 + } + _aaBar + } + val aaBarValues = aaBar.values val effectiveRegParam = regParam / bStd val effectiveL1RegParam = elasticNetParam * effectiveRegParam @@ -191,11 +204,11 @@ private[ml] class WeightedLeastSquares( // add L2 regularization to diagonals var i = 0 - j = 2 + var j = 2 while (i < triK) { var lambda = effectiveL2RegParam if (!standardizeFeatures) { - val std = aStd(j - 2) + val std = aStdValues(j - 2) if (std != 0.0) { lambda /= (std * std) } else { @@ -209,8 +222,9 @@ private[ml] class WeightedLeastSquares( i += j j += 1 } - val aa = getAtA(aaBar.values, aBar.values) - val ab = getAtB(abBar.values, bBarStd) + + val aa = getAtA(aaBarValues, aBarValues) + val ab = getAtB(abBarValues, bBar) val solver = if ((solverType == WeightedLeastSquares.Auto && elasticNetParam != 0.0 && regParam != 0.0) || (solverType == WeightedLeastSquares.QuasiNewton)) { @@ -237,22 +251,23 @@ private[ml] class WeightedLeastSquares( val solution = solver match { case cholesky: CholeskySolver => try { - cholesky.solve(bBarStd, bbBarStd, ab, aa, aBar) + cholesky.solve(bBar, bbBar, ab, aa, aBar) } catch { // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to - // quasi-newton solver + // Quasi-Newton solver. case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto => logWarning("Cholesky solver failed due to singular covariance matrix. " + "Retrying with Quasi-Newton solver.") // ab and aa were modified in place, so reconstruct them - val _aa = getAtA(aaBar.values, aBar.values) - val _ab = getAtB(abBar.values, bBarStd) + val _aa = getAtA(aaBarValues, aBarValues) + val _ab = getAtB(abBarValues, bBar) val newSolver = new QuasiNewtonSolver(fitIntercept, maxIter, tol, None) - newSolver.solve(bBarStd, bbBarStd, _ab, _aa, aBar) + newSolver.solve(bBar, bbBar, _ab, _aa, aBar) } case qn: QuasiNewtonSolver => - qn.solve(bBarStd, bbBarStd, ab, aa, aBar) + qn.solve(bBar, bbBar, ab, aa, aBar) } + val (coefficientArray, intercept) = if (fitIntercept) { (solution.coefficients.slice(0, solution.coefficients.length - 1), solution.coefficients.last * bStd) @@ -271,7 +286,11 @@ private[ml] class WeightedLeastSquares( // aaInv is a packed upper triangular matrix, here we get all elements on diagonal val diagInvAtWA = solution.aaInv.map { inv => new DenseVector((1 to k).map { i => - val multiplier = if (i == k && fitIntercept) 1.0 else aStdValues(i - 1) * aStdValues(i - 1) + val multiplier = if (i == k && fitIntercept) { + 1.0 + } else { + aStdValues(i - 1) * aStdValues(i - 1) + } inv(i + (i - 1) * i / 2 - 1) / (wSum * multiplier) }.toArray) }.getOrElse(new DenseVector(Array(0D))) @@ -280,7 +299,7 @@ private[ml] class WeightedLeastSquares( solution.objectiveHistory.getOrElse(Array(0D))) } - /** Construct A^T^ A from summary statistics. */ + /** Construct A^T^ A (append bias if necessary). */ private def getAtA(aaBar: Array[Double], aBar: Array[Double]): DenseVector = { if (fitIntercept) { new DenseVector(Array.concat(aaBar, aBar, Array(1.0))) @@ -289,7 +308,7 @@ private[ml] class WeightedLeastSquares( } } - /** Construct A^T^ b from summary statistics. */ + /** Construct A^T^ b (append bias if necessary). */ private def getAtB(abBar: Array[Double], bBar: Double): DenseVector = { if (fitIntercept) { new DenseVector(Array.concat(abBar, Array(bBar))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala index 3cdab0327991..093d02ea7a14 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -361,14 +361,13 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext for (fitIntercept <- Seq(false, true); standardization <- Seq(false, true); (lambda, alpha) <- Seq((0.0, 0.0), (0.5, 0.0), (0.5, 0.5), (0.5, 1.0))) { - for (solver <- Seq(WeightedLeastSquares.Auto, WeightedLeastSquares.Cholesky)) { - val wls = new WeightedLeastSquares(fitIntercept, regParam = lambda, elasticNetParam = alpha, - standardizeFeatures = standardization, standardizeLabel = true, - solverType = WeightedLeastSquares.QuasiNewton) - val model = wls.fit(constantFeaturesInstances) - val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) - assert(actual ~== expectedQuasiNewton(idx) absTol 1e-6) - } + val wls = new WeightedLeastSquares(fitIntercept, regParam = lambda, elasticNetParam = alpha, + standardizeFeatures = standardization, standardizeLabel = true, + solverType = WeightedLeastSquares.QuasiNewton) + val model = wls.fit(constantFeaturesInstances) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~== expectedQuasiNewton(idx) absTol 1e-6) + idx += 1 } } From 7ac70e7ba8d610a45c21a70dc28e4c989c19451b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 26 Oct 2016 10:36:36 -0700 Subject: [PATCH 023/381] [SPARK-13747][SQL] Fix concurrent executions in ForkJoinPool for SQL ## What changes were proposed in this pull request? Calling `Await.result` will allow other tasks to be run on the same thread when using ForkJoinPool. However, SQL uses a `ThreadLocal` execution id to trace Spark jobs launched by a query, which doesn't work perfectly in ForkJoinPool. This PR just uses `Awaitable.result` instead to prevent ForkJoinPool from running other tasks in the current waiting thread. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15520 from zsxwing/SPARK-13747. --- .../org/apache/spark/util/ThreadUtils.scala | 21 +++++++++++++++++++ scalastyle-config.xml | 1 + .../execution/basicPhysicalOperators.scala | 2 +- .../exchange/BroadcastExchangeExec.scala | 3 ++- 4 files changed, 25 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 5a6dbc830448..d093e7bfc3da 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -194,4 +194,25 @@ private[spark] object ThreadUtils { throw new SparkException("Exception thrown in awaitResult: ", t) } } + + /** + * Calls [[Awaitable.result]] directly to avoid using `ForkJoinPool`'s `BlockingContext`, wraps + * and re-throws any exceptions with nice stack track. + * + * Codes running in the user's thread may be in a thread of Scala ForkJoinPool. As concurrent + * executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this method + * basically prevents ForkJoinPool from running other tasks in the current waiting thread. + */ + @throws(classOf[SparkException]) + def awaitResultInForkJoinSafely[T](awaitable: Awaitable[T], atMost: Duration): T = { + try { + // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. + // See SPARK-13747. + val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] + awaitable.result(Duration.Inf)(awaitPermission) + } catch { + case NonFatal(t) => + throw new SparkException("Exception thrown in awaitResult: ", t) + } + } } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 7fe0697202cd..81d57d723a72 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -200,6 +200,7 @@ This file is divided into 3 sections: // scalastyle:off awaitresult Await.result(...) // scalastyle:on awaitresult + If your codes use ThreadLocal and may run in threads created by the user, use ThreadUtils.awaitResultInForkJoinSafely instead. ]]> diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 37d750e621c2..a5291e0c12f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -570,7 +570,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { } override def executeCollect(): Array[InternalRow] = { - ThreadUtils.awaitResult(relationFuture, Duration.Inf) + ThreadUtils.awaitResultInForkJoinSafely(relationFuture, Duration.Inf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 7be5d31d4a76..ce5013daeb1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -128,7 +128,8 @@ case class BroadcastExchangeExec( } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + ThreadUtils.awaitResultInForkJoinSafely(relationFuture, timeout) + .asInstanceOf[broadcast.Broadcast[T]] } } From fa7d9d70825a6816495d239da925d0087f7cb94f Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Wed, 26 Oct 2016 20:12:20 +0200 Subject: [PATCH 024/381] [SPARK-18063][SQL] Failed to infer constraints over multiple aliases ## What changes were proposed in this pull request? The `UnaryNode.getAliasedConstraints` function fails to replace all expressions by their alias where constraints contains more than one expression to be replaced. For example: ``` val tr = LocalRelation('a.int, 'b.string, 'c.int) val multiAlias = tr.where('a === 'c + 10).select('a.as('x), 'c.as('y)) multiAlias.analyze.constraints ``` currently outputs: ``` ExpressionSet(Seq( IsNotNull(resolveColumn(multiAlias.analyze, "x")), IsNotNull(resolveColumn(multiAlias.analyze, "y")) ) ``` The constraint `resolveColumn(multiAlias.analyze, "x") === resolveColumn(multiAlias.analyze, "y") + 10)` is missing. ## How was this patch tested? Add new test cases in `ConstraintPropagationSuite`. Author: jiangxingbo Closes #15597 from jiangxb1987/alias-constraints. --- .../sql/catalyst/plans/logical/LogicalPlan.scala | 16 ++++++++++------ .../plans/ConstraintPropagationSuite.scala | 8 ++++++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 09725473a384..b0a4145f3776 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -293,15 +293,19 @@ abstract class UnaryNode extends LogicalPlan { * expressions with the corresponding alias */ protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { - projectList.flatMap { + var allConstraints = child.constraints.asInstanceOf[Set[Expression]] + projectList.foreach { case a @ Alias(e, _) => - child.constraints.map(_ transform { + // For every alias in `projectList`, replace the reference in constraints by its attribute. + allConstraints ++= allConstraints.map(_ transform { case expr: Expression if expr.semanticEquals(e) => a.toAttribute - }).union(Set(EqualNullSafe(e, a.toAttribute))) - case _ => - Set.empty[Expression] - }.toSet + }) + allConstraints += EqualNullSafe(e, a.toAttribute) + case _ => // Don't change. + } + + allConstraints -- child.constraints } override protected def validConstraints: Set[Expression] = child.constraints diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 8d6a49a8a37b..8068ce922e63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -128,8 +128,16 @@ class ConstraintPropagationSuite extends SparkFunSuite { ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"), resolveColumn(aliasedRelation.analyze, "z") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) + + val multiAlias = tr.where('a === 'c + 10).select('a.as('x), 'c.as('y)) + verifyConstraints(multiAlias.analyze.constraints, + ExpressionSet(Seq(IsNotNull(resolveColumn(multiAlias.analyze, "x")), + IsNotNull(resolveColumn(multiAlias.analyze, "y")), + resolveColumn(multiAlias.analyze, "x") === resolveColumn(multiAlias.analyze, "y") + 10)) + ) } test("propagating constraints in union") { From 7d10631c16b980adf1f55378c128436310daed65 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 26 Oct 2016 11:16:20 -0700 Subject: [PATCH 025/381] [SPARK-18104][DOC] Don't build KafkaSource doc ## What changes were proposed in this pull request? Don't need to build doc for KafkaSource because the user should use the data source APIs to use KafkaSource. All KafkaSource APIs are internal. ## How was this patch tested? Verified manually. Author: Shixiong Zhu Closes #15630 from zsxwing/kafka-unidoc. --- project/SparkBuild.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 88d5dc9b02dd..2d3a95b163a7 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -714,9 +714,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010), unidocAllClasspaths in (ScalaUnidoc, unidoc) := { ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value) From ea3605e82545031a00235ee0f449e1e2418674e8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 26 Oct 2016 11:48:54 -0700 Subject: [PATCH 026/381] [MINOR][ML] Refactor clustering summary. ## What changes were proposed in this pull request? Abstract ```ClusteringSummary``` from ```KMeansSummary```, ```GaussianMixtureSummary``` and ```BisectingSummary```, and eliminate duplicated pieces of code. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #15555 from yanboliang/clustering-summary. --- .../spark/ml/clustering/BisectingKMeans.scala | 36 +++---------- .../ml/clustering/ClusteringSummary.scala | 54 +++++++++++++++++++ .../spark/ml/clustering/GaussianMixture.scala | 37 ++++--------- .../apache/spark/ml/clustering/KMeans.scala | 36 +++---------- 4 files changed, 80 insertions(+), 83 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index ef2d918ea354..2718dd93dcb5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -288,35 +288,15 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { * :: Experimental :: * Summary of BisectingKMeans. * - * @param predictions [[DataFrame]] produced by [[BisectingKMeansModel.transform()]] - * @param predictionCol Name for column of predicted clusters in `predictions` - * @param featuresCol Name for column of features in `predictions` - * @param k Number of clusters + * @param predictions [[DataFrame]] produced by [[BisectingKMeansModel.transform()]]. + * @param predictionCol Name for column of predicted clusters in `predictions`. + * @param featuresCol Name for column of features in `predictions`. + * @param k Number of clusters. */ @Since("2.1.0") @Experimental class BisectingKMeansSummary private[clustering] ( - @Since("2.1.0") @transient val predictions: DataFrame, - @Since("2.1.0") val predictionCol: String, - @Since("2.1.0") val featuresCol: String, - @Since("2.1.0") val k: Int) extends Serializable { - - /** - * Cluster centers of the transformed data. - */ - @Since("2.1.0") - @transient lazy val cluster: DataFrame = predictions.select(predictionCol) - - /** - * Size of (number of data points in) each cluster. - */ - @Since("2.1.0") - lazy val clusterSizes: Array[Long] = { - val sizes = Array.fill[Long](k)(0) - cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { - case Row(cluster: Int, count: Long) => sizes(cluster) = count - } - sizes - } - -} + predictions: DataFrame, + predictionCol: String, + featuresCol: String, + k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala new file mode 100644 index 000000000000..8b5f525194f2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{DataFrame, Row} + +/** + * :: Experimental :: + * Summary of clustering algorithms. + * + * @param predictions [[DataFrame]] produced by model.transform(). + * @param predictionCol Name for column of predicted clusters in `predictions`. + * @param featuresCol Name for column of features in `predictions`. + * @param k Number of clusters. + */ +@Experimental +class ClusteringSummary private[clustering] ( + @transient val predictions: DataFrame, + val predictionCol: String, + val featuresCol: String, + val k: Int) extends Serializable { + + /** + * Cluster centers of the transformed data. + */ + @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + + /** + * Size of (number of data points in) each cluster. + */ + lazy val clusterSizes: Array[Long] = { + val sizes = Array.fill[Long](k)(0) + cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { + case Row(cluster: Int, count: Long) => sizes(cluster) = count + } + sizes + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 69f060ad7711..e3cb92f4f144 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -356,42 +356,25 @@ object GaussianMixture extends DefaultParamsReadable[GaussianMixture] { * :: Experimental :: * Summary of GaussianMixture. * - * @param predictions [[DataFrame]] produced by [[GaussianMixtureModel.transform()]] - * @param predictionCol Name for column of predicted clusters in `predictions` - * @param probabilityCol Name for column of predicted probability of each cluster in `predictions` - * @param featuresCol Name for column of features in `predictions` - * @param k Number of clusters + * @param predictions [[DataFrame]] produced by [[GaussianMixtureModel.transform()]]. + * @param predictionCol Name for column of predicted clusters in `predictions`. + * @param probabilityCol Name for column of predicted probability of each cluster + * in `predictions`. + * @param featuresCol Name for column of features in `predictions`. + * @param k Number of clusters. */ @Since("2.0.0") @Experimental class GaussianMixtureSummary private[clustering] ( - @Since("2.0.0") @transient val predictions: DataFrame, - @Since("2.0.0") val predictionCol: String, + predictions: DataFrame, + predictionCol: String, @Since("2.0.0") val probabilityCol: String, - @Since("2.0.0") val featuresCol: String, - @Since("2.0.0") val k: Int) extends Serializable { - - /** - * Cluster centers of the transformed data. - */ - @Since("2.0.0") - @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + featuresCol: String, + k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) { /** * Probability of each cluster. */ @Since("2.0.0") @transient lazy val probability: DataFrame = predictions.select(probabilityCol) - - /** - * Size of (number of data points in) each cluster. - */ - @Since("2.0.0") - lazy val clusterSizes: Array[Long] = { - val sizes = Array.fill[Long](k)(0) - cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { - case Row(cluster: Int, count: Long) => sizes(cluster) = count - } - sizes - } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 0d2405b50068..05ed3223ae53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -346,35 +346,15 @@ object KMeans extends DefaultParamsReadable[KMeans] { * :: Experimental :: * Summary of KMeans. * - * @param predictions [[DataFrame]] produced by [[KMeansModel.transform()]] - * @param predictionCol Name for column of predicted clusters in `predictions` - * @param featuresCol Name for column of features in `predictions` - * @param k Number of clusters + * @param predictions [[DataFrame]] produced by [[KMeansModel.transform()]]. + * @param predictionCol Name for column of predicted clusters in `predictions`. + * @param featuresCol Name for column of features in `predictions`. + * @param k Number of clusters. */ @Since("2.0.0") @Experimental class KMeansSummary private[clustering] ( - @Since("2.0.0") @transient val predictions: DataFrame, - @Since("2.0.0") val predictionCol: String, - @Since("2.0.0") val featuresCol: String, - @Since("2.0.0") val k: Int) extends Serializable { - - /** - * Cluster centers of the transformed data. - */ - @Since("2.0.0") - @transient lazy val cluster: DataFrame = predictions.select(predictionCol) - - /** - * Size of (number of data points in) each cluster. - */ - @Since("2.0.0") - lazy val clusterSizes: Array[Long] = { - val sizes = Array.fill[Long](k)(0) - cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { - case Row(cluster: Int, count: Long) => sizes(cluster) = count - } - sizes - } - -} + predictions: DataFrame, + predictionCol: String, + featuresCol: String, + k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) From fb0a8a8dd7e8985676a846684b956e2d988875c6 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 26 Oct 2016 13:26:43 -0700 Subject: [PATCH 027/381] [SPARK-17961][SPARKR][SQL] Add storageLevel to DataFrame for SparkR ## What changes were proposed in this pull request? Add storageLevel to DataFrame for SparkR. This is similar to this RP: https://github.com/apache/spark/pull/13780 but in R I do not make a class for `StorageLevel` but add a method `storageToString` ## How was this patch tested? test added. Author: WeichenXu Closes #15516 from WeichenXu123/storageLevel_df_r. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 28 +++++++++++++++- R/pkg/R/RDD.R | 2 +- R/pkg/R/generics.R | 6 +++- R/pkg/R/utils.R | 41 +++++++++++++++++++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 5 ++- 6 files changed, 79 insertions(+), 4 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 87181851714e..eb314f471893 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -124,6 +124,7 @@ exportMethods("arrange", "selectExpr", "show", "showDF", + "storageLevel", "subset", "summarize", "summary", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index b6ce838969a4..be34e4b32f6f 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -633,7 +633,7 @@ setMethod("persist", #' @param ... further arguments to be passed to or from other methods. #' #' @family SparkDataFrame functions -#' @rdname unpersist-methods +#' @rdname unpersist #' @aliases unpersist,SparkDataFrame-method #' @name unpersist #' @export @@ -654,6 +654,32 @@ setMethod("unpersist", x }) +#' StorageLevel +#' +#' Get storagelevel of this SparkDataFrame. +#' +#' @param x the SparkDataFrame to get the storageLevel. +#' +#' @family SparkDataFrame functions +#' @rdname storageLevel +#' @aliases storageLevel,SparkDataFrame-method +#' @name storageLevel +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' persist(df, "MEMORY_AND_DISK") +#' storageLevel(df) +#'} +#' @note storageLevel since 2.1.0 +setMethod("storageLevel", + signature(x = "SparkDataFrame"), + function(x) { + storageLevelToString(callJMethod(x@sdf, "storageLevel")) + }) + #' Repartition #' #' The following options for repartition are possible: diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 6cd0704003f1..0f1162fec1df 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -261,7 +261,7 @@ setMethod("persistRDD", #' cache(rdd) # rdd@@env$isCached == TRUE #' unpersistRDD(rdd) # rdd@@env$isCached == FALSE #'} -#' @rdname unpersist-methods +#' @rdname unpersist #' @aliases unpersist,RDD-method #' @noRd setMethod("unpersistRDD", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5549cd7cac51..4569fe489046 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -691,6 +691,10 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") #' @export setGeneric("showDF", function(x, ...) { standardGeneric("showDF") }) +# @rdname storageLevel +# @export +setGeneric("storageLevel", function(x) { standardGeneric("storageLevel") }) + #' @rdname subset #' @export setGeneric("subset", function(x, ...) { standardGeneric("subset") }) @@ -715,7 +719,7 @@ setGeneric("union", function(x, y) { standardGeneric("union") }) #' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) -#' @rdname unpersist-methods +#' @rdname unpersist #' @export setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index fa8bb0f79ce8..c4e78cbb804d 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -385,6 +385,47 @@ getStorageLevel <- function(newLevel = c("DISK_ONLY", "OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP")) } +storageLevelToString <- function(levelObj) { + useDisk <- callJMethod(levelObj, "useDisk") + useMemory <- callJMethod(levelObj, "useMemory") + useOffHeap <- callJMethod(levelObj, "useOffHeap") + deserialized <- callJMethod(levelObj, "deserialized") + replication <- callJMethod(levelObj, "replication") + shortName <- if (!useDisk && !useMemory && !useOffHeap && !deserialized && replication == 1) { + "NONE" + } else if (useDisk && !useMemory && !useOffHeap && !deserialized && replication == 1) { + "DISK_ONLY" + } else if (useDisk && !useMemory && !useOffHeap && !deserialized && replication == 2) { + "DISK_ONLY_2" + } else if (!useDisk && useMemory && !useOffHeap && deserialized && replication == 1) { + "MEMORY_ONLY" + } else if (!useDisk && useMemory && !useOffHeap && deserialized && replication == 2) { + "MEMORY_ONLY_2" + } else if (!useDisk && useMemory && !useOffHeap && !deserialized && replication == 1) { + "MEMORY_ONLY_SER" + } else if (!useDisk && useMemory && !useOffHeap && !deserialized && replication == 2) { + "MEMORY_ONLY_SER_2" + } else if (useDisk && useMemory && !useOffHeap && deserialized && replication == 1) { + "MEMORY_AND_DISK" + } else if (useDisk && useMemory && !useOffHeap && deserialized && replication == 2) { + "MEMORY_AND_DISK_2" + } else if (useDisk && useMemory && !useOffHeap && !deserialized && replication == 1) { + "MEMORY_AND_DISK_SER" + } else if (useDisk && useMemory && !useOffHeap && !deserialized && replication == 2) { + "MEMORY_AND_DISK_SER_2" + } else if (useDisk && useMemory && useOffHeap && !deserialized && replication == 1) { + "OFF_HEAP" + } else { + NULL + } + fullInfo <- callJMethod(levelObj, "toString") + if (is.null(shortName)) { + fullInfo + } else { + paste(shortName, "-", fullInfo) + } +} + # Utility function for functions where an argument needs to be integer but we want to allow # the user to type (for example) `5` instead of `5L` to avoid a confusing error message. numToInt <- function(num) { diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index e77dbde44ee6..9289db57b6d6 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -796,7 +796,7 @@ test_that("multiple pipeline transformations result in an RDD with the correct v expect_false(collectRDD(second)[[3]]$testCol) }) -test_that("cache(), persist(), and unpersist() on a DataFrame", { +test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", { df <- read.json(jsonPath) expect_false(df@env$isCached) cache(df) @@ -808,6 +808,9 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { persist(df, "MEMORY_AND_DISK") expect_true(df@env$isCached) + expect_equal(storageLevel(df), + "MEMORY_AND_DISK - StorageLevel(disk, memory, deserialized, 1 replicas)") + unpersist(df) expect_false(df@env$isCached) From dcdda19785a272969fb1e3ec18382403aaad6c91 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Wed, 26 Oct 2016 13:33:23 -0700 Subject: [PATCH 028/381] [SPARK-14300][DOCS][MLLIB] Scala MLlib examples code merge and clean up ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14300 Duplicated code found in scala/examples/mllib, below all deleted in this PR: - DenseGaussianMixture.scala - StreamingLinearRegression.scala ## delete reasons: #### delete: mllib/DenseGaussianMixture.scala - duplicate of mllib/GaussianMixtureExample #### delete: mllib/StreamingLinearRegression.scala - duplicate of mllib/StreamingLinearRegressionExample When merging and cleaning those code, be sure not disturb the previous example on and off blocks. ## How was this patch tested? Test with `SKIP_API=1 jekyll` manually to make sure that works well. Author: Xin Ren Closes #12195 from keypointt/SPARK-14300. --- .../examples/mllib/DenseGaussianMixture.scala | 75 ------------------- .../mllib/StreamingLinearRegression.scala | 73 ------------------ .../StreamingLinearRegressionExample.scala | 19 +++++ 3 files changed, 19 insertions(+), 148 deletions(-) delete mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala deleted file mode 100644 index 90b817b23e15..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.mllib - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.clustering.GaussianMixture -import org.apache.spark.mllib.linalg.Vectors - -/** - * An example Gaussian Mixture Model EM app. Run with - * {{{ - * ./bin/run-example mllib.DenseGaussianMixture - * }}} - * If you use it as a template to create your own app, please use `spark-submit` to submit your app. - */ -object DenseGaussianMixture { - def main(args: Array[String]): Unit = { - if (args.length < 3) { - println("usage: DenseGmmEM [maxIterations]") - } else { - val maxIterations = if (args.length > 3) args(3).toInt else 100 - run(args(0), args(1).toInt, args(2).toDouble, maxIterations) - } - } - - private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { - val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") - val ctx = new SparkContext(conf) - - val data = ctx.textFile(inputFile).map { line => - Vectors.dense(line.trim.split(' ').map(_.toDouble)) - }.cache() - - val clusters = new GaussianMixture() - .setK(k) - .setConvergenceTol(convergenceTol) - .setMaxIterations(maxIterations) - .run(data) - - for (i <- 0 until clusters.k) { - println("weight=%f\nmu=%s\nsigma=\n%s\n" format - (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) - } - - println("The membership value of each vector to all mixture components (first <= 100):") - val membership = clusters.predictSoft(data) - membership.take(100).foreach { x => - print(" " + x.mkString(",")) - } - println() - println("Cluster labels (first <= 100):") - val clusterLabels = clusters.predict(data) - clusterLabels.take(100).foreach { x => - print(" " + x) - } - println() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala deleted file mode 100644 index e5592966f13f..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.mllib - -import org.apache.spark.SparkConf -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.{LabeledPoint, StreamingLinearRegressionWithSGD} -import org.apache.spark.streaming.{Seconds, StreamingContext} - -/** - * Train a linear regression model on one stream of data and make predictions - * on another stream, where the data streams arrive as text files - * into two different directories. - * - * The rows of the text files must be labeled data points in the form - * `(y,[x1,x2,x3,...,xn])` - * Where n is the number of features. n must be the same for train and test. - * - * Usage: StreamingLinearRegression - * - * To run on your local machine using the two directories `trainingDir` and `testDir`, - * with updates every 5 seconds, and 2 features per data point, call: - * $ bin/run-example mllib.StreamingLinearRegression trainingDir testDir 5 2 - * - * As you add text files to `trainingDir` the model will continuously update. - * Anytime you add text files to `testDir`, you'll see predictions from the current model. - * - */ -object StreamingLinearRegression { - - def main(args: Array[String]) { - - if (args.length != 4) { - System.err.println( - "Usage: StreamingLinearRegression ") - System.exit(1) - } - - val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression") - val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) - - val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse) - val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) - - val model = new StreamingLinearRegressionWithSGD() - .setInitialWeights(Vectors.zeros(args(3).toInt)) - - model.trainOn(trainingData) - model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() - - ssc.start() - ssc.awaitTermination() - - } - -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala index 0a1cd2d62d5b..2ba1a62e450e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala @@ -26,6 +26,25 @@ import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD // $example off$ import org.apache.spark.streaming._ +/** + * Train a linear regression model on one stream of data and make predictions + * on another stream, where the data streams arrive as text files + * into two different directories. + * + * The rows of the text files must be labeled data points in the form + * `(y,[x1,x2,x3,...,xn])` + * Where n is the number of features. n must be the same for train and test. + * + * Usage: StreamingLinearRegressionExample + * + * To run on your local machine using the two directories `trainingDir` and `testDir`, + * with updates every 5 seconds, and 2 features per data point, call: + * $ bin/run-example mllib.StreamingLinearRegressionExample trainingDir testDir + * + * As you add text files to `trainingDir` the model will continuously update. + * Anytime you add text files to `testDir`, you'll see predictions from the current model. + * + */ object StreamingLinearRegressionExample { def main(args: Array[String]): Unit = { From 5b7d403c1819c32a6a5b87d470f8de1a8ad7a987 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Wed, 26 Oct 2016 23:51:16 +0200 Subject: [PATCH 029/381] [SPARK-18094][SQL][TESTS] Move group analytics test cases from `SQLQuerySuite` into a query file test. ## What changes were proposed in this pull request? Currently we have several test cases for group analytics(ROLLUP/CUBE/GROUPING SETS) in `SQLQuerySuite`, should better move them into a query file test. The following test cases are moved to `group-analytics.sql`: ``` test("rollup") test("grouping sets when aggregate functions containing groupBy columns") test("cube") test("grouping sets") test("grouping and grouping_id") test("grouping and grouping_id in having") test("grouping and grouping_id in sort") ``` This is followup work of #15582 ## How was this patch tested? Modified query file `group-analytics.sql`, which will be tested by `SQLQueryTestSuite`. Author: jiangxingbo Closes #15624 from jiangxb1987/group-analytics-test. --- .../sql-tests/inputs/group-analytics.sql | 46 +++- .../sql-tests/results/group-analytics.sql.out | 247 +++++++++++++++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 189 -------------- 3 files changed, 290 insertions(+), 192 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql index 2f783495ddf9..f8135389a9e5 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql @@ -10,4 +10,48 @@ SELECT a, b, SUM(b) FROM testData GROUP BY a, b WITH CUBE; -- ROLLUP on overlapping columns SELECT a + b, b, SUM(a - b) FROM testData GROUP BY a + b, b WITH ROLLUP; -SELECT a, b, SUM(b) FROM testData GROUP BY a, b WITH ROLLUP; \ No newline at end of file +SELECT a, b, SUM(b) FROM testData GROUP BY a, b WITH ROLLUP; + +CREATE OR REPLACE TEMPORARY VIEW courseSales AS SELECT * FROM VALUES +("dotNET", 2012, 10000), ("Java", 2012, 20000), ("dotNET", 2012, 5000), ("dotNET", 2013, 48000), ("Java", 2013, 30000) +AS courseSales(course, year, earnings); + +-- ROLLUP +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY ROLLUP(course, year) ORDER BY course, year; + +-- CUBE +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY CUBE(course, year) ORDER BY course, year; + +-- GROUPING SETS +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(course, year); +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(course); +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(year); + +-- GROUPING SETS with aggregate functions containing groupBy columns +SELECT course, SUM(earnings) AS sum FROM courseSales +GROUP BY course, earnings GROUPING SETS((), (course), (course, earnings)) ORDER BY course, sum; +SELECT course, SUM(earnings) AS sum, GROUPING_ID(course, earnings) FROM courseSales +GROUP BY course, earnings GROUPING SETS((), (course), (course, earnings)) ORDER BY course, sum; + +-- GROUPING/GROUPING_ID +SELECT course, year, GROUPING(course), GROUPING(year), GROUPING_ID(course, year) FROM courseSales +GROUP BY CUBE(course, year); +SELECT course, year, GROUPING(course) FROM courseSales GROUP BY course, year; +SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY course, year; +SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year); + +-- GROUPING/GROUPING_ID in having clause +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) +HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0; +SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING(course) > 0; +SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING_ID(course) > 0; +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0; + +-- GROUPING/GROUPING_ID in orderBy clause +SELECT course, year, GROUPING(course), GROUPING(year) FROM courseSales GROUP BY CUBE(course, year) +ORDER BY GROUPING(course), GROUPING(year), course, year; +SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(course, year) +ORDER BY GROUPING(course), GROUPING(year), course, year; +SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course); +SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course); +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out index 8ea7de809d19..825e8f5488c8 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 5 +-- Number of queries: 26 -- !query 0 @@ -32,7 +32,6 @@ NULL 2 0 NULL NULL 3 - -- !query 2 SELECT a, b, SUM(b) FROM testData GROUP BY a, b WITH CUBE -- !query 2 schema @@ -85,3 +84,247 @@ struct 3 2 2 3 NULL 3 NULL NULL 9 + + +-- !query 5 +CREATE OR REPLACE TEMPORARY VIEW courseSales AS SELECT * FROM VALUES +("dotNET", 2012, 10000), ("Java", 2012, 20000), ("dotNET", 2012, 5000), ("dotNET", 2013, 48000), ("Java", 2013, 30000) +AS courseSales(course, year, earnings) +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY ROLLUP(course, year) ORDER BY course, year +-- !query 6 schema +struct +-- !query 6 output +NULL NULL 113000 +Java NULL 50000 +Java 2012 20000 +Java 2013 30000 +dotNET NULL 63000 +dotNET 2012 15000 +dotNET 2013 48000 + + +-- !query 7 +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY CUBE(course, year) ORDER BY course, year +-- !query 7 schema +struct +-- !query 7 output +NULL NULL 113000 +NULL 2012 35000 +NULL 2013 78000 +Java NULL 50000 +Java 2012 20000 +Java 2013 30000 +dotNET NULL 63000 +dotNET 2012 15000 +dotNET 2013 48000 + + +-- !query 8 +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(course, year) +-- !query 8 schema +struct +-- !query 8 output +Java NULL 50000 +NULL 2012 35000 +NULL 2013 78000 +dotNET NULL 63000 + + +-- !query 9 +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(course) +-- !query 9 schema +struct +-- !query 9 output +Java NULL 50000 +dotNET NULL 63000 + + +-- !query 10 +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(year) +-- !query 10 schema +struct +-- !query 10 output +NULL 2012 35000 +NULL 2013 78000 + + +-- !query 11 +SELECT course, SUM(earnings) AS sum FROM courseSales +GROUP BY course, earnings GROUPING SETS((), (course), (course, earnings)) ORDER BY course, sum +-- !query 11 schema +struct +-- !query 11 output +NULL 113000 +Java 20000 +Java 30000 +Java 50000 +dotNET 5000 +dotNET 10000 +dotNET 48000 +dotNET 63000 + + +-- !query 12 +SELECT course, SUM(earnings) AS sum, GROUPING_ID(course, earnings) FROM courseSales +GROUP BY course, earnings GROUPING SETS((), (course), (course, earnings)) ORDER BY course, sum +-- !query 12 schema +struct +-- !query 12 output +NULL 113000 3 +Java 20000 0 +Java 30000 0 +Java 50000 1 +dotNET 5000 0 +dotNET 10000 0 +dotNET 48000 0 +dotNET 63000 1 + + +-- !query 13 +SELECT course, year, GROUPING(course), GROUPING(year), GROUPING_ID(course, year) FROM courseSales +GROUP BY CUBE(course, year) +-- !query 13 schema +struct +-- !query 13 output +Java 2012 0 0 0 +Java 2013 0 0 0 +Java NULL 0 1 1 +NULL 2012 1 0 2 +NULL 2013 1 0 2 +NULL NULL 1 1 3 +dotNET 2012 0 0 0 +dotNET 2013 0 0 0 +dotNET NULL 0 1 1 + + +-- !query 14 +SELECT course, year, GROUPING(course) FROM courseSales GROUP BY course, year +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +grouping() can only be used with GroupingSets/Cube/Rollup; + + +-- !query 15 +SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY course, year +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +grouping_id() can only be used with GroupingSets/Cube/Rollup; + + +-- !query 16 +SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +grouping__id is deprecated; use grouping_id() instead; + + +-- !query 17 +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) +HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 +-- !query 17 schema +struct +-- !query 17 output +Java NULL +NULL NULL +dotNET NULL + + +-- !query 18 +SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING(course) > 0 +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; + + +-- !query 19 +SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING_ID(course) > 0 +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.AnalysisException +grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; + + +-- !query 20 +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0 +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.AnalysisException +grouping__id is deprecated; use grouping_id() instead; + + +-- !query 21 +SELECT course, year, GROUPING(course), GROUPING(year) FROM courseSales GROUP BY CUBE(course, year) +ORDER BY GROUPING(course), GROUPING(year), course, year +-- !query 21 schema +struct +-- !query 21 output +Java 2012 0 0 +Java 2013 0 0 +dotNET 2012 0 0 +dotNET 2013 0 0 +Java NULL 0 1 +dotNET NULL 0 1 +NULL 2012 1 0 +NULL 2013 1 0 +NULL NULL 1 1 + + +-- !query 22 +SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(course, year) +ORDER BY GROUPING(course), GROUPING(year), course, year +-- !query 22 schema +struct +-- !query 22 output +Java 2012 0 +Java 2013 0 +dotNET 2012 0 +dotNET 2013 0 +Java NULL 1 +dotNET NULL 1 +NULL 2012 2 +NULL 2013 2 +NULL NULL 3 + + +-- !query 23 +SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course) +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; + + +-- !query 24 +SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course) +-- !query 24 schema +struct<> +-- !query 24 output +org.apache.spark.sql.AnalysisException +grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; + + +-- !query 25 +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id +-- !query 25 schema +struct<> +-- !query 25 output +org.apache.spark.sql.AnalysisException +grouping__id is deprecated; use grouping_id() instead; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bd4c25315c31..1a43d0b2205c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2005,195 +2005,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(false) :: Row(true) :: Nil) } - test("rollup") { - checkAnswer( - sql("select course, year, sum(earnings) from courseSales group by rollup(course, year)" + - " order by course, year"), - Row(null, null, 113000.0) :: - Row("Java", null, 50000.0) :: - Row("Java", 2012, 20000.0) :: - Row("Java", 2013, 30000.0) :: - Row("dotNET", null, 63000.0) :: - Row("dotNET", 2012, 15000.0) :: - Row("dotNET", 2013, 48000.0) :: Nil - ) - } - - test("grouping sets when aggregate functions containing groupBy columns") { - checkAnswer( - sql("select course, sum(earnings) as sum from courseSales group by course, earnings " + - "grouping sets((), (course), (course, earnings)) " + - "order by course, sum"), - Row(null, 113000.0) :: - Row("Java", 20000.0) :: - Row("Java", 30000.0) :: - Row("Java", 50000.0) :: - Row("dotNET", 5000.0) :: - Row("dotNET", 10000.0) :: - Row("dotNET", 48000.0) :: - Row("dotNET", 63000.0) :: Nil - ) - - checkAnswer( - sql("select course, sum(earnings) as sum, grouping_id(course, earnings) from courseSales " + - "group by course, earnings grouping sets((), (course), (course, earnings)) " + - "order by course, sum"), - Row(null, 113000.0, 3) :: - Row("Java", 20000.0, 0) :: - Row("Java", 30000.0, 0) :: - Row("Java", 50000.0, 1) :: - Row("dotNET", 5000.0, 0) :: - Row("dotNET", 10000.0, 0) :: - Row("dotNET", 48000.0, 0) :: - Row("dotNET", 63000.0, 1) :: Nil - ) - } - - test("cube") { - checkAnswer( - sql("select course, year, sum(earnings) from courseSales group by cube(course, year)"), - Row("Java", 2012, 20000.0) :: - Row("Java", 2013, 30000.0) :: - Row("Java", null, 50000.0) :: - Row("dotNET", 2012, 15000.0) :: - Row("dotNET", 2013, 48000.0) :: - Row("dotNET", null, 63000.0) :: - Row(null, 2012, 35000.0) :: - Row(null, 2013, 78000.0) :: - Row(null, null, 113000.0) :: Nil - ) - } - - test("grouping sets") { - checkAnswer( - sql("select course, year, sum(earnings) from courseSales group by course, year " + - "grouping sets(course, year)"), - Row("Java", null, 50000.0) :: - Row("dotNET", null, 63000.0) :: - Row(null, 2012, 35000.0) :: - Row(null, 2013, 78000.0) :: Nil - ) - - checkAnswer( - sql("select course, year, sum(earnings) from courseSales group by course, year " + - "grouping sets(course)"), - Row("Java", null, 50000.0) :: - Row("dotNET", null, 63000.0) :: Nil - ) - - checkAnswer( - sql("select course, year, sum(earnings) from courseSales group by course, year " + - "grouping sets(year)"), - Row(null, 2012, 35000.0) :: - Row(null, 2013, 78000.0) :: Nil - ) - } - - test("grouping and grouping_id") { - checkAnswer( - sql("select course, year, grouping(course), grouping(year), grouping_id(course, year)" + - " from courseSales group by cube(course, year)"), - Row("Java", 2012, 0, 0, 0) :: - Row("Java", 2013, 0, 0, 0) :: - Row("Java", null, 0, 1, 1) :: - Row("dotNET", 2012, 0, 0, 0) :: - Row("dotNET", 2013, 0, 0, 0) :: - Row("dotNET", null, 0, 1, 1) :: - Row(null, 2012, 1, 0, 2) :: - Row(null, 2013, 1, 0, 2) :: - Row(null, null, 1, 1, 3) :: Nil - ) - - var error = intercept[AnalysisException] { - sql("select course, year, grouping(course) from courseSales group by course, year") - } - assert(error.getMessage contains "grouping() can only be used with GroupingSets/Cube/Rollup") - error = intercept[AnalysisException] { - sql("select course, year, grouping_id(course, year) from courseSales group by course, year") - } - assert(error.getMessage contains "grouping_id() can only be used with GroupingSets/Cube/Rollup") - error = intercept[AnalysisException] { - sql("select course, year, grouping__id from courseSales group by cube(course, year)") - } - assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") - } - - test("grouping and grouping_id in having") { - checkAnswer( - sql("select course, year from courseSales group by cube(course, year)" + - " having grouping(year) = 1 and grouping_id(course, year) > 0"), - Row("Java", null) :: - Row("dotNET", null) :: - Row(null, null) :: Nil - ) - - var error = intercept[AnalysisException] { - sql("select course, year from courseSales group by course, year" + - " having grouping(course) > 0") - } - assert(error.getMessage contains - "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") - error = intercept[AnalysisException] { - sql("select course, year from courseSales group by course, year" + - " having grouping_id(course, year) > 0") - } - assert(error.getMessage contains - "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") - error = intercept[AnalysisException] { - sql("select course, year from courseSales group by cube(course, year)" + - " having grouping__id > 0") - } - assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") - } - - test("grouping and grouping_id in sort") { - checkAnswer( - sql("select course, year, grouping(course), grouping(year) from courseSales" + - " group by cube(course, year) order by grouping_id(course, year), course, year"), - Row("Java", 2012, 0, 0) :: - Row("Java", 2013, 0, 0) :: - Row("dotNET", 2012, 0, 0) :: - Row("dotNET", 2013, 0, 0) :: - Row("Java", null, 0, 1) :: - Row("dotNET", null, 0, 1) :: - Row(null, 2012, 1, 0) :: - Row(null, 2013, 1, 0) :: - Row(null, null, 1, 1) :: Nil - ) - - checkAnswer( - sql("select course, year, grouping_id(course, year) from courseSales" + - " group by cube(course, year) order by grouping(course), grouping(year), course, year"), - Row("Java", 2012, 0) :: - Row("Java", 2013, 0) :: - Row("dotNET", 2012, 0) :: - Row("dotNET", 2013, 0) :: - Row("Java", null, 1) :: - Row("dotNET", null, 1) :: - Row(null, 2012, 2) :: - Row(null, 2013, 2) :: - Row(null, null, 3) :: Nil - ) - - var error = intercept[AnalysisException] { - sql("select course, year from courseSales group by course, year" + - " order by grouping(course)") - } - assert(error.getMessage contains - "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") - error = intercept[AnalysisException] { - sql("select course, year from courseSales group by course, year" + - " order by grouping_id(course, year)") - } - assert(error.getMessage contains - "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") - error = intercept[AnalysisException] { - sql("select course, year from courseSales group by cube(course, year)" + - " order by grouping__id") - } - assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") - } - test("filter on a grouping column that is not presented in SELECT") { checkAnswer( sql("select count(1) from (select 1 as a) t group by a having a > 0"), From 29cea8f332aa3750f8ff7c3b9e705d107278da4b Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Wed, 26 Oct 2016 16:12:55 -0700 Subject: [PATCH 030/381] [SPARK-17157][SPARKR] Add multiclass logistic regression SparkR Wrapper ## What changes were proposed in this pull request? As we discussed in #14818, I added a separate R wrapper spark.logit for logistic regression. This single interface supports both binary and multinomial logistic regression. It also has "predict" and "summary" for binary logistic regression. ## How was this patch tested? New unit tests are added. Author: wm624@hotmail.com Closes #15365 from wangmiao1981/glm. --- R/pkg/NAMESPACE | 3 +- R/pkg/R/generics.R | 4 + R/pkg/R/mllib.R | 192 +++++++++++++++++- R/pkg/inst/tests/testthat/test_mllib.R | 55 +++++ .../ml/r/LogisticRegressionWrapper.scala | 157 ++++++++++++++ .../org/apache/spark/ml/r/RWrappers.scala | 2 + 6 files changed, 410 insertions(+), 3 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index eb314f471893..7a89c01fee73 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -43,7 +43,8 @@ exportMethods("glm", "spark.isoreg", "spark.gaussianMixture", "spark.als", - "spark.kstest") + "spark.kstest", + "spark.logit") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4569fe489046..107e1c638be7 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1375,6 +1375,10 @@ setGeneric("spark.gaussianMixture", standardGeneric("spark.gaussianMixture") }) +#' @rdname spark.logit +#' @export +setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) + #' @param object a fitted ML model object. #' @param path the directory where the model is saved. #' @param ... additional argument(s) passed to the method. diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index bf182be8e23d..e441db94998b 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -95,6 +95,13 @@ setClass("ALSModel", representation(jobj = "jobj")) #' @note KSTest since 2.1.0 setClass("KSTest", representation(jobj = "jobj")) +#' S4 class that represents an LogisticRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala LogisticRegressionModel +#' @export +#' @note LogisticRegressionModel since 2.1.0 +setClass("LogisticRegressionModel", representation(jobj = "jobj")) + #' Saves the MLlib model to the input path #' #' Saves the MLlib model to the input path. For more information, see the specific @@ -105,7 +112,7 @@ setClass("KSTest", representation(jobj = "jobj")) #' @seealso \link{spark.glm}, \link{glm}, #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, #' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} -#' @seealso \link{read.ml} +#' @seealso \link{spark.logit}, \link{read.ml} NULL #' Makes predictions from a MLlib model @@ -117,7 +124,7 @@ NULL #' @export #' @seealso \link{spark.glm}, \link{glm}, #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, -#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} +#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.logit} NULL write_internal <- function(object, path, overwrite = FALSE) { @@ -647,6 +654,170 @@ setMethod("predict", signature(object = "KMeansModel"), predict_internal(object, newData) }) +#' Logistic Regression Model +#' +#' Fits an logistic regression model against a Spark DataFrame. It supports "binomial": Binary logistic regression +#' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. +#' Users can print, make predictions on the produced model and save the model to the input path. +#' +#' @param data SparkDataFrame for training +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param regParam the regularization parameter. Default is 0.0. +#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. +#' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination +#' of L1 and L2. Default is 0.0 which is an L2 penalty. +#' @param maxIter maximum iteration number. +#' @param tol convergence tolerance of iterations. +#' @param fitIntercept whether to fit an intercept term. Default is TRUE. +#' @param family the name of family which is a description of the label distribution to be used in the model. +#' Supported options: +#' \itemize{ +#' \item{"auto": Automatically select the family based on the number of classes: +#' If number of classes == 1 || number of classes == 2, set to "binomial". +#' Else, set to "multinomial".} +#' \item{"binomial": Binary logistic regression with pivoting.} +#' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting. +#' Default is "auto".} +#' } +#' @param standardization whether to standardize the training features before fitting the model. The coefficients +#' of models will be always returned on the original scale, so it will be transparent for +#' users. Note that with/without standardization, the models should be always converged +#' to the same solution when no regularization is applied. Default is TRUE, same as glmnet. +#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1 +#' is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 +#' more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with +#' threshold p is equivalent to setting thresholds c(1-p, p). When threshold is set, any user-set +#' value for thresholds will be cleared. If both threshold and thresholds are set, then they must be +#' equivalent. In multiclass (or binary) classification to adjust the probability of +#' predicting each class. Array must have length equal to the number of classes, with values > 0, +#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p +#' is the original probability of that class and t is the class's threshold. Note: When thresholds +#' is set, any user-set value for threshold will be cleared. If both threshold and thresholds are +#' set, then they must be equivalent. Default is 0.5. +#' @param weightCol The weight column name. +#' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions +#' are large, this param could be adjusted to a larger size. Default is 2. +#' @param probabilityCol column name for predicted class conditional probabilities. Default is "probability". +#' @param ... additional arguments passed to the method. +#' @return \code{spark.logit} returns a fitted logistic regression model +#' @rdname spark.logit +#' @aliases spark.logit,SparkDataFrame,formula-method +#' @name spark.logit +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' # binary logistic regression +#' label <- c(1.0, 1.0, 1.0, 0.0, 0.0) +#' feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) +#' binary_data <- as.data.frame(cbind(label, feature)) +#' binary_df <- createDataFrame(binary_data) +#' blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0) +#' blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) +#' +#' # summary of binary logistic regression +#' blr_summary <- summary(blr_model) +#' blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(blr_model, path) +#' +#' # can also read back the saved model and predict +#' Note that summary deos not work on loaded model +#' savedModel <- read.ml(path) +#' blr_predict2 <- collect(select(predict(savedModel, binary_df), "prediction")) +#' +#' # multinomial logistic regression +#' +#' label <- c(0.0, 1.0, 2.0, 0.0, 0.0) +#' feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) +#' feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) +#' feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) +#' feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) +#' data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) +#' df <- createDataFrame(data) +#' +#' Note that summary of multinomial logistic regression is not implemented yet +#' model <- spark.logit(df, label ~ ., family = "multinomial", thresholds=c(0, 1, 1)) +#' predict1 <- collect(select(predict(model, df), "prediction")) +#' } +#' @note spark.logit since 2.1.0 +setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, + tol = 1E-6, fitIntercept = TRUE, family = "auto", standardization = TRUE, + thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, + probabilityCol = "probability") { + formula <- paste0(deparse(formula), collapse = "") + + if (is.null(weightCol)) { + weightCol <- "" + } + + jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", + data@sdf, formula, as.numeric(regParam), + as.numeric(elasticNetParam), as.integer(maxIter), + as.numeric(tol), as.logical(fitIntercept), + as.character(family), as.logical(standardization), + as.array(thresholds), as.character(weightCol), + as.integer(aggregationDepth), as.character(probabilityCol)) + new("LogisticRegressionModel", jobj = jobj) + }) + +# Predicted values based on an LogisticRegressionModel model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns the predicted values based on an LogisticRegressionModel. +#' @rdname spark.logit +#' @aliases predict,LogisticRegressionModel,SparkDataFrame-method +#' @export +#' @note predict(LogisticRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "LogisticRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Get the summary of an LogisticRegressionModel + +#' @param object an LogisticRegressionModel fitted by \code{spark.logit} +#' @return \code{summary} returns the Binary Logistic regression results of a given model as lists. Note that +#' Multinomial logistic regression summary is not available now. +#' @rdname spark.logit +#' @aliases summary,LogisticRegressionModel-method +#' @export +#' @note summary(LogisticRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "LogisticRegressionModel"), + function(object) { + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + + if (is.loaded) { + stop("Loaded model doesn't have training summary.") + } + + roc <- dataFrame(callJMethod(jobj, "roc")) + + areaUnderROC <- callJMethod(jobj, "areaUnderROC") + + pr <- dataFrame(callJMethod(jobj, "pr")) + + fMeasureByThreshold <- dataFrame(callJMethod(jobj, "fMeasureByThreshold")) + + precisionByThreshold <- dataFrame(callJMethod(jobj, "precisionByThreshold")) + + recallByThreshold <- dataFrame(callJMethod(jobj, "recallByThreshold")) + + totalIterations <- callJMethod(jobj, "totalIterations") + + objectiveHistory <- callJMethod(jobj, "objectiveHistory") + + list(roc = roc, areaUnderROC = areaUnderROC, pr = pr, + fMeasureByThreshold = fMeasureByThreshold, + precisionByThreshold = precisionByThreshold, + recallByThreshold = recallByThreshold, + totalIterations = totalIterations, objectiveHistory = objectiveHistory) + }) + #' Multilayer Perceptron Classification Model #' #' \code{spark.mlp} fits a multi-layer perceptron neural network model against a SparkDataFrame. @@ -888,6 +1059,21 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char write_internal(object, path, overwrite) }) +# Save fitted LogisticRegressionModel to the input path + +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.logit +#' @aliases write.ml,LogisticRegressionModel,character-method +#' @export +#' @note write.ml(LogisticRegression, character) since 2.1.0 +setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + # Save fitted MLlib model to the input path #' @param path the directory where the model is saved. @@ -938,6 +1124,8 @@ read.ml <- function(path) { new("GaussianMixtureModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) { new("ALSModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) { + new("LogisticRegressionModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 33cc069f1445..6d1fccc7c058 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -602,6 +602,61 @@ test_that("spark.isotonicRegression", { unlink(modelPath) }) +test_that("spark.logit", { + # test binary logistic regression + label <- c(1.0, 1.0, 1.0, 0.0, 0.0) + feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) + binary_data <- as.data.frame(cbind(label, feature)) + binary_df <- createDataFrame(binary_data) + + blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0) + blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) + expect_equal(blr_predict$prediction, c(0, 0, 0, 0, 0)) + blr_model1 <- spark.logit(binary_df, label ~ feature, thresholds = 0.0) + blr_predict1 <- collect(select(predict(blr_model1, binary_df), "prediction")) + expect_equal(blr_predict1$prediction, c(1, 1, 1, 1, 1)) + + # test summary of binary logistic regression + blr_summary <- summary(blr_model) + blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) + expect_equal(blr_fmeasure$threshold, c(0.8221347, 0.7884005, 0.6674709, 0.3785437, 0.3434487), + tolerance = 1e-4) + expect_equal(blr_fmeasure$"F-Measure", c(0.5000000, 0.8000000, 0.6666667, 0.8571429, 0.7500000), + tolerance = 1e-4) + blr_precision <- collect(select(blr_summary$precisionByThreshold, "threshold", "precision")) + expect_equal(blr_precision$precision, c(1.0000000, 1.0000000, 0.6666667, 0.7500000, 0.6000000), + tolerance = 1e-4) + blr_recall <- collect(select(blr_summary$recallByThreshold, "threshold", "recall")) + expect_equal(blr_recall$recall, c(0.3333333, 0.6666667, 0.6666667, 1.0000000, 1.0000000), + tolerance = 1e-4) + + # test model save and read + modelPath <- tempfile(pattern = "spark-logisticRegression", fileext = ".tmp") + write.ml(blr_model, modelPath) + expect_error(write.ml(blr_model, modelPath)) + write.ml(blr_model, modelPath, overwrite = TRUE) + blr_model2 <- read.ml(modelPath) + blr_predict2 <- collect(select(predict(blr_model2, binary_df), "prediction")) + expect_equal(blr_predict$prediction, blr_predict2$prediction) + expect_error(summary(blr_model2)) + unlink(modelPath) + + # test multinomial logistic regression + label <- c(0.0, 1.0, 2.0, 0.0, 0.0) + feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) + feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) + feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) + feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) + data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) + df <- createDataFrame(data) + + model <- spark.logit(df, label ~., family = "multinomial", thresholds = c(0, 1, 1)) + predict1 <- collect(select(predict(model, df), "prediction")) + expect_equal(predict1$prediction, c(0, 0, 0, 0, 0)) + # Summary of multinomial logistic regression is not implemented yet + expect_error(summary(model)) +}) + test_that("spark.gaussianMixture", { # R code to reproduce the result. # nolint start diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala new file mode 100644 index 000000000000..9b352c986311 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class LogisticRegressionWrapper private ( + val pipeline: PipelineModel, + val features: Array[String], + val isLoaded: Boolean = false) extends MLWritable { + + private val logisticRegressionModel: LogisticRegressionModel = + pipeline.stages(1).asInstanceOf[LogisticRegressionModel] + + lazy val totalIterations: Int = logisticRegressionModel.summary.totalIterations + + lazy val objectiveHistory: Array[Double] = logisticRegressionModel.summary.objectiveHistory + + lazy val blrSummary = + logisticRegressionModel.summary.asInstanceOf[BinaryLogisticRegressionSummary] + + lazy val roc: DataFrame = blrSummary.roc + + lazy val areaUnderROC: Double = blrSummary.areaUnderROC + + lazy val pr: DataFrame = blrSummary.pr + + lazy val fMeasureByThreshold: DataFrame = blrSummary.fMeasureByThreshold + + lazy val precisionByThreshold: DataFrame = blrSummary.precisionByThreshold + + lazy val recallByThreshold: DataFrame = blrSummary.recallByThreshold + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(logisticRegressionModel.getFeaturesCol) + } + + override def write: MLWriter = new LogisticRegressionWrapper.LogisticRegressionWrapperWriter(this) +} + +private[r] object LogisticRegressionWrapper + extends MLReadable[LogisticRegressionWrapper] { + + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + regParam: Double, + elasticNetParam: Double, + maxIter: Int, + tol: Double, + fitIntercept: Boolean, + family: String, + standardization: Boolean, + thresholds: Array[Double], + weightCol: String, + aggregationDepth: Int, + probability: String + ): LogisticRegressionWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val logisticRegression = new LogisticRegression() + .setRegParam(regParam) + .setElasticNetParam(elasticNetParam) + .setMaxIter(maxIter) + .setTol(tol) + .setFitIntercept(fitIntercept) + .setFamily(family) + .setStandardization(standardization) + .setWeightCol(weightCol) + .setAggregationDepth(aggregationDepth) + .setFeaturesCol(rFormula.getFeaturesCol) + .setProbabilityCol(probability) + + if (thresholds.length > 1) { + logisticRegression.setThresholds(thresholds) + } else { + logisticRegression.setThreshold(thresholds(0)) + } + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, logisticRegression)) + .fit(data) + + new LogisticRegressionWrapper(pipeline, features) + } + + override def read: MLReader[LogisticRegressionWrapper] = new LogisticRegressionWrapperReader + + override def load(path: String): LogisticRegressionWrapper = super.load(path) + + class LogisticRegressionWrapperWriter(instance: LogisticRegressionWrapper) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.pipeline.save(pipelinePath) + } + } + + class LogisticRegressionWrapperReader extends MLReader[LogisticRegressionWrapper] { + + override def load(path: String): LogisticRegressionWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val features = (rMetadata \ "features").extract[Array[String]] + + val pipeline = PipelineModel.load(pipelinePath) + new LogisticRegressionWrapper(pipeline, features, isLoaded = true) + } + } +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index d64de1b6abb6..1df3662a5822 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -54,6 +54,8 @@ private[r] object RWrappers extends MLReader[Object] { GaussianMixtureWrapper.load(path) case "org.apache.spark.ml.r.ALSWrapper" => ALSWrapper.load(path) + case "org.apache.spark.ml.r.LogisticRegressionWrapper" => + LogisticRegressionWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } From a76846cfb1c2d6c8f4d647426030b59de20d9433 Mon Sep 17 00:00:00 2001 From: Miao Wang Date: Thu, 27 Oct 2016 01:17:32 +0200 Subject: [PATCH 031/381] [SPARK-18126][SPARK-CORE] getIteratorZipWithIndex accepts negative value as index ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) `Utils.getIteratorZipWithIndex` was added to deal with number of records > 2147483647 in one partition. method `getIteratorZipWithIndex` accepts `startIndex` < 0, which leads to negative index. This PR just adds a defensive check on `startIndex` to make sure it is >= 0. ## How was this patch tested? Add a new unit test. Author: Miao Wang Closes #15639 from wangmiao1981/zip. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 1 + core/src/test/scala/org/apache/spark/util/UtilsSuite.scala | 3 +++ 2 files changed, 4 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e57eb0de2689..6027b07c0fee 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1765,6 +1765,7 @@ private[spark] object Utils extends Logging { */ def getIteratorZipWithIndex[T](iterator: Iterator[T], startIndex: Long): Iterator[(T, Long)] = { new Iterator[(T, Long)] { + require(startIndex >= 0, "startIndex should be >= 0.") var index: Long = startIndex - 1L def hasNext: Boolean = iterator.hasNext def next(): (T, Long) = { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index aeb2969fd579..15ef32f21d90 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -401,6 +401,9 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(iterator.toArray === Array( (0, -1L + Int.MaxValue), (1, 0L + Int.MaxValue), (2, 1L + Int.MaxValue) )) + intercept[IllegalArgumentException] { + Utils.getIteratorZipWithIndex(Iterator(0, 1, 2), -1L) + } } test("doesDirectoryContainFilesNewerThan") { From 5b27598ff50cb08e7570fade458da0a3d4d4eabc Mon Sep 17 00:00:00 2001 From: frreiss Date: Wed, 26 Oct 2016 17:33:08 -0700 Subject: [PATCH 032/381] [SPARK-16963][STREAMING][SQL] Changes to Source trait and related implementation classes ## What changes were proposed in this pull request? This PR contains changes to the Source trait such that the scheduler can notify data sources when it is safe to discard buffered data. Summary of changes: * Added a method `commit(end: Offset)` that tells the Source that is OK to discard all offsets up `end`, inclusive. * Changed the semantics of a `None` value for the `getBatch` method to mean "from the very beginning of the stream"; as opposed to "all data present in the Source's buffer". * Added notes that the upper layers of the system will never call `getBatch` with a start value less than the last value passed to `commit`. * Added a `lastCommittedOffset` method to allow the scheduler to query the status of each Source on restart. This addition is not strictly necessary, but it seemed like a good idea -- Sources will be maintaining their own persistent state, and there may be bugs in the checkpointing code. * The scheduler in `StreamExecution.scala` now calls `commit` on its stream sources after marking each batch as complete in its checkpoint. * `MemoryStream` now cleans committed batches out of its internal buffer. * `TextSocketSource` now cleans committed batches from its internal buffer. ## How was this patch tested? Existing regression tests already exercise the new code. Author: frreiss Closes #14553 from frreiss/fred-16963. --- .../streaming/FileStreamSource.scala | 9 +++ .../sql/execution/streaming/Source.scala | 22 ++++-- .../execution/streaming/StreamExecution.scala | 32 ++++++--- .../sql/execution/streaming/memory.scala | 47 ++++++++++-- .../sql/execution/streaming/socket.scala | 72 +++++++++++++++---- .../sql/streaming/StreamingQuerySuite.scala | 8 +-- 6 files changed, 154 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 115edf7ab2b6..a392b8299902 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -176,6 +176,15 @@ class FileStreamSource( override def toString: String = s"FileStreamSource[$qualifiedBasePath]" + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + override def commit(end: Offset): Unit = { + // No-op for now; FileStreamSource currently garbage-collects files based on timestamp + // and the value of the maxFileAge parameter. + } + override def stop() {} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 971147840d2f..f3bd5bfe23fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -30,16 +30,30 @@ trait Source { /** Returns the schema of the data from this source */ def schema: StructType - /** Returns the maximum available offset for this source. */ + /** + * Returns the maximum available offset for this source. + * Returns `None` if this source has never received any data. + */ def getOffset: Option[Offset] /** - * Returns the data that is between the offsets (`start`, `end`]. When `start` is `None` then - * the batch should begin with the first available record. This method must always return the - * same data for a particular `start` and `end` pair. + * Returns the data that is between the offsets (`start`, `end`]. When `start` is `None`, + * then the batch should begin with the first record. This method must always return the + * same data for a particular `start` and `end` pair; even after the Source has been restarted + * on a different node. + * + * Higher layers will always call this method with a value of `start` greater than or equal + * to the last value passed to `commit` and a value of `end` less than or equal to the + * last value returned by `getOffset` */ def getBatch(start: Option[Offset], end: Offset): DataFrame + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + def commit(end: Offset) : Unit = {} + /** Stop this source and free any resources it has allocated. */ def stop(): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index ba8cf808e339..37af1a550aaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -73,6 +73,9 @@ class StreamExecution( /** * Tracks how much data we have processed and committed to the sink or state store from each * input source. + * Only the scheduler thread should modify this field, and only in atomic steps. + * Other threads should make a shallow copy if they are going to access this field more than + * once, since the field's value may change at any time. */ @volatile var committedOffsets = new StreamProgress @@ -80,6 +83,9 @@ class StreamExecution( /** * Tracks the offsets that are available to be processed, but have not yet be committed to the * sink. + * Only the scheduler thread should modify this field, and only in atomic steps. + * Other threads should make a shallow copy if they are going to access this field more than + * once, since the field's value may change at any time. */ @volatile private var availableOffsets = new StreamProgress @@ -337,17 +343,27 @@ class StreamExecution( } if (hasNewData) { reportTimeTaken(OFFSET_WAL_WRITE_LATENCY) { - assert( - offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)), + assert(offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)), s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") logInfo(s"Committed offsets for batch $currentBatchId.") + // NOTE: The following code is correct because runBatches() processes exactly one + // batch at a time. If we add pipeline parallelism (multiple batches in flight at + // the same time), this cleanup logic will need to change. + + // Now that we've updated the scheduler's persistent checkpoint, it is safe for the + // sources to discard data from the previous batch. + val prevBatchOff = offsetLog.get(currentBatchId - 1) + if (prevBatchOff.isDefined) { + prevBatchOff.get.toStreamProgress(sources).foreach { + case (src, off) => src.commit(off) + } + } + // Now that we have logged the new batch, no further processing will happen for - // the previous batch, and it is safe to discard the old metadata. - // Note that purge is exclusive, i.e. it purges everything before currentBatchId. - // NOTE: If StreamExecution implements pipeline parallelism (multiple batches in - // flight at the same time), this cleanup logic will need to change. - offsetLog.purge(currentBatchId) + // the batch before the previous batch, and it is safe to discard the old metadata. + // Note that purge is exclusive, i.e. it purges everything before the target ID. + offsetLog.purge(currentBatchId - 1) } } else { awaitBatchLock.lock() @@ -455,7 +471,7 @@ class StreamExecution( /** * Blocks the current thread until processing for data from the given `source` has reached at - * least the given `Offset`. This method is indented for use primarily when writing tests. + * least the given `Offset`. This method is intended for use primarily when writing tests. */ private[sql] def awaitOffset(source: Source, newOffset: Offset): Unit = { def notDone = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 788fcd0361be..48d9791faf1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal import org.apache.spark.internal.Logging @@ -51,12 +51,23 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected val logicalPlan = StreamingExecutionRelation(this) protected val output = logicalPlan.output + /** + * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive. + * Stored in a ListBuffer to facilitate removing committed batches. + */ @GuardedBy("this") - protected val batches = new ArrayBuffer[Dataset[A]] + protected val batches = new ListBuffer[Dataset[A]] @GuardedBy("this") protected var currentOffset: LongOffset = new LongOffset(-1) + /** + * Last offset that was discarded, or -1 if no commits have occurred. Note that the value + * -1 is used in calculations below and isn't just an arbitrary constant. + */ + @GuardedBy("this") + protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) + def schema: StructType = encoder.schema def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { @@ -85,21 +96,25 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" override def getOffset: Option[Offset] = synchronized { - if (batches.isEmpty) { + if (currentOffset.offset == -1) { None } else { Some(currentOffset) } } - /** - * Returns the data that is between the offsets (`start`, `end`]. - */ override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1 val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 - val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) } + + // Internal buffer only holds the batches after lastCommittedOffset. + val newBlocks = synchronized { + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 + val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + batches.slice(sliceStart, sliceEnd) + } logDebug( s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}") @@ -111,11 +126,29 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } + override def commit(end: Offset): Unit = synchronized { + end match { + case newOffset: LongOffset => + val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt + + if (offsetDiff < 0) { + sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end") + } + + batches.trimStart(offsetDiff) + lastOffsetCommitted = newOffset + case _ => + sys.error(s"MemoryStream.commit() received an offset ($end) that did not originate with " + + "an instance of this class") + } + } + override def stop() {} def reset(): Unit = synchronized { batches.clear() currentOffset = new LongOffset(-1) + lastOffsetCommitted = new LongOffset(-1) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala index fb15239f9af9..c662e7c6bc77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -24,14 +24,15 @@ import java.text.SimpleDateFormat import java.util.Calendar import javax.annotation.concurrent.GuardedBy -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} +import org.apache.spark.sql._ import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + object TextSocketSource { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: @@ -53,8 +54,18 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo @GuardedBy("this") private var readThread: Thread = null + /** + * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive. + * Stored in a ListBuffer to facilitate removing committed batches. + */ + @GuardedBy("this") + protected val batches = new ListBuffer[(String, Timestamp)] + + @GuardedBy("this") + protected var currentOffset: LongOffset = new LongOffset(-1) + @GuardedBy("this") - private var lines = new ArrayBuffer[(String, Timestamp)] + protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) initialize() @@ -74,10 +85,12 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo return } TextSocketSource.this.synchronized { - lines += ((line, + val newData = (line, Timestamp.valueOf( TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime())) - )) + ) + currentOffset = currentOffset + 1 + batches.append(newData) } } } catch { @@ -92,21 +105,54 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo override def schema: StructType = if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP else TextSocketSource.SCHEMA_REGULAR - /** Returns the maximum available offset for this source. */ override def getOffset: Option[Offset] = synchronized { - if (lines.isEmpty) None else Some(LongOffset(lines.size - 1)) + if (currentOffset.offset == -1) { + None + } else { + Some(currentOffset) + } } /** Returns the data that is between the offsets (`start`, `end`]. */ override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { - val startIdx = start.map(_.asInstanceOf[LongOffset].offset.toInt + 1).getOrElse(0) - val endIdx = end.asInstanceOf[LongOffset].offset.toInt + 1 - val data = synchronized { lines.slice(startIdx, endIdx) } + val startOrdinal = + start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1 + val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 + + // Internal buffer only holds the batches after lastOffsetCommitted + val rawList = synchronized { + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 + val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + batches.slice(sliceStart, sliceEnd) + } + import sqlContext.implicits._ + val rawBatch = sqlContext.createDataset(rawList) + + // Underlying MemoryStream has schema (String, Timestamp); strip out the timestamp + // if requested. if (includeTimestamp) { - data.toDF("value", "timestamp") + rawBatch.toDF("value", "timestamp") + } else { + // Strip out timestamp + rawBatch.select("_1").toDF("value") + } + } + + override def commit(end: Offset): Unit = synchronized { + if (end.isInstanceOf[LongOffset]) { + val newOffset = end.asInstanceOf[LongOffset] + val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt + + if (offsetDiff < 0) { + sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end") + } + + batches.trimStart(offsetDiff) + lastOffsetCommitted = newOffset } else { - data.map(_._1).toDF("value") + sys.error(s"TextSocketStream.commit() received an offset ($end) that did not " + + s"originate with an instance of this class") } } @@ -141,7 +187,7 @@ class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegis providerName: String, parameters: Map[String, String]): (String, StructType) = { logWarning("The socket source should not be used for production applications! " + - "It does not support recovery and stores state indefinitely.") + "It does not support recovery.") if (!parameters.contains("host")) { throw new AnalysisException("Set a host to read from with option(\"host\", ...).") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 92020be9789f..dad410486ed2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -252,8 +252,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) - // Run 3 batches, and then assert that only 1 metadata file is left at the end - // since the first 2 should have been purged. + // Run 3 batches, and then assert that only 2 metadata files is are at the end + // since the first should have been purged. testStream(mapped)( AddData(inputData, 1, 2), CheckAnswer(6, 3), @@ -262,11 +262,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { AddData(inputData, 4, 6), CheckAnswer(6, 3, 6, 3, 1, 1), - AssertOnQuery("metadata log should contain only one file") { q => + AssertOnQuery("metadata log should contain only two files") { q => val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toString) val logFileNames = metadataLogDir.listFiles().toSeq.map(_.getName()) val toTest = logFileNames.filter(! _.endsWith(".crc")) // Workaround for SPARK-17475 - assert(toTest.size == 1 && toTest.head == "2") + assert(toTest.size == 2 && toTest.head == "1") true } ) From f1aeed8b022e043de2eb38b30187dcc36ee8dcdb Mon Sep 17 00:00:00 2001 From: ALeksander Eskilson Date: Wed, 26 Oct 2016 18:03:31 -0700 Subject: [PATCH 033/381] [SPARK-17770][CATALYST] making ObjectType public ## What changes were proposed in this pull request? In order to facilitate the writing of additional Encoders, I proposed opening up the ObjectType SQL DataType. This DataType is used extensively in the JavaBean Encoder, but would also be useful in writing other custom encoders. As mentioned by marmbrus, it is understood that the Expressions API is subject to potential change. ## How was this patch tested? The change only affects the visibility of the ObjectType class, and the existing SQL test suite still runs without error. Author: ALeksander Eskilson Closes #15453 from bdrillard/master. --- .../org/apache/spark/sql/types/ObjectType.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index c741a2dd3ea3..b18fba29af0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -19,7 +19,10 @@ package org.apache.spark.sql.types import scala.language.existentials -private[sql] object ObjectType extends AbstractDataType { +import org.apache.spark.annotation.InterfaceStability + +@InterfaceStability.Evolving +object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = throw new UnsupportedOperationException("null literals can't be casted to ObjectType") @@ -32,11 +35,10 @@ private[sql] object ObjectType extends AbstractDataType { } /** - * Represents a JVM object that is passing through Spark SQL expression evaluation. Note this - * is only used internally while converting into the internal format and is not intended for use - * outside of the execution engine. + * Represents a JVM object that is passing through Spark SQL expression evaluation. */ -private[sql] case class ObjectType(cls: Class[_]) extends DataType { +@InterfaceStability.Evolving +case class ObjectType(cls: Class[_]) extends DataType { override def defaultSize: Int = 4096 def asNullable: DataType = this From dd4f088c1df6abd728e5544a17ba85322bedfe4c Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 27 Oct 2016 13:12:14 +0800 Subject: [PATCH 034/381] [SPARK-18009][SQL] Fix ClassCastException while calling toLocalIterator() on dataframe produced by RunnableCommand ## What changes were proposed in this pull request? A short code snippet that uses toLocalIterator() on a dataframe produced by a RunnableCommand reproduces the problem. toLocalIterator() is called by thriftserver when `spark.sql.thriftServer.incrementalCollect`is set to handle queries producing large result set. **Before** ```SQL scala> spark.sql("show databases") res0: org.apache.spark.sql.DataFrame = [databaseName: string] scala> res0.toLocalIterator() 16/10/26 03:00:24 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0) java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericInternalRow cannot be cast to org.apache.spark.sql.catalyst.expressions.UnsafeRow ``` **After** ```SQL scala> spark.sql("drop database databases") res30: org.apache.spark.sql.DataFrame = [] scala> spark.sql("show databases") res31: org.apache.spark.sql.DataFrame = [databaseName: string] scala> res31.toLocalIterator().asScala foreach println [default] [parquet] ``` ## How was this patch tested? Added a test in DDLSuite Author: Dilip Biswal Closes #15642 from dilipbiswal/SPARK-18009. --- .../org/apache/spark/sql/execution/command/commands.scala | 2 ++ .../org/apache/spark/sql/execution/command/DDLSuite.scala | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 698c625d617f..d82e54e57564 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -66,6 +66,8 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray + override def executeToIterator: Iterator[InternalRow] = sideEffectResult.toIterator + override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index de326f80f659..b989d01ec787 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1805,4 +1805,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } } + + test("SPARK-18009 calling toLocalIterator on commands") { + import scala.collection.JavaConverters._ + val df = sql("show databases") + val rows: Seq[Row] = df.toLocalIterator().asScala.toSeq + assert(rows.length > 0) + } } From d3b4831d009905185ad74096ce3ecfa934bc191d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 26 Oct 2016 22:22:23 -0700 Subject: [PATCH 035/381] [SPARK-18132] Fix checkstyle This PR fixes checkstyle. Author: Yin Huai Closes #15656 from yhuai/fix-format. --- .../util/collection/unsafe/sort/UnsafeExternalSorter.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 783501791023..dcae4a34c4b0 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -143,9 +143,10 @@ private UnsafeExternalSorter( this.recordComparator = recordComparator; this.prefixComparator = prefixComparator; // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units - // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024 this.fileBufferSizeBytes = 32 * 1024; - // The spill metrics are stored in a new ShuffleWriteMetrics, and then discarded (this fixes SPARK-16827). + // The spill metrics are stored in a new ShuffleWriteMetrics, + // and then discarded (this fixes SPARK-16827). // TODO: Instead, separate spill metrics should be stored and reported (tracked in SPARK-3577). this.writeMetrics = new ShuffleWriteMetrics(); From 1dbe9896b7f30538a5fad2f5d718d035c7906936 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 26 Oct 2016 23:02:54 -0700 Subject: [PATCH 036/381] [SPARK-17157][SPARKR][FOLLOW-UP] doc fixes ## What changes were proposed in this pull request? a couple of small late finding fixes for doc ## How was this patch tested? manually wangmiao1981 Author: Felix Cheung Closes #15650 from felixcheung/logitfix. --- R/pkg/R/mllib.R | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index e441db94998b..629f284b79f3 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -111,8 +111,9 @@ setClass("LogisticRegressionModel", representation(jobj = "jobj")) #' @export #' @seealso \link{spark.glm}, \link{glm}, #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, -#' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} -#' @seealso \link{spark.logit}, \link{read.ml} +#' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, +#' @seealso \link{spark.survreg} +#' @seealso \link{read.ml} NULL #' Makes predictions from a MLlib model @@ -124,7 +125,7 @@ NULL #' @export #' @seealso \link{spark.glm}, \link{glm}, #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, -#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.logit} +#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} NULL write_internal <- function(object, path, overwrite = FALSE) { @@ -671,14 +672,13 @@ setMethod("predict", signature(object = "KMeansModel"), #' @param tol convergence tolerance of iterations. #' @param fitIntercept whether to fit an intercept term. Default is TRUE. #' @param family the name of family which is a description of the label distribution to be used in the model. -#' Supported options: +#' Supported options: Default is "auto". #' \itemize{ #' \item{"auto": Automatically select the family based on the number of classes: #' If number of classes == 1 || number of classes == 2, set to "binomial". #' Else, set to "multinomial".} #' \item{"binomial": Binary logistic regression with pivoting.} -#' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting. -#' Default is "auto".} +#' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting.} #' } #' @param standardization whether to standardize the training features before fitting the model. The coefficients #' of models will be always returned on the original scale, so it will be transparent for @@ -687,14 +687,10 @@ setMethod("predict", signature(object = "KMeansModel"), #' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1 #' is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 #' more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with -#' threshold p is equivalent to setting thresholds c(1-p, p). When threshold is set, any user-set -#' value for thresholds will be cleared. If both threshold and thresholds are set, then they must be -#' equivalent. In multiclass (or binary) classification to adjust the probability of +#' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of #' predicting each class. Array must have length equal to the number of classes, with values > 0, #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p -#' is the original probability of that class and t is the class's threshold. Note: When thresholds -#' is set, any user-set value for threshold will be cleared. If both threshold and thresholds are -#' set, then they must be equivalent. Default is 0.5. +#' is the original probability of that class and t is the class's threshold. Default is 0.5. #' @param weightCol The weight column name. #' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions #' are large, this param could be adjusted to a larger size. Default is 2. @@ -724,7 +720,7 @@ setMethod("predict", signature(object = "KMeansModel"), #' write.ml(blr_model, path) #' #' # can also read back the saved model and predict -#' Note that summary deos not work on loaded model +#' # Note that summary deos not work on loaded model #' savedModel <- read.ml(path) #' blr_predict2 <- collect(select(predict(savedModel, binary_df), "prediction")) #' @@ -738,8 +734,8 @@ setMethod("predict", signature(object = "KMeansModel"), #' data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) #' df <- createDataFrame(data) #' -#' Note that summary of multinomial logistic regression is not implemented yet -#' model <- spark.logit(df, label ~ ., family = "multinomial", thresholds=c(0, 1, 1)) +#' # Note that summary of multinomial logistic regression is not implemented yet +#' model <- spark.logit(df, label ~ ., family = "multinomial", thresholds = c(0, 1, 1)) #' predict1 <- collect(select(predict(model, df), "prediction")) #' } #' @note spark.logit since 2.1.0 From 44c8bfda793b7655e2bd1da5e9915a09ed9d42ce Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 26 Oct 2016 23:06:11 -0700 Subject: [PATCH 037/381] [SQL][DOC] updating doc for JSON source to link to jsonlines.org ## What changes were proposed in this pull request? API and programming guide doc changes for Scala, Python and R. ## How was this patch tested? manual test Author: Felix Cheung Closes #15629 from felixcheung/jsondoc. --- R/pkg/R/DataFrame.R | 3 ++- R/pkg/R/SQLContext.R | 3 ++- docs/sparkr.md | 2 +- docs/sql-programming-guide.md | 22 +++++++++++-------- python/pyspark/sql/readwriter.py | 5 +++-- python/pyspark/sql/streaming.py | 3 ++- .../apache/spark/sql/DataFrameReader.scala | 14 +++++++----- .../apache/spark/sql/DataFrameWriter.scala | 3 ++- .../sql/streaming/DataStreamReader.scala | 3 ++- 9 files changed, 35 insertions(+), 23 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index be34e4b32f6f..1df8bbf9fe60 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -761,7 +761,8 @@ setMethod("toJSON", #' Save the contents of SparkDataFrame as a JSON file #' -#' Save the contents of a SparkDataFrame as a JSON file (one object per line). Files written out +#' Save the contents of a SparkDataFrame as a JSON file (\href{http://jsonlines.org/}{ +#' JSON Lines text format or newline-delimited JSON}). Files written out #' with this method can be read back in as a SparkDataFrame using read.json(). #' #' @param x A SparkDataFrame diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 0d6a229e6345..216ca51666ba 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -324,7 +324,8 @@ setMethod("toDF", signature(x = "RDD"), #' Create a SparkDataFrame from a JSON file. #' -#' Loads a JSON file (one object per line), returning the result as a SparkDataFrame +#' Loads a JSON file (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} +#' ), returning the result as a SparkDataFrame #' It goes through the entire dataset once to determine the schema. #' #' @param path Path of file to read. A vector of multiple paths is allowed. diff --git a/docs/sparkr.md b/docs/sparkr.md index c1829efd18f4..f30bd4026fed 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -135,7 +135,7 @@ sparkR.session(sparkPackages = "com.databricks:spark-avro_2.11:3.0.0") {% endhighlight %} -We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. +We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a consequence, a regular multi-line JSON file will most often fail.
{% highlight r %} diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 064af41965b7..b9be7a7545ef 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -316,7 +316,7 @@ Serializable and has getters and setters for all of its fields. Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, -and the types are inferred by sampling the whole datase, similar to the inference that is performed on JSON files. +and the types are inferred by sampling the whole dataset, similar to the inference that is performed on JSON files. {% include_example schema_inferring python/sql/basic.py %}
@@ -832,8 +832,9 @@ This conversion can be done using `SparkSession.read.json()` on either an RDD of or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each -line must contain a separate, self-contained valid JSON object. As a consequence, -a regular multi-line JSON file will most often fail. +line must contain a separate, self-contained valid JSON object. For more information, please see +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a +consequence, a regular multi-line JSON file will most often fail. {% include_example json_dataset scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} @@ -844,8 +845,9 @@ This conversion can be done using `SparkSession.read().json()` on either an RDD or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each -line must contain a separate, self-contained valid JSON object. As a consequence, -a regular multi-line JSON file will most often fail. +line must contain a separate, self-contained valid JSON object. For more information, please see +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a +consequence, a regular multi-line JSON file will most often fail. {% include_example json_dataset java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} @@ -855,8 +857,9 @@ Spark SQL can automatically infer the schema of a JSON dataset and load it as a This conversion can be done using `SparkSession.read.json` on a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each -line must contain a separate, self-contained valid JSON object. As a consequence, -a regular multi-line JSON file will most often fail. +line must contain a separate, self-contained valid JSON object. For more information, please see +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a +consequence, a regular multi-line JSON file will most often fail. {% include_example json_dataset python/sql/datasource.py %} @@ -867,8 +870,9 @@ the `read.json()` function, which loads data from a directory of JSON files wher files is a JSON object. Note that the file that is offered as _a json file_ is not a typical JSON file. Each -line must contain a separate, self-contained valid JSON object. As a consequence, -a regular multi-line JSON file will most often fail. +line must contain a separate, self-contained valid JSON object. For more information, please see +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a +consequence, a regular multi-line JSON file will most often fail. {% include_example json_dataset r/RSparkSQLExample.R %} diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 91c2b17049fa..bc786ef95ed0 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -160,8 +160,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None): """ - Loads a JSON file (one object per line) or an RDD of Strings storing JSON objects - (one object per record) and returns the result as a :class`DataFrame`. + Loads a JSON file (`JSON Lines text format or newline-delimited JSON + <[http://jsonlines.org/>`_) or an RDD of Strings storing JSON objects (one object per + record) and returns the result as a :class`DataFrame`. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 35fc46929168..559647bbabf6 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -640,7 +640,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None): """ - Loads a JSON file stream (one object per line) and returns a :class`DataFrame`. + Loads a JSON file stream (`JSON Lines text format or newline-delimited JSON + <[http://jsonlines.org/>`_) and returns a :class`DataFrame`. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b7b2203cdd85..a77937efd7e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -239,7 +239,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * Loads a JSON file ([[http://jsonlines.org/ JSON Lines text format or newline-delimited JSON]]) + * and returns the result as a [[DataFrame]]. * See the documentation on the overloaded `json()` method with varargs for more details. * * @since 1.4.0 @@ -250,7 +251,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * Loads a JSON file ([[http://jsonlines.org/ JSON Lines text format or newline-delimited JSON]]) + * and returns the result as a [[DataFrame]]. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -295,8 +297,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { def json(paths: String*): DataFrame = format("json").load(paths : _*) /** - * Loads a `JavaRDD[String]` storing JSON objects (one object per record) and - * returns the result as a [[DataFrame]]. + * Loads a `JavaRDD[String]` storing JSON objects ([[http://jsonlines.org/ JSON Lines text format + * or newline-delimited JSON]]) and returns the result as a [[DataFrame]]. * * Unless the schema is specified using [[schema]] function, this function goes through the * input once to determine the input schema. @@ -307,8 +309,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { def json(jsonRDD: JavaRDD[String]): DataFrame = json(jsonRDD.rdd) /** - * Loads an `RDD[String]` storing JSON objects (one object per record) and - * returns the result as a [[DataFrame]]. + * Loads an `RDD[String]` storing JSON objects ([[http://jsonlines.org/ JSON Lines text format or + * newline-delimited JSON]]) and returns the result as a [[DataFrame]]. * * Unless the schema is specified using [[schema]] function, this function goes through the * input once to determine the input schema. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5be3277651d0..4b5f0246b9a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -434,7 +434,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } /** - * Saves the content of the [[DataFrame]] in JSON format at the specified path. + * Saves the content of the [[DataFrame]] in JSON format ([[http://jsonlines.org/ JSON Lines text + * format or newline-delimited JSON]]) at the specified path. * This is equivalent to: * {{{ * format("json").save(path) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 87b73062180e..40b482e4c01a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -134,7 +134,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } /** - * Loads a JSON file stream (one object per line) and returns the result as a [[DataFrame]]. + * Loads a JSON file stream ([[http://jsonlines.org/ JSON Lines text format or newline-delimited + * JSON]]) and returns the result as a [[DataFrame]]. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. From 701a9d361b3045a25c42b3c0e44e7755d45ff78c Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Thu, 27 Oct 2016 10:00:37 +0200 Subject: [PATCH 038/381] [SPARK-CORE][TEST][MINOR] Fix the wrong comment in test ## What changes were proposed in this pull request? While learning core scheduler code, I found two lines of wrong comments. This PR simply corrects the comments. ## How was this patch tested? Author: wm624@hotmail.com Closes #15631 from wangmiao1981/Rbug. --- .../org/apache/spark/scheduler/TaskSetManagerSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index b49ba085ca5d..1b1a764ceff9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -261,14 +261,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) == None) clock.advance(LOCALITY_WAIT_MS) - // Offer host1, exec1 again, at NODE_LOCAL level: the node local (task 2) should + // Offer host1, exec1 again, at NODE_LOCAL level: the node local (task 3) should // get chosen before the noPref task assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index == 2) - // Offer host2, exec3 again, at NODE_LOCAL level: we should choose task 2 + // Offer host2, exec2, at NODE_LOCAL level: we should choose task 2 assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL).get.index == 1) - // Offer host2, exec3 again, at NODE_LOCAL level: we should get noPref task + // Offer host2, exec2 again, at NODE_LOCAL level: we should get noPref task // after failing to find a node_Local task assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL) == None) clock.advance(LOCALITY_WAIT_MS) From 104232580528c097a284d753adb5795f6de8b0a5 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Thu, 27 Oct 2016 10:30:59 -0700 Subject: [PATCH 039/381] [SPARK-17813][SQL][KAFKA] Maximum data per trigger ## What changes were proposed in this pull request? maxOffsetsPerTrigger option for rate limiting, proportionally based on volume of different topicpartitions. ## How was this patch tested? Added unit test Author: cody koeninger Closes #15527 from koeninger/SPARK-17813. --- .../structured-streaming-kafka-integration.md | 6 + .../spark/sql/kafka010/KafkaSource.scala | 107 ++++++++++++++---- .../spark/sql/kafka010/KafkaSourceSuite.scala | 71 +++++++++++- 3 files changed, 157 insertions(+), 27 deletions(-) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index e851f210c92c..a6c3b3a9024d 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -221,6 +221,12 @@ The following configurations are optional: 10 milliseconds to wait before retrying to fetch Kafka offsets + + maxOffsetsPerTrigger + long + none + Rate limit on maximum number of offsets processed per trigger interval. The specified total number of offsets will be proportionally split across topicPartitions of different volume. + Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 537b7b0baa1b..61cba737d148 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -96,6 +96,9 @@ private[kafka010] case class KafkaSource( private val offsetFetchAttemptIntervalMs = sourceOptions.getOrElse("fetchOffset.retryIntervalMs", "10").toLong + private val maxOffsetsPerTrigger = + sourceOptions.get("maxOffsetsPerTrigger").map(_.toLong) + /** * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the * offsets and never commits them. @@ -121,6 +124,8 @@ private[kafka010] case class KafkaSource( }.partitionToOffsets } + private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None + override def schema: StructType = KafkaSource.kafkaSchema /** Returns the maximum available offset for this source. */ @@ -128,9 +133,54 @@ private[kafka010] case class KafkaSource( // Make sure initialPartitionOffsets is initialized initialPartitionOffsets - val offset = KafkaSourceOffset(fetchLatestOffsets()) - logDebug(s"GetOffset: ${offset.partitionToOffsets.toSeq.map(_.toString).sorted}") - Some(offset) + val latest = fetchLatestOffsets() + val offsets = maxOffsetsPerTrigger match { + case None => + latest + case Some(limit) if currentPartitionOffsets.isEmpty => + rateLimit(limit, initialPartitionOffsets, latest) + case Some(limit) => + rateLimit(limit, currentPartitionOffsets.get, latest) + } + + currentPartitionOffsets = Some(offsets) + logDebug(s"GetOffset: ${offsets.toSeq.map(_.toString).sorted}") + Some(KafkaSourceOffset(offsets)) + } + + /** Proportionally distribute limit number of offsets among topicpartitions */ + private def rateLimit( + limit: Long, + from: Map[TopicPartition, Long], + until: Map[TopicPartition, Long]): Map[TopicPartition, Long] = { + val fromNew = fetchNewPartitionEarliestOffsets(until.keySet.diff(from.keySet).toSeq) + val sizes = until.flatMap { + case (tp, end) => + // If begin isn't defined, something's wrong, but let alert logic in getBatch handle it + from.get(tp).orElse(fromNew.get(tp)).flatMap { begin => + val size = end - begin + logDebug(s"rateLimit $tp size is $size") + if (size > 0) Some(tp -> size) else None + } + } + val total = sizes.values.sum.toDouble + if (total < 1) { + until + } else { + until.map { + case (tp, end) => + tp -> sizes.get(tp).map { size => + val begin = from.get(tp).getOrElse(fromNew(tp)) + val prorate = limit * (size / total) + logDebug(s"rateLimit $tp prorated amount is $prorate") + // Don't completely starve small topicpartitions + val off = begin + (if (prorate < 1) Math.ceil(prorate) else Math.floor(prorate)).toLong + logDebug(s"rateLimit $tp new offset is $off") + // Paranoia, make sure not to return an offset that's past end + Math.min(end, off) + }.getOrElse(end) + } + } } /** @@ -153,11 +203,7 @@ private[kafka010] case class KafkaSource( // Find the new partitions, and get their earliest offsets val newPartitions = untilPartitionOffsets.keySet.diff(fromPartitionOffsets.keySet) - val newPartitionOffsets = if (newPartitions.nonEmpty) { - fetchNewPartitionEarliestOffsets(newPartitions.toSeq) - } else { - Map.empty[TopicPartition, Long] - } + val newPartitionOffsets = fetchNewPartitionEarliestOffsets(newPartitions.toSeq) if (newPartitionOffsets.keySet != newPartitions) { // We cannot get from offsets for some partitions. It means they got deleted. val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet) @@ -221,6 +267,12 @@ private[kafka010] case class KafkaSource( logInfo("GetBatch generating RDD of offset range: " + offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) + + // On recovery, getBatch will get called before getOffset + if (currentPartitionOffsets.isEmpty) { + currentPartitionOffsets = Some(untilPartitionOffsets) + } + sqlContext.createDataFrame(rdd, schema) } @@ -305,23 +357,28 @@ private[kafka010] case class KafkaSource( * some partitions if they are deleted. */ private def fetchNewPartitionEarliestOffsets( - newPartitions: Seq[TopicPartition]): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { - // Poll to get the latest assigned partitions - consumer.poll(0) - val partitions = consumer.assignment() - consumer.pause(partitions) - logDebug(s"\tPartitions assigned to consumer: $partitions") - - // Get the earliest offset of each partition - consumer.seekToBeginning(partitions) - val partitionOffsets = newPartitions.filter { p => - // When deleting topics happen at the same time, some partitions may not be in `partitions`. - // So we need to ignore them - partitions.contains(p) - }.map(p => p -> consumer.position(p)).toMap - logDebug(s"Got earliest offsets for new partitions: $partitionOffsets") - partitionOffsets - } + newPartitions: Seq[TopicPartition]): Map[TopicPartition, Long] = + if (newPartitions.isEmpty) { + Map.empty[TopicPartition, Long] + } else { + withRetriesWithoutInterrupt { + // Poll to get the latest assigned partitions + consumer.poll(0) + val partitions = consumer.assignment() + consumer.pause(partitions) + logDebug(s"\tPartitions assigned to consumer: $partitions") + + // Get the earliest offset of each partition + consumer.seekToBeginning(partitions) + val partitionOffsets = newPartitions.filter { p => + // When deleting topics happen at the same time, some partitions may not be in + // `partitions`. So we need to ignore them + partitions.contains(p) + }.map(p => p -> consumer.position(p)).toMap + logDebug(s"Got earliest offsets for new partitions: $partitionOffsets") + partitionOffsets + } + } /** * Helper function that does multiple retries on the a body of code that returns offsets. diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index b50688ecb774..ed4cc75920e8 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -23,13 +23,14 @@ import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata import org.apache.kafka.common.TopicPartition +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{ ProcessingTime, StreamTest } import org.apache.spark.sql.test.SharedSQLContext - abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { protected var testUtils: KafkaTestUtils = _ @@ -133,6 +134,72 @@ class KafkaSourceSuite extends KafkaSourceTest { private val topicId = new AtomicInteger(0) + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } + test("cannot stop Kafka stream") { val topic = newTopic() testUtils.createTopic(newTopic(), partitions = 5) From 0b076d4cb6afde2946124e6411ed6a6ce7b8b1a7 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Thu, 27 Oct 2016 11:52:15 -0700 Subject: [PATCH 040/381] [SPARK-17219][ML] enhanced NaN value handling in Bucketizer ## What changes were proposed in this pull request? This PR is an enhancement of PR with commit ID:57dc326bd00cf0a49da971e9c573c48ae28acaa2. NaN is a special type of value which is commonly seen as invalid. But We find that there are certain cases where NaN are also valuable, thus need special handling. We provided user when dealing NaN values with 3 options, to either reserve an extra bucket for NaN values, or remove the NaN values, or report an error, by setting handleNaN "keep", "skip", or "error"(default) respectively. '''Before: val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") .setSplits(splits) '''After: val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") .setSplits(splits) .setHandleNaN("keep") ## How was this patch tested? Tests added in QuantileDiscretizerSuite, BucketizerSuite and DataFrameStatSuite Signed-off-by: VinceShieh Author: VinceShieh Author: Vincent Xie Author: Joseph K. Bradley Closes #15428 from VinceShieh/spark-17219_followup. --- docs/ml-features.md | 15 ++-- .../apache/spark/ml/feature/Bucketizer.scala | 71 +++++++++++++++++-- .../ml/feature/QuantileDiscretizer.scala | 47 ++++++++++-- .../spark/ml/feature/BucketizerSuite.scala | 26 +++++-- .../ml/feature/QuantileDiscretizerSuite.scala | 35 ++++++--- python/pyspark/ml/feature.py | 5 -- .../apache/spark/sql/DataFrameStatSuite.scala | 4 ++ 7 files changed, 161 insertions(+), 42 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index a7f710fa52e6..64c6a160239c 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1103,11 +1103,16 @@ for more details on the API. `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned categorical features. The number of bins is set by the `numBuckets` parameter. It is possible -that the number of buckets used will be less than this value, for example, if there are too few -distinct values of the input to create enough distinct quantiles. Note also that NaN values are -handled specially and placed into their own bucket. For example, if 4 buckets are used, then -non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4]. -The bin ranges are chosen using an approximate algorithm (see the documentation for +that the number of buckets used will be smaller than this value, for example, if there are too few +distinct values of the input to create enough distinct quantiles. + +NaN values: Note also that QuantileDiscretizer +will raise an error when it finds NaN values in the dataset, but the user can also choose to either +keep or remove NaN values within the dataset by setting `handleInvalid`. If the user chooses to keep +NaN values, they will be handled specially and placed into their own bucket, for example, if 4 buckets +are used, then non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4]. + +Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for [approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions) for a detailed description). The precision of the approximation can be controlled with the `relativeError` parameter. When set to zero, exact quantiles are calculated diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index ec0ea05f9e1b..1143f0f565eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql._ +import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -46,6 +47,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String * also includes y. Splits should be of length >= 3 and strictly increasing. * Values at -inf, inf must be explicitly provided to cover all Double values; * otherwise, values outside the splits specified will be treated as errors. + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * * @group param */ @Since("1.4.0") @@ -73,15 +77,47 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** + * Param for how to handle invalid entries. Options are skip (filter out rows with + * invalid values), error (throw an error), or keep (keep invalid values in a special additional + * bucket). + * Default: "error" + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (filter out rows with invalid values), " + + "error (throw an error), or keep (keep invalid values in a special additional bucket).", + ParamValidators.inArray(Bucketizer.supportedHandleInvalid)) + + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) + + /** @group setParam */ + @Since("2.1.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, Bucketizer.ERROR_INVALID) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - val bucketizer = udf { feature: Double => - Bucketizer.binarySearchForBuckets($(splits), feature) + val (filteredDataset, keepInvalid) = { + if (getHandleInvalid == Bucketizer.SKIP_INVALID) { + // "skip" NaN option is set, will filter out NaN values in the dataset + (dataset.na.drop().toDF(), false) + } else { + (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID) + } + } + + val bucketizer: UserDefinedFunction = udf { (feature: Double) => + Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) } - val newCol = bucketizer(dataset($(inputCol))) - val newField = prepOutputField(dataset.schema) - dataset.withColumn($(outputCol), newCol, newField.metadata) + + val newCol = bucketizer(filteredDataset($(inputCol))) + val newField = prepOutputField(filteredDataset.schema) + filteredDataset.withColumn($(outputCol), newCol, newField.metadata) } private def prepOutputField(schema: StructType): StructField = { @@ -106,6 +142,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.6.0") object Bucketizer extends DefaultParamsReadable[Bucketizer] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalid: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + /** * We require splits to be of length >= 3 and to be in strictly increasing order. * No NaN split should be accepted. @@ -126,11 +168,26 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { /** * Binary searching in several buckets to place each data point. + * @param splits array of split points + * @param feature data point + * @param keepInvalid NaN flag. + * Set "true" to make an extra bucket for NaN values; + * Set "false" to report an error for NaN values + * @return bucket for each data point * @throws SparkException if a feature is < splits.head or > splits.last */ - private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { + + private[feature] def binarySearchForBuckets( + splits: Array[Double], + feature: Double, + keepInvalid: Boolean): Double = { if (feature.isNaN) { - splits.length - 1 + if (keepInvalid) { + splits.length - 1 + } else { + throw new SparkException("Bucketizer encountered NaN value. To handle or skip NaNs," + + " try setting Bucketizer.handleInvalid.") + } } else if (feature == splits.last) { splits.length - 2 } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 05e034d90f6a..b9e01dde70d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -36,6 +36,9 @@ private[feature] trait QuantileDiscretizerBase extends Params /** * Number of buckets (quantiles, or categories) into which data points are grouped. Must * be >= 2. + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * * default: 2 * @group param */ @@ -61,17 +64,41 @@ private[feature] trait QuantileDiscretizerBase extends Params /** @group getParam */ def getRelativeError: Double = getOrDefault(relativeError) + + /** + * Param for how to handle invalid entries. Options are skip (filter out rows with + * invalid values), error (throw an error), or keep (keep invalid values in a special additional + * bucket). + * Default: "error" + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (filter out rows with invalid values), " + + "error (throw an error), or keep (keep invalid values in a special additional bucket).", + ParamValidators.inArray(Bucketizer.supportedHandleInvalid)) + setDefault(handleInvalid, Bucketizer.ERROR_INVALID) + + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) + } /** * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * categorical features. The number of bins can be set using the `numBuckets` parameter. It is - * possible that the number of buckets used will be less than this value, for example, if there - * are too few distinct values of the input to create enough distinct quantiles. Note also that - * NaN values are handled specially and placed into their own bucket. For example, if 4 buckets - * are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special - * bucket(4). - * The bin ranges are chosen using an approximate algorithm (see the documentation for + * possible that the number of buckets used will be smaller than this value, for example, if there + * are too few distinct values of the input to create enough distinct quantiles. + * + * NaN handling: Note also that + * QuantileDiscretizer will raise an error when it finds NaN values in the dataset, but the user can + * also choose to either keep or remove NaN values within the dataset by setting `handleInvalid`. + * If the user chooses to keep NaN values, they will be handled specially and placed into their own + * bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3], + * but NaNs will be counted in a special bucket[4]. + * + * Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] * for a detailed description). The precision of the approximation can be controlled with the * `relativeError` parameter. The lower and upper bin bounds will be `-Infinity` and `+Infinity`, @@ -100,6 +127,10 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.1.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkNumericType(schema, $(inputCol)) @@ -124,7 +155,9 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + s" buckets as a result.") } - val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted) + val bucketizer = new Bucketizer(uid) + .setSplits(distinctSplits.sorted) + .setHandleInvalid($(handleInvalid)) copyValues(bucketizer.setParent(this)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 87cdceb26738..aac29137d791 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -99,21 +99,32 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) + bucketizer.setHandleInvalid("keep") bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") } + + bucketizer.setHandleInvalid("skip") + val skipResults: Array[Double] = bucketizer.transform(dataFrame) + .select("result").as[Double].collect() + assert(skipResults.length === 7) + assert(skipResults.forall(_ !== 4.0)) + + bucketizer.setHandleInvalid("error") + withClue("Bucketizer should throw error when setHandleInvalid=error and given NaN values") { + intercept[SparkException] { + bucketizer.transform(dataFrame).collect() + } + } } test("Bucket continuous features, with NaN splits") { val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN) - withClue("Invalid NaN split was not caught as an invalid split!") { + withClue("Invalid NaN split was not caught during Bucketizer initialization") { intercept[IllegalArgumentException] { - val bucketizer: Bucketizer = new Bucketizer() - .setInputCol("feature") - .setOutputCol("result") - .setSplits(splits) + new Bucketizer().setSplits(splits) } } } @@ -138,7 +149,8 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val data = Array.fill(100)(Random.nextDouble()) val splits: Array[Double] = Double.NegativeInfinity +: Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity - val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x))) + val bsResult = Vectors.dense(data.map(x => + Bucketizer.binarySearchForBuckets(splits, x, false))) val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x))) assert(bsResult ~== lsResult absTol 1e-5) } @@ -169,7 +181,7 @@ private object BucketizerSuite extends SparkFunSuite { /** Check all values in splits, plus values between all splits. */ def checkBinarySearch(splits: Array[Double]): Unit = { def testFeature(feature: Double, expectedBucket: Double): Unit = { - assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket, + assert(Bucketizer.binarySearchForBuckets(splits, feature, false) === expectedBucket, s"Expected feature value $feature to be in bucket $expectedBucket with splits:" + s" ${splits.mkString(", ")}") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6822594044a5..f219f775b218 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql._ import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite @@ -76,20 +76,33 @@ class QuantileDiscretizerSuite import spark.implicits._ val numBuckets = 3 - val df = sc.parallelize(Array(1.0, 1.0, 1.0, Double.NaN)) - .map(Tuple1.apply).toDF("input") + val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) + val expectedKeep = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0) + val expectedSkip = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0) + val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - // Reserve extra one bucket for NaN - val expectedNumBuckets = discretizer.fit(df).getSplits.length - 1 - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets == expectedNumBuckets, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBuckets but found $observedNumBuckets") + withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { + val dataFrame: DataFrame = validData.toSeq.toDF("input") + intercept[SparkException] { + discretizer.fit(dataFrame).transform(dataFrame).collect() + } + } + + List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{ + case(u, v) => + discretizer.setHandleInvalid(u) + val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected") + val result = discretizer.fit(dataFrame).transform(dataFrame) + result.select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, + s"The feature value is not correct after bucketing. Expected $y but found $x") + } + } } test("Test transform method on unseen data") { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 7683360664eb..94afe82a3647 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1155,11 +1155,6 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned categorical features. The number of bins can be set using the :py:attr:`numBuckets` parameter. - It is possible that the number of buckets used will be less than this value, for example, if - there are too few distinct values of the input to create enough distinct quantiles. Note also - that NaN values are handled specially and placed into their own bucket. For example, if 4 - buckets are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in - a special bucket(4). The bin ranges are chosen using an approximate algorithm (see the documentation for :py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed description). The precision of the approximation can be controlled with the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 73026c749db4..1383208874a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -150,6 +150,10 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(d1 - 2 * q1 * n) < error_double) assert(math.abs(d2 - 2 * q2 * n) < error_double) } + // test approxQuantile on NaN values + val dfNaN = Seq(Double.NaN, 1.0, Double.NaN, Double.NaN).toDF("input") + val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons.head) + assert(resNaN.count(_.isNaN) === 0) } test("crosstab") { From 79fd0cc0584e48fb021c4237877b15abbffb319a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 27 Oct 2016 12:32:58 -0700 Subject: [PATCH 041/381] [SPARK-16963][SQL] Fix test "StreamExecution metadata garbage collection" ## What changes were proposed in this pull request? A follow up PR for #14553 to fix the flaky test. It's flaky because the file list API doesn't guarantee any order of the return list. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15661 from zsxwing/fix-StreamingQuerySuite. --- .../org/apache/spark/sql/streaming/StreamingQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index dad410486ed2..464c443beb6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -265,7 +265,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { AssertOnQuery("metadata log should contain only two files") { q => val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toString) val logFileNames = metadataLogDir.listFiles().toSeq.map(_.getName()) - val toTest = logFileNames.filter(! _.endsWith(".crc")) // Workaround for SPARK-17475 + val toTest = logFileNames.filter(! _.endsWith(".crc")).sorted // Workaround for SPARK-17475 assert(toTest.size == 2 && toTest.head == "1") true } From ccb11543048dccd4cc590a8db1df1d9d5847d112 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 27 Oct 2016 14:22:30 -0700 Subject: [PATCH 042/381] [SPARK-17970][SQL] store partition spec in metastore for data source table ## What changes were proposed in this pull request? We should follow hive table and also store partition spec in metastore for data source table. This brings 2 benefits: 1. It's more flexible to manage the table data files, as users can use `ADD PARTITION`, `DROP PARTITION` and `RENAME PARTITION` 2. We don't need to cache all file status for data source table anymore. ## How was this patch tested? existing tests. Author: Eric Liang Author: Michael Allman Author: Eric Liang Author: Wenchen Fan Closes #15515 from cloud-fan/partition. --- .../sql/catalyst/catalog/interface.scala | 12 +- .../sql/catalyst/trees/TreeNodeSuite.scala | 1 + .../apache/spark/sql/DataFrameWriter.scala | 13 +- .../command/AnalyzeColumnCommand.scala | 3 +- .../command/AnalyzeTableCommand.scala | 3 +- .../command/createDataSourceTables.scala | 17 +- .../spark/sql/execution/command/ddl.scala | 90 ++++---- .../spark/sql/execution/command/tables.scala | 39 ++-- .../execution/datasources/DataSource.scala | 20 +- .../datasources/DataSourceStrategy.scala | 15 +- .../execution/datasources/FileCatalog.scala | 4 + .../datasources/FileStatusCache.scala | 2 +- .../PartitioningAwareFileCatalog.scala | 12 +- .../datasources/TableFileCatalog.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 16 +- .../apache/spark/sql/SQLQueryTestSuite.scala | 2 +- .../sql/execution/command/DDLSuite.scala | 200 +++++++----------- .../spark/sql/hive/HiveExternalCatalog.scala | 129 +++++++---- .../spark/sql/hive/HiveMetastoreCatalog.scala | 9 +- .../sql/hive/client/HiveClientImpl.scala | 5 +- .../sql/hive/HiveMetadataCacheSuite.scala | 2 +- .../PartitionProviderCompatibilitySuite.scala | 137 ++++++++++++ ...a => PartitionedTablePerfStatsSuite.scala} | 112 +++++++--- .../spark/sql/hive/StatisticsSuite.scala | 65 +++--- .../sql/hive/execution/HiveCommandSuite.scala | 5 +- .../sql/hive/execution/SQLQuerySuite.scala | 8 +- 26 files changed, 596 insertions(+), 329 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala rename sql/hive/src/test/scala/org/apache/spark/sql/hive/{HiveTablePerfStatsSuite.scala => PartitionedTablePerfStatsSuite.scala} (68%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index a97ed701c420..7c3bec897956 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -89,9 +89,10 @@ case class CatalogTablePartition( parameters: Map[String, String] = Map.empty) { override def toString: String = { + val specString = spec.map { case (k, v) => s"$k=$v" }.mkString(", ") val output = Seq( - s"Partition Values: [${spec.values.mkString(", ")}]", + s"Partition Values: [$specString]", s"$storage", s"Partition Parameters:{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") @@ -137,6 +138,8 @@ case class BucketSpec( * Can be None if this table is a View, should be "hive" for hive serde tables. * @param unsupportedFeatures is a list of string descriptions of features that are used by the * underlying table but not supported by Spark SQL yet. + * @param partitionProviderIsHive whether this table's partition metadata is stored in the Hive + * metastore. */ case class CatalogTable( identifier: TableIdentifier, @@ -154,7 +157,8 @@ case class CatalogTable( viewOriginalText: Option[String] = None, viewText: Option[String] = None, comment: Option[String] = None, - unsupportedFeatures: Seq[String] = Seq.empty) { + unsupportedFeatures: Seq[String] = Seq.empty, + partitionProviderIsHive: Boolean = false) { /** schema of this table's partition columns */ def partitionSchema: StructType = StructType(schema.filter { @@ -212,11 +216,11 @@ case class CatalogTable( comment.map("Comment: " + _).getOrElse(""), if (properties.nonEmpty) s"Properties: $tableProperties" else "", if (stats.isDefined) s"Statistics: ${stats.get.simpleString}" else "", - s"$storage") + s"$storage", + if (partitionProviderIsHive) "Partition Provider: Hive" else "") output.filter(_.nonEmpty).mkString("CatalogTable(\n\t", "\n\t", ")") } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index cb0426c7a98a..3eff12f9eed1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -489,6 +489,7 @@ class TreeNodeSuite extends SparkFunSuite { "owner" -> "", "createTime" -> 0, "lastAccessTime" -> -1, + "partitionProviderIsHive" -> false, "properties" -> JNull, "unsupportedFeatures" -> List.empty[String])) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 4b5f0246b9a1..7ff3522f547d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,7 +25,8 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Union} +import org.apache.spark.sql.execution.command.AlterTableRecoverPartitionsCommand import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, CreateTable, DataSource, HadoopFsRelation} import org.apache.spark.sql.types.StructType @@ -387,7 +388,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec ) - val cmd = CreateTable(tableDesc, mode, Some(df.logicalPlan)) + val createCmd = CreateTable(tableDesc, mode, Some(df.logicalPlan)) + val cmd = if (tableDesc.partitionColumnNames.nonEmpty && + df.sparkSession.sqlContext.conf.manageFilesourcePartitions) { + // Need to recover partitions into the metastore so our saved data is visible. + val recoverPartitionCmd = AlterTableRecoverPartitionsCommand(tableDesc.identifier) + Union(createCmd, recoverPartitionCmd) + } else { + createCmd + } df.sparkSession.sessionState.executePlan(cmd).toRdd } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 488138709a12..f873f34a845e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -50,7 +50,8 @@ case class AnalyzeColumnCommand( AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable)) case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => - updateStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) + updateStats(logicalRel.catalogTable.get, + AnalyzeTableCommand.calculateTotalSize(sessionState, logicalRel.catalogTable.get)) case otherRelation => throw new AnalysisException("ANALYZE TABLE is not supported for " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 7b0e49b665f4..52a8fc88c56c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -51,7 +51,8 @@ case class AnalyzeTableCommand( // data source tables have been converted into LogicalRelations case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => - updateTableStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) + updateTableStats(logicalRel.catalogTable.get, + AnalyzeTableCommand.calculateTotalSize(sessionState, logicalRel.catalogTable.get)) case otherRelation => throw new AnalysisException("ANALYZE TABLE is not supported for " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index a8c75a7f29ce..2a9743130d4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -94,10 +94,16 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo val newTable = table.copy( storage = table.storage.copy(properties = optionsWithPath), schema = dataSource.schema, - partitionColumnNames = partitionColumnNames) + partitionColumnNames = partitionColumnNames, + // If metastore partition management for file source tables is enabled, we start off with + // partition provider hive, but no partitions in the metastore. The user has to call + // `msck repair table` to populate the table partitions. + partitionProviderIsHive = partitionColumnNames.nonEmpty && + sparkSession.sessionState.conf.manageFilesourcePartitions) // We will return Nil or throw exception at the beginning if the table already exists, so when // we reach here, the table should not exist and we should set `ignoreIfExists` to false. sessionState.catalog.createTable(newTable, ignoreIfExists = false) + Seq.empty[Row] } } @@ -232,6 +238,15 @@ case class CreateDataSourceTableAsSelectCommand( sessionState.catalog.createTable(newTable, ignoreIfExists = false) } + result match { + case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && + sparkSession.sqlContext.conf.manageFilesourcePartitions => + // Need to recover partitions into the metastore so our saved data is visible. + sparkSession.sessionState.executePlan( + AlterTableRecoverPartitionsCommand(table.identifier)).toRdd + case _ => + } + // Refresh the cache of the table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 15656faa08e4..61e0550cef5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -28,10 +28,11 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTablePartition, CatalogTableType, SessionCatalog} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, PartitioningUtils} import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -346,10 +347,7 @@ case class AlterTableAddPartitionCommand( val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) DDLUtils.verifyAlterTableType(catalog, table, isView = false) - if (DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - "ALTER TABLE ADD PARTITION is not allowed for tables defined using the datasource API") - } + DDLUtils.verifyPartitionProviderIsHive(sparkSession, table, "ALTER TABLE ADD PARTITION") val parts = partitionSpecsAndLocs.map { case (spec, location) => val normalizedSpec = PartitioningUtils.normalizePartitionSpec( spec, @@ -382,11 +380,8 @@ case class AlterTableRenamePartitionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) - if (DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - "ALTER TABLE RENAME PARTITION is not allowed for tables defined using the datasource API") - } DDLUtils.verifyAlterTableType(catalog, table, isView = false) + DDLUtils.verifyPartitionProviderIsHive(sparkSession, table, "ALTER TABLE RENAME PARTITION") val normalizedOldPartition = PartitioningUtils.normalizePartitionSpec( oldPartition, @@ -432,10 +427,7 @@ case class AlterTableDropPartitionCommand( val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) DDLUtils.verifyAlterTableType(catalog, table, isView = false) - if (DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - "ALTER TABLE DROP PARTITIONS is not allowed for tables defined using the datasource API") - } + DDLUtils.verifyPartitionProviderIsHive(sparkSession, table, "ALTER TABLE DROP PARTITION") val normalizedSpecs = specs.map { spec => PartitioningUtils.normalizePartitionSpec( @@ -493,33 +485,39 @@ case class AlterTableRecoverPartitionsCommand( } } + private def getBasePath(table: CatalogTable): Option[String] = { + if (table.provider == Some("hive")) { + table.storage.locationUri + } else { + new CaseInsensitiveMap(table.storage.properties).get("path") + } + } + override def run(spark: SparkSession): Seq[Row] = { val catalog = spark.sessionState.catalog val table = catalog.getTableMetadata(tableName) val tableIdentWithDB = table.identifier.quotedString DDLUtils.verifyAlterTableType(catalog, table, isView = false) - if (DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - s"Operation not allowed: $cmd on datasource tables: $tableIdentWithDB") - } if (table.partitionColumnNames.isEmpty) { throw new AnalysisException( s"Operation not allowed: $cmd only works on partitioned tables: $tableIdentWithDB") } - if (table.storage.locationUri.isEmpty) { + + val tablePath = getBasePath(table) + if (tablePath.isEmpty) { throw new AnalysisException(s"Operation not allowed: $cmd only works on table with " + s"location provided: $tableIdentWithDB") } - val root = new Path(table.storage.locationUri.get) + val root = new Path(tablePath.get) logInfo(s"Recover all the partitions in $root") val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt val hadoopConf = spark.sparkContext.hadoopConfiguration val pathFilter = getPathFilter(hadoopConf) - val partitionSpecsAndLocs = scanPartitions( - spark, fs, pathFilter, root, Map(), table.partitionColumnNames.map(_.toLowerCase), threshold) + val partitionSpecsAndLocs = scanPartitions(spark, fs, pathFilter, root, Map(), + table.partitionColumnNames, threshold, spark.sessionState.conf.resolver) val total = partitionSpecsAndLocs.length logInfo(s"Found $total partitions in $root") @@ -531,6 +529,11 @@ case class AlterTableRecoverPartitionsCommand( logInfo(s"Finished to gather the fast stats for all $total partitions.") addPartitions(spark, table, partitionSpecsAndLocs, partitionStats) + // Updates the table to indicate that its partition metadata is stored in the Hive metastore. + // This is always the case for Hive format tables, but is not true for Datasource tables created + // before Spark 2.1 unless they are converted via `msck repair table`. + spark.sessionState.catalog.alterTable(table.copy(partitionProviderIsHive = true)) + catalog.refreshTable(tableName) logInfo(s"Recovered all partitions ($total).") Seq.empty[Row] } @@ -544,7 +547,8 @@ case class AlterTableRecoverPartitionsCommand( path: Path, spec: TablePartitionSpec, partitionNames: Seq[String], - threshold: Int): GenSeq[(TablePartitionSpec, Path)] = { + threshold: Int, + resolver: Resolver): GenSeq[(TablePartitionSpec, Path)] = { if (partitionNames.isEmpty) { return Seq(spec -> path) } @@ -563,15 +567,15 @@ case class AlterTableRecoverPartitionsCommand( val name = st.getPath.getName if (st.isDirectory && name.contains("=")) { val ps = name.split("=", 2) - val columnName = PartitioningUtils.unescapePathName(ps(0)).toLowerCase + val columnName = PartitioningUtils.unescapePathName(ps(0)) // TODO: Validate the value val value = PartitioningUtils.unescapePathName(ps(1)) - // comparing with case-insensitive, but preserve the case - if (columnName == partitionNames.head) { - scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(columnName -> value), - partitionNames.drop(1), threshold) + if (resolver(columnName, partitionNames.head)) { + scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value), + partitionNames.drop(1), threshold, resolver) } else { - logWarning(s"expect partition column ${partitionNames.head}, but got ${ps(0)}, ignore it") + logWarning( + s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it") Seq() } } else { @@ -676,16 +680,11 @@ case class AlterTableSetLocationCommand( DDLUtils.verifyAlterTableType(catalog, table, isView = false) partitionSpec match { case Some(spec) => + DDLUtils.verifyPartitionProviderIsHive( + sparkSession, table, "ALTER TABLE ... SET LOCATION") // Partition spec is specified, so we set the location only for this partition val part = catalog.getPartition(table.identifier, spec) - val newPart = - if (DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - "ALTER TABLE SET LOCATION for partition is not allowed for tables defined " + - "using the datasource API") - } else { - part.copy(storage = part.storage.copy(locationUri = Some(location))) - } + val newPart = part.copy(storage = part.storage.copy(locationUri = Some(location))) catalog.alterPartitions(table.identifier, Seq(newPart)) case None => // No partition spec is specified, so we set the location for the table itself @@ -709,6 +708,25 @@ object DDLUtils { table.provider.isDefined && table.provider.get != "hive" } + /** + * Throws a standard error for actions that require partitionProvider = hive. + */ + def verifyPartitionProviderIsHive( + spark: SparkSession, table: CatalogTable, action: String): Unit = { + val tableName = table.identifier.table + if (!spark.sqlContext.conf.manageFilesourcePartitions && isDatasourceTable(table)) { + throw new AnalysisException( + s"$action is not allowed on $tableName since filesource partition management is " + + "disabled (spark.sql.hive.manageFilesourcePartitions = false).") + } + if (!table.partitionProviderIsHive && isDatasourceTable(table)) { + throw new AnalysisException( + s"$action is not allowed on $tableName since its partition metadata is not stored in " + + "the Hive metastore. To import this information into the metastore, run " + + s"`msck repair table $tableName`") + } + } + /** * If the command ALTER VIEW is to alter a table or ALTER TABLE is to alter a view, * issue an exception [[AnalysisException]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index aec25430b719..4acfffb62804 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -358,19 +358,16 @@ case class TruncateTableCommand( throw new AnalysisException( s"Operation not allowed: TRUNCATE TABLE on views: $tableIdentwithDB") } - val isDatasourceTable = DDLUtils.isDatasourceTable(table) - if (isDatasourceTable && partitionSpec.isDefined) { - throw new AnalysisException( - s"Operation not allowed: TRUNCATE TABLE ... PARTITION is not supported " + - s"for tables created using the data sources API: $tableIdentwithDB") - } if (table.partitionColumnNames.isEmpty && partitionSpec.isDefined) { throw new AnalysisException( s"Operation not allowed: TRUNCATE TABLE ... PARTITION is not supported " + s"for tables that are not partitioned: $tableIdentwithDB") } + if (partitionSpec.isDefined) { + DDLUtils.verifyPartitionProviderIsHive(spark, table, "TRUNCATE TABLE ... PARTITION") + } val locations = - if (isDatasourceTable) { + if (DDLUtils.isDatasourceTable(table)) { Seq(table.storage.properties.get("path")) } else if (table.partitionColumnNames.isEmpty) { Seq(table.storage.locationUri) @@ -453,7 +450,7 @@ case class DescribeTableCommand( describeFormattedTableInfo(metadata, result) } } else { - describeDetailedPartitionInfo(catalog, metadata, result) + describeDetailedPartitionInfo(sparkSession, catalog, metadata, result) } } @@ -492,6 +489,10 @@ case class DescribeTableCommand( describeStorageInfo(table, buffer) if (table.tableType == CatalogTableType.VIEW) describeViewInfo(table, buffer) + + if (DDLUtils.isDatasourceTable(table) && table.partitionProviderIsHive) { + append(buffer, "Partition Provider:", "Hive", "") + } } private def describeStorageInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { @@ -528,6 +529,7 @@ case class DescribeTableCommand( } private def describeDetailedPartitionInfo( + spark: SparkSession, catalog: SessionCatalog, metadata: CatalogTable, result: ArrayBuffer[Row]): Unit = { @@ -535,10 +537,7 @@ case class DescribeTableCommand( throw new AnalysisException( s"DESC PARTITION is not allowed on a view: ${table.identifier}") } - if (DDLUtils.isDatasourceTable(metadata)) { - throw new AnalysisException( - s"DESC PARTITION is not allowed on a datasource table: ${table.identifier}") - } + DDLUtils.verifyPartitionProviderIsHive(spark, metadata, "DESC PARTITION") val partition = catalog.getPartition(table, partitionSpec) if (isExtended) { describeExtendedDetailedPartitionInfo(table, metadata, partition, result) @@ -743,10 +742,7 @@ case class ShowPartitionsCommand( s"SHOW PARTITIONS is not allowed on a table that is not partitioned: $tableIdentWithDB") } - if (DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - s"SHOW PARTITIONS is not allowed on a datasource table: $tableIdentWithDB") - } + DDLUtils.verifyPartitionProviderIsHive(sparkSession, table, "SHOW PARTITIONS") /** * Validate the partitioning spec by making sure all the referenced columns are @@ -894,18 +890,11 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman private def showHiveTableProperties(metadata: CatalogTable, builder: StringBuilder): Unit = { if (metadata.properties.nonEmpty) { - val filteredProps = metadata.properties.filterNot { - // Skips "EXTERNAL" property for external tables - case (key, _) => key == "EXTERNAL" && metadata.tableType == EXTERNAL - } - - val props = filteredProps.map { case (key, value) => + val props = metadata.properties.map { case (key, value) => s"'${escapeSingleQuotedString(key)}' = '${escapeSingleQuotedString(value)}'" } - if (props.nonEmpty) { - builder ++= props.mkString("TBLPROPERTIES (\n ", ",\n ", "\n)\n") - } + builder ++= props.mkString("TBLPROPERTIES (\n ", ",\n ", "\n)\n") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 17da606580ee..5b8f05a39624 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -30,7 +30,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider @@ -65,6 +65,8 @@ import org.apache.spark.util.Utils * @param partitionColumns A list of column names that the relation is partitioned by. When this * list is empty, the relation is unpartitioned. * @param bucketSpec An optional specification for bucketing (hash-partitioning) of the data. + * @param catalogTable Optional catalog table reference that can be used to push down operations + * over the datasource to the catalog service. */ case class DataSource( sparkSession: SparkSession, @@ -73,7 +75,8 @@ case class DataSource( userSpecifiedSchema: Option[StructType] = None, partitionColumns: Seq[String] = Seq.empty, bucketSpec: Option[BucketSpec] = None, - options: Map[String, String] = Map.empty) extends Logging { + options: Map[String, String] = Map.empty, + catalogTable: Option[CatalogTable] = None) extends Logging { case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String]) @@ -412,9 +415,16 @@ case class DataSource( }) } - val fileCatalog = + val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && + catalogTable.isDefined && catalogTable.get.partitionProviderIsHive) { + new TableFileCatalog( + sparkSession, + catalogTable.get, + catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(0L)) + } else { new ListingFileCatalog( sparkSession, globbedPaths, options, partitionSchema) + } val dataSchema = userSpecifiedSchema.map { schema => val equality = sparkSession.sessionState.conf.resolver @@ -423,7 +433,7 @@ case class DataSource( format.inferSchema( sparkSession, caseInsensitiveOptions, - fileCatalog.allFiles()) + fileCatalog.asInstanceOf[ListingFileCatalog].allFiles()) }.getOrElse { throw new AnalysisException( s"Unable to infer schema for $format at ${allPaths.take(2).mkString(",")}. " + @@ -432,7 +442,7 @@ case class DataSource( HadoopFsRelation( fileCatalog, - partitionSchema = fileCatalog.partitionSpec().partitionColumns, + partitionSchema = fileCatalog.partitionSchema, dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7d0abe86a44d..f0bcf94eadc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -30,11 +30,11 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.command.{DDLUtils, ExecutedCommandExec} +import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, DDLUtils, ExecutedCommandExec} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -179,7 +179,7 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { "Cannot overwrite a path that is also being read from.") } - InsertIntoHadoopFsRelationCommand( + val insertCmd = InsertIntoHadoopFsRelationCommand( outputPath, query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver), t.bucketSpec, @@ -188,6 +188,15 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { t.options, query, mode) + + if (l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty && + l.catalogTable.get.partitionProviderIsHive) { + // TODO(ekl) we should be more efficient here and only recover the newly added partitions + val recoverPartitionCmd = AlterTableRecoverPartitionsCommand(l.catalogTable.get.identifier) + Union(insertCmd, recoverPartitionCmd) + } else { + insertCmd + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCatalog.scala index 2bc66ceeebdb..dba64624c34b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCatalog.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StructType /** * A collection of data files from a partitioned relation, along with the partition values in the @@ -63,4 +64,7 @@ trait FileCatalog { /** Sum of table file sizes, in bytes */ def sizeInBytes: Long + + /** Schema of the partitioning columns, or the empty schema if the table is not partitioned. */ + def partitionSchema: StructType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala index e0ec748a0b34..7c2e6fd04d5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala @@ -64,7 +64,7 @@ object FileStatusCache { */ def newCache(session: SparkSession): FileStatusCache = { synchronized { - if (session.sqlContext.conf.filesourcePartitionPruning && + if (session.sqlContext.conf.manageFilesourcePartitions && session.sqlContext.conf.filesourcePartitionFileCacheSize > 0) { if (sharedCache == null) { sharedCache = new SharedInMemoryCache( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala index 9b1903c47119..cc4049e92590 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala @@ -38,19 +38,21 @@ import org.apache.spark.util.SerializableConfiguration * It provides the necessary methods to parse partition data based on a set of files. * * @param parameters as set of options to control partition discovery - * @param partitionSchema an optional partition schema that will be use to provide types for the - * discovered partitions -*/ + * @param userPartitionSchema an optional partition schema that will be use to provide types for + * the discovered partitions + */ abstract class PartitioningAwareFileCatalog( sparkSession: SparkSession, parameters: Map[String, String], - partitionSchema: Option[StructType], + userPartitionSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends FileCatalog with Logging { import PartitioningAwareFileCatalog.BASE_PATH_PARAM /** Returns the specification of the partitions inferred from the data. */ def partitionSpec(): PartitionSpec + override def partitionSchema: StructType = partitionSpec().partitionColumns + protected val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(parameters) protected def leafFiles: mutable.LinkedHashMap[Path, FileStatus] @@ -122,7 +124,7 @@ abstract class PartitioningAwareFileCatalog( val leafDirs = leafDirToChildrenFiles.filter { case (_, files) => files.exists(f => isDataPath(f.getPath)) }.keys.toSeq - partitionSchema match { + userPartitionSchema match { case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => val spec = PartitioningUtils.parsePartitions( leafDirs, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala index 667379b222c4..b459df5734d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StructType /** @@ -45,6 +46,8 @@ class TableFileCatalog( private val baseLocation = table.storage.locationUri + override def partitionSchema: StructType = table.partitionSchema + override def rootPaths: Seq[Path] = baseLocation.map(new Path(_)).toSeq override def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] = { @@ -63,7 +66,6 @@ class TableFileCatalog( if (table.partitionColumnNames.nonEmpty) { val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( table.identifier, filters) - val partitionSchema = table.partitionSchema val partitions = selectedPartitions.map { p => PartitionPath(p.toRow(partitionSchema), p.storage.locationUri.get) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f47ec7f3963a..dc31f3bc323f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -272,18 +272,20 @@ object SQLConf { .booleanConf .createWithDefault(true) - val HIVE_FILESOURCE_PARTITION_PRUNING = - SQLConfigBuilder("spark.sql.hive.filesourcePartitionPruning") - .doc("When true, enable metastore partition pruning for filesource relations as well. " + - "This is currently implemented for converted Hive tables only.") + val HIVE_MANAGE_FILESOURCE_PARTITIONS = + SQLConfigBuilder("spark.sql.hive.manageFilesourcePartitions") + .doc("When true, enable metastore partition management for file source tables as well. " + + "This includes both datasource and converted Hive tables. When partition managment " + + "is enabled, datasource tables store partition in the Hive metastore, and use the " + + "metastore to prune partitions during query planning.") .booleanConf .createWithDefault(true) val HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE = SQLConfigBuilder("spark.sql.hive.filesourcePartitionFileCacheSize") - .doc("When nonzero, enable caching of partition file metadata in memory. All table share " + + .doc("When nonzero, enable caching of partition file metadata in memory. All tables share " + "a cache that can use up to specified num bytes for file metadata. This conf only " + - "applies if filesource partition pruning is also enabled.") + "has an effect when hive filesource partition management is enabled.") .longConf .createWithDefault(250 * 1024 * 1024) @@ -679,7 +681,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) - def filesourcePartitionPruning: Boolean = getConf(HIVE_FILESOURCE_PARTITION_PRUNING) + def manageFilesourcePartitions: Boolean = getConf(HIVE_MANAGE_FILESOURCE_PARTITIONS) def filesourcePartitionFileCacheSize: Long = getConf(HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 6857dd37286d..2d73d9f1fc80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -197,7 +197,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { assertResult(expected.schema, s"Schema did not match for query #$i\n${expected.sql}") { output.schema } - assertResult(expected.output, s"Result dit not match for query #$i\n${expected.sql}") { + assertResult(expected.output, s"Result did not match for query #$i\n${expected.sql}") { output.output } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index b989d01ec787..9fb0f5384d88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -95,7 +95,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { .add("b", "int"), provider = Some("hive"), partitionColumnNames = Seq("a", "b"), - createTime = 0L) + createTime = 0L, + partitionProviderIsHive = true) } private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { @@ -923,68 +924,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: rename partition") { - val catalog = spark.sessionState.catalog - val tableIdent = TableIdentifier("tab1", Some("dbx")) - createPartitionedTable(tableIdent, isDatasourceTable = false) - - // basic rename partition - sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") - sql("ALTER TABLE dbx.tab1 PARTITION (a='2', b='c') RENAME TO PARTITION (a='20', b='c')") - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(Map("a" -> "100", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) - - // rename without explicitly specifying database - catalog.setCurrentDatabase("dbx") - sql("ALTER TABLE tab1 PARTITION (a='100', b='p') RENAME TO PARTITION (a='10', b='p')") - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(Map("a" -> "10", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) - - // table to alter does not exist - intercept[NoSuchTableException] { - sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") - } - - // partition to rename does not exist - intercept[NoSuchPartitionException] { - sql("ALTER TABLE tab1 PARTITION (a='not_found', b='1') RENAME TO PARTITION (a='1', b='2')") - } - - // partition spec in RENAME PARTITION should be case insensitive by default - sql("ALTER TABLE tab1 PARTITION (A='10', B='p') RENAME TO PARTITION (A='1', B='p')") - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(Map("a" -> "1", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) + testRenamePartitions(isDatasourceTable = false) } test("alter table: rename partition (datasource table)") { - createPartitionedTable(TableIdentifier("tab1", Some("dbx")), isDatasourceTable = true) - val e = intercept[AnalysisException] { - sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") - }.getMessage - assert(e.contains( - "ALTER TABLE RENAME PARTITION is not allowed for tables defined using the datasource API")) - // table to alter does not exist - intercept[NoSuchTableException] { - sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") - } - } - - private def createPartitionedTable( - tableIdent: TableIdentifier, - isDatasourceTable: Boolean): Unit = { - val catalog = spark.sessionState.catalog - val part1 = Map("a" -> "1", "b" -> "q") - val part2 = Map("a" -> "2", "b" -> "c") - val part3 = Map("a" -> "3", "b" -> "p") - createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - createTablePartition(catalog, part1, tableIdent) - createTablePartition(catalog, part2, tableIdent) - createTablePartition(catalog, part3, tableIdent) - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(part1, part2, part3)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + testRenamePartitions(isDatasourceTable = true) } test("show tables") { @@ -1199,7 +1143,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { if (isDatasourceTable) { if (spec.isDefined) { assert(storageFormat.properties.isEmpty) - assert(storageFormat.locationUri.isEmpty) + assert(storageFormat.locationUri === Some(expected)) } else { assert(storageFormat.properties.get("path") === Some(expected)) assert(storageFormat.locationUri === Some(expected)) @@ -1212,18 +1156,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") verifyLocation("/path/to/your/lovely/heart") // set table partition location - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways'") - } + sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways'") verifyLocation("/path/to/part/ways", Some(partSpec)) // set table location without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET LOCATION '/swanky/steak/place'") verifyLocation("/swanky/steak/place") // set table partition location without explicitly specifying database - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 PARTITION (a='1', b='2') SET LOCATION 'vienna'") - } + sql("ALTER TABLE tab1 PARTITION (a='1', b='2') SET LOCATION 'vienna'") verifyLocation("vienna", Some(partSpec)) // table to alter does not exist intercept[AnalysisException] { @@ -1354,26 +1294,18 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) // basic add partition - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + - "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") - } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty) - assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris")) - assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) - } + sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + + "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) + assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris")) + assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) // add partitions without explicitly specifying database catalog.setCurrentDatabase("dbx") - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (a='4', b='8')") - } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(part1, part2, part3, part4)) - } + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (a='4', b='8')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) // table to alter does not exist intercept[AnalysisException] { @@ -1386,22 +1318,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } // partition to add already exists when using IF NOT EXISTS - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (a='4', b='8')") - } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(part1, part2, part3, part4)) - } + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (a='4', b='8')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) // partition spec in ADD PARTITION should be case insensitive by default - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 ADD PARTITION (A='9', B='9')") - } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(part1, part2, part3, part4, part5)) - } + sql("ALTER TABLE tab1 ADD PARTITION (A='9', B='9')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4, part5)) } private def testDropPartitions(isDatasourceTable: Boolean): Unit = { @@ -1424,21 +1348,13 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } // basic drop partition - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')") - } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) - } + sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) // drop partitions without explicitly specifying database catalog.setCurrentDatabase("dbx") - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (a='2', b ='6')") - } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) - } + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (a='2', b ='6')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) // table to alter does not exist intercept[AnalysisException] { @@ -1451,20 +1367,56 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } // partition to drop does not exist when using IF EXISTS - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (a='300')") - } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) - } + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (a='300')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) // partition spec in DROP PARTITION should be case insensitive by default - maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 DROP PARTITION (A='1', B='5')") + sql("ALTER TABLE tab1 DROP PARTITION (A='1', B='5')") + assert(catalog.listPartitions(tableIdent).isEmpty) + } + + private def testRenamePartitions(isDatasourceTable: Boolean): Unit = { + val catalog = spark.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1", "b" -> "q") + val part2 = Map("a" -> "2", "b" -> "c") + val part3 = Map("a" -> "3", "b" -> "p") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + createTablePartition(catalog, part2, tableIdent) + createTablePartition(catalog, part3, tableIdent) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + + // basic rename partition + sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") + sql("ALTER TABLE dbx.tab1 PARTITION (a='2', b='c') RENAME TO PARTITION (a='20', b='c')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "100", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) + + // rename without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 PARTITION (a='100', b='p') RENAME TO PARTITION (a='10', b='p')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "10", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) + + // table to alter does not exist + intercept[NoSuchTableException] { + sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") } - if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).isEmpty) + + // partition to rename does not exist + intercept[NoSuchPartitionException] { + sql("ALTER TABLE tab1 PARTITION (a='not_found', b='1') RENAME TO PARTITION (a='1', b='2')") } + + // partition spec in RENAME PARTITION should be case insensitive by default + sql("ALTER TABLE tab1 PARTITION (A='10', B='p') RENAME TO PARTITION (A='1', B='p')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "1", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) } test("drop build-in function") { @@ -1683,12 +1635,16 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - // truncating partitioned data source tables is not supported withTable("rectangles", "rectangles2") { data.write.saveAsTable("rectangles") data.write.partitionBy("length").saveAsTable("rectangles2") + + // not supported since the table is not partitioned assertUnsupported("TRUNCATE TABLE rectangles PARTITION (width=1)") - assertUnsupported("TRUNCATE TABLE rectangles2 PARTITION (width=1)") + + // supported since partitions are stored in the metastore + sql("TRUNCATE TABLE rectangles2 PARTITION (width=1)") + assert(spark.table("rectangles2").collect().isEmpty) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 2003ff42d4f0..409c316c6802 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.sql.types.{DataType, StructField, StructType} @@ -105,13 +106,11 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * metastore. */ private def verifyTableProperties(table: CatalogTable): Unit = { - val invalidKeys = table.properties.keys.filter { key => - key.startsWith(DATASOURCE_PREFIX) || key.startsWith(STATISTICS_PREFIX) - } + val invalidKeys = table.properties.keys.filter(_.startsWith(SPARK_SQL_PREFIX)) if (invalidKeys.nonEmpty) { throw new AnalysisException(s"Cannot persistent ${table.qualifiedName} into hive metastore " + - s"as table property keys may not start with '$DATASOURCE_PREFIX' or '$STATISTICS_PREFIX':" + - s" ${invalidKeys.mkString("[", ", ", "]")}") + s"as table property keys may not start with '$SPARK_SQL_PREFIX': " + + invalidKeys.mkString("[", ", ", "]")) } // External users are not allowed to set/switch the table type. In Hive metastore, the table // type can be switched by changing the value of a case-sensitive table property `EXTERNAL`. @@ -190,11 +189,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat throw new TableAlreadyExistsException(db = db, table = table) } // Before saving data source table metadata into Hive metastore, we should: - // 1. Put table schema, partition column names and bucket specification in table properties. + // 1. Put table provider, schema, partition column names, bucket specification and partition + // provider in table properties. // 2. Check if this table is hive compatible // 2.1 If it's not hive compatible, set schema, partition columns and bucket spec to empty // and save table metadata to Hive. - // 2.1 If it's hive compatible, set serde information in table metadata and try to save + // 2.2 If it's hive compatible, set serde information in table metadata and try to save // it to Hive. If it fails, treat it as not hive compatible and go back to 2.1 if (DDLUtils.isDatasourceTable(tableDefinition)) { // data source table always have a provider, it's guaranteed by `DDLUtils.isDatasourceTable`. @@ -204,6 +204,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val tableProperties = new scala.collection.mutable.HashMap[String, String] tableProperties.put(DATASOURCE_PROVIDER, provider) + if (tableDefinition.partitionProviderIsHive) { + tableProperties.put(TABLE_PARTITION_PROVIDER, "hive") + } // Serialized JSON schema string may be too long to be stored into a single metastore table // property. In this case, we split the JSON string and store each part as a separate table @@ -241,12 +244,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - // converts the table metadata to Spark SQL specific format, i.e. set schema, partition column - // names and bucket specification to empty. + // converts the table metadata to Spark SQL specific format, i.e. set data schema, names and + // bucket specification to empty. Note that partition columns are retained, so that we can + // call partition-related Hive API later. def newSparkSQLSpecificMetastoreTable(): CatalogTable = { tableDefinition.copy( - schema = new StructType, - partitionColumnNames = Nil, + schema = tableDefinition.partitionSchema, bucketSpec = None, properties = tableDefinition.properties ++ tableProperties) } @@ -419,12 +422,17 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Sets the `schema`, `partitionColumnNames` and `bucketSpec` from the old table definition, // to retain the spark specific format if it is. Also add old data source properties to table // properties, to retain the data source table format. - val oldDataSourceProps = oldDef.properties.filter(_._1.startsWith(DATASOURCE_PREFIX)) + val oldDataSourceProps = oldDef.properties.filter(_._1.startsWith(SPARK_SQL_PREFIX)) + val partitionProviderProp = if (tableDefinition.partitionProviderIsHive) { + TABLE_PARTITION_PROVIDER -> "hive" + } else { + TABLE_PARTITION_PROVIDER -> "builtin" + } val newDef = withStatsProps.copy( schema = oldDef.schema, partitionColumnNames = oldDef.partitionColumnNames, bucketSpec = oldDef.bucketSpec, - properties = oldDataSourceProps ++ withStatsProps.properties) + properties = oldDataSourceProps ++ withStatsProps.properties + partitionProviderProp) client.alterTable(newDef) } else { @@ -448,7 +456,11 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * properties, and filter out these special entries from table properties. */ private def restoreTableMetadata(table: CatalogTable): CatalogTable = { - val catalogTable = if (table.tableType == VIEW || conf.get(DEBUG_MODE)) { + if (conf.get(DEBUG_MODE)) { + return table + } + + val tableWithSchema = if (table.tableType == VIEW) { table } else { getProviderFromTableProperties(table).map { provider => @@ -473,30 +485,32 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat provider = Some(provider), partitionColumnNames = getPartitionColumnsFromTableProperties(table), bucketSpec = getBucketSpecFromTableProperties(table), - properties = getOriginalTableProperties(table)) + partitionProviderIsHive = table.properties.get(TABLE_PARTITION_PROVIDER) == Some("hive")) } getOrElse { - table.copy(provider = Some("hive")) + table.copy(provider = Some("hive"), partitionProviderIsHive = true) } } + // construct Spark's statistics from information in Hive metastore - val statsProps = catalogTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) - if (statsProps.nonEmpty) { + val statsProps = tableWithSchema.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + val tableWithStats = if (statsProps.nonEmpty) { val colStatsProps = statsProps.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)) .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } - val colStats: Map[String, ColumnStat] = catalogTable.schema.collect { + val colStats: Map[String, ColumnStat] = tableWithSchema.schema.collect { case f if colStatsProps.contains(f.name) => val numFields = ColumnStatStruct.numStatFields(f.dataType) (f.name, ColumnStat(numFields, colStatsProps(f.name))) }.toMap - catalogTable.copy( - properties = removeStatsProperties(catalogTable), + tableWithSchema.copy( stats = Some(Statistics( - sizeInBytes = BigInt(catalogTable.properties(STATISTICS_TOTAL_SIZE)), - rowCount = catalogTable.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), + sizeInBytes = BigInt(tableWithSchema.properties(STATISTICS_TOTAL_SIZE)), + rowCount = tableWithSchema.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), colStats = colStats))) } else { - catalogTable + tableWithSchema } + + tableWithStats.copy(properties = getOriginalTableProperties(table)) } override def tableExists(db: String, table: String): Boolean = withClient { @@ -581,13 +595,30 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Partitions // -------------------------------------------------------------------------- + // Hive metastore is not case preserving and the partition columns are always lower cased. We need + // to lower case the column names in partition specification before calling partition related Hive + // APIs, to match this behaviour. + private def lowerCasePartitionSpec(spec: TablePartitionSpec): TablePartitionSpec = { + spec.map { case (k, v) => k.toLowerCase -> v } + } + + // Hive metastore is not case preserving and the column names of the partition specification we + // get from the metastore are always lower cased. We should restore them w.r.t. the actual table + // partition columns. + private def restorePartitionSpec( + spec: TablePartitionSpec, + partCols: Seq[String]): TablePartitionSpec = { + spec.map { case (k, v) => partCols.find(_.equalsIgnoreCase(k)).get -> v } + } + override def createPartitions( db: String, table: String, parts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit = withClient { requireTableExists(db, table) - client.createPartitions(db, table, parts, ignoreIfExists) + val lowerCasedParts = parts.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) + client.createPartitions(db, table, lowerCasedParts, ignoreIfExists) } override def dropPartitions( @@ -597,7 +628,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat ignoreIfNotExists: Boolean, purge: Boolean): Unit = withClient { requireTableExists(db, table) - client.dropPartitions(db, table, parts, ignoreIfNotExists, purge) + client.dropPartitions(db, table, parts.map(lowerCasePartitionSpec), ignoreIfNotExists, purge) } override def renamePartitions( @@ -605,21 +636,24 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table: String, specs: Seq[TablePartitionSpec], newSpecs: Seq[TablePartitionSpec]): Unit = withClient { - client.renamePartitions(db, table, specs, newSpecs) + client.renamePartitions( + db, table, specs.map(lowerCasePartitionSpec), newSpecs.map(lowerCasePartitionSpec)) } override def alterPartitions( db: String, table: String, newParts: Seq[CatalogTablePartition]): Unit = withClient { - client.alterPartitions(db, table, newParts) + val lowerCasedParts = newParts.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) + client.alterPartitions(db, table, lowerCasedParts) } override def getPartition( db: String, table: String, spec: TablePartitionSpec): CatalogTablePartition = withClient { - client.getPartition(db, table, spec) + val part = client.getPartition(db, table, lowerCasePartitionSpec(spec)) + part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames)) } /** @@ -629,7 +663,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat db: String, table: String, spec: TablePartitionSpec): Option[CatalogTablePartition] = withClient { - client.getPartitionOption(db, table, spec) + client.getPartitionOption(db, table, lowerCasePartitionSpec(spec)).map { part => + part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames)) + } } /** @@ -639,14 +675,17 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat db: String, table: String, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = withClient { - client.getPartitions(db, table, partialSpec) + client.getPartitions(db, table, partialSpec.map(lowerCasePartitionSpec)).map { part => + part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames)) + } } override def listPartitionsByFilter( db: String, table: String, predicates: Seq[Expression]): Seq[CatalogTablePartition] = withClient { - val catalogTable = client.getTable(db, table) + val rawTable = client.getTable(db, table) + val catalogTable = restoreTableMetadata(rawTable) val partitionColumnNames = catalogTable.partitionColumnNames.toSet val nonPartitionPruningPredicates = predicates.filterNot { _.references.map(_.name).toSet.subsetOf(partitionColumnNames) @@ -660,19 +699,20 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val partitionSchema = catalogTable.partitionSchema if (predicates.nonEmpty) { - val clientPrunedPartitions = - client.getPartitionsByFilter(catalogTable, predicates) + val clientPrunedPartitions = client.getPartitionsByFilter(rawTable, predicates).map { part => + part.copy(spec = restorePartitionSpec(part.spec, catalogTable.partitionColumnNames)) + } val boundPredicate = InterpretedPredicate.create(predicates.reduce(And).transform { case att: AttributeReference => val index = partitionSchema.indexWhere(_.name == att.name) BoundReference(index, partitionSchema(index).dataType, nullable = true) }) - clientPrunedPartitions.filter { case p: CatalogTablePartition => - boundPredicate(p.toRow(partitionSchema)) - } + clientPrunedPartitions.filter { p => boundPredicate(p.toRow(partitionSchema)) } } else { - client.getPartitions(catalogTable) + client.getPartitions(catalogTable).map { part => + part.copy(spec = restorePartitionSpec(part.spec, catalogTable.partitionColumnNames)) + } } } @@ -722,7 +762,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } object HiveExternalCatalog { - val DATASOURCE_PREFIX = "spark.sql.sources." + val SPARK_SQL_PREFIX = "spark.sql." + + val DATASOURCE_PREFIX = SPARK_SQL_PREFIX + "sources." val DATASOURCE_PROVIDER = DATASOURCE_PREFIX + "provider" val DATASOURCE_SCHEMA = DATASOURCE_PREFIX + "schema" val DATASOURCE_SCHEMA_PREFIX = DATASOURCE_SCHEMA + "." @@ -736,21 +778,20 @@ object HiveExternalCatalog { val DATASOURCE_SCHEMA_BUCKETCOL_PREFIX = DATASOURCE_SCHEMA_PREFIX + "bucketCol." val DATASOURCE_SCHEMA_SORTCOL_PREFIX = DATASOURCE_SCHEMA_PREFIX + "sortCol." - val STATISTICS_PREFIX = "spark.sql.statistics." + val STATISTICS_PREFIX = SPARK_SQL_PREFIX + "statistics." val STATISTICS_TOTAL_SIZE = STATISTICS_PREFIX + "totalSize" val STATISTICS_NUM_ROWS = STATISTICS_PREFIX + "numRows" val STATISTICS_COL_STATS_PREFIX = STATISTICS_PREFIX + "colStats." - def removeStatsProperties(metadata: CatalogTable): Map[String, String] = { - metadata.properties.filterNot { case (key, _) => key.startsWith(STATISTICS_PREFIX) } - } + val TABLE_PARTITION_PROVIDER = SPARK_SQL_PREFIX + "partitionProvider" + def getProviderFromTableProperties(metadata: CatalogTable): Option[String] = { metadata.properties.get(DATASOURCE_PROVIDER) } def getOriginalTableProperties(metadata: CatalogTable): Map[String, String] = { - metadata.properties.filterNot { case (key, _) => key.startsWith(DATASOURCE_PREFIX) } + metadata.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) } } // A persisted data source table always store its schema in the catalog. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6c1585d5f561..d1de863ce362 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -76,11 +76,10 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log partitionColumns = table.partitionColumnNames, bucketSpec = table.bucketSpec, className = table.provider.get, - options = table.storage.properties) + options = table.storage.properties, + catalogTable = Some(table)) - LogicalRelation( - dataSource.resolveRelation(), - catalogTable = Some(table)) + LogicalRelation(dataSource.resolveRelation(), catalogTable = Some(table)) } } @@ -194,7 +193,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) val bucketSpec = None // We don't support hive bucketed tables, only ones we write out. - val lazyPruningEnabled = sparkSession.sqlContext.conf.filesourcePartitionPruning + val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 8835b266b22a..84873bbbb81c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -777,7 +777,7 @@ private[hive] class HiveClientImpl( val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => table.partitionColumnNames.contains(c.getName) } - if (table.schema.isEmpty) { + if (schema.isEmpty) { // This is a hack to preserve existing behavior. Before Spark 2.0, we do not // set a default serde here (this was done in Hive), and so if the user provides // an empty schema Hive would automatically populate the schema with a single @@ -831,9 +831,6 @@ private[hive] class HiveClientImpl( new HivePartition(ht, tpart) } - // TODO (cloud-fan): the column names in partition specification are always lower cased because - // Hive metastore is not case preserving. We should normalize them to the actual column names of - // the table, once we store partition spec of data source tables. private def fromHivePartition(hp: HivePartition): CatalogTablePartition = { val apiPartition = hp.getTPartition CatalogTablePartition( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala index d290fe9962db..6e887d95c0f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala @@ -63,7 +63,7 @@ class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSi def testCaching(pruningEnabled: Boolean): Unit = { test(s"partitioned table is cached when partition pruning is $pruningEnabled") { - withSQLConf(SQLConf.HIVE_FILESOURCE_PARTITION_PRUNING.key -> pruningEnabled.toString) { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> pruningEnabled.toString) { withTable("test") { withTempDir { dir => spark.range(5).selectExpr("id", "id as f1", "id as f2").write diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala new file mode 100644 index 000000000000..5f16960fb149 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +class PartitionProviderCompatibilitySuite + extends QueryTest with TestHiveSingleton with SQLTestUtils { + + private def setupPartitionedDatasourceTable(tableName: String, dir: File): Unit = { + spark.range(5).selectExpr("id as fieldOne", "id as partCol").write + .partitionBy("partCol") + .mode("overwrite") + .parquet(dir.getAbsolutePath) + + spark.sql(s""" + |create table $tableName (fieldOne long, partCol int) + |using parquet + |options (path "${dir.getAbsolutePath}") + |partitioned by (partCol)""".stripMargin) + } + + private def verifyIsLegacyTable(tableName: String): Unit = { + val unsupportedCommands = Seq( + s"ALTER TABLE $tableName ADD PARTITION (partCol=1) LOCATION '/foo'", + s"ALTER TABLE $tableName PARTITION (partCol=1) RENAME TO PARTITION (partCol=2)", + s"ALTER TABLE $tableName PARTITION (partCol=1) SET LOCATION '/foo'", + s"ALTER TABLE $tableName DROP PARTITION (partCol=1)", + s"DESCRIBE $tableName PARTITION (partCol=1)", + s"SHOW PARTITIONS $tableName") + + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + for (cmd <- unsupportedCommands) { + val e = intercept[AnalysisException] { + spark.sql(cmd) + } + assert(e.getMessage.contains("partition metadata is not stored in the Hive metastore"), e) + } + } + } + + test("convert partition provider to hive with repair table") { + withTable("test") { + withTempDir { dir => + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { + setupPartitionedDatasourceTable("test", dir) + assert(spark.sql("select * from test").count() == 5) + } + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + verifyIsLegacyTable("test") + spark.sql("msck repair table test") + spark.sql("show partitions test").count() // check we are a new table + + // sanity check table performance + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol < 2").count() == 2) + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 2) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 2) + } + } + } + } + + test("when partition management is enabled, new tables have partition provider hive") { + withTable("test") { + withTempDir { dir => + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + setupPartitionedDatasourceTable("test", dir) + spark.sql("show partitions test").count() // check we are a new table + assert(spark.sql("select * from test").count() == 0) // needs repair + spark.sql("msck repair table test") + assert(spark.sql("select * from test").count() == 5) + } + } + } + } + + test("when partition management is disabled, new tables have no partition provider") { + withTable("test") { + withTempDir { dir => + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { + setupPartitionedDatasourceTable("test", dir) + verifyIsLegacyTable("test") + assert(spark.sql("select * from test").count() == 5) + } + } + } + } + + test("when partition management is disabled, we preserve the old behavior even for new tables") { + withTable("test") { + withTempDir { dir => + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + setupPartitionedDatasourceTable("test", dir) + spark.sql("show partitions test").count() // check we are a new table + spark.sql("refresh table test") + assert(spark.sql("select * from test").count() == 0) + } + // disabled + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { + val e = intercept[AnalysisException] { + spark.sql(s"show partitions test") + } + assert(e.getMessage.contains("filesource partition management is disabled")) + spark.sql("refresh table test") + assert(spark.sql("select * from test").count() == 5) + } + // then enabled again + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + spark.sql("refresh table test") + assert(spark.sql("select * from test").count() == 0) + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala similarity index 68% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveTablePerfStatsSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 82ee813c6a95..476383a5b33a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils -class HiveTablePerfStatsSuite +class PartitionedTablePerfStatsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils with BeforeAndAfterEach { override def beforeEach(): Unit = { @@ -41,25 +41,54 @@ class HiveTablePerfStatsSuite FileStatusCache.resetForTesting() } - private def setupPartitionedTable(tableName: String, dir: File): Unit = { - spark.range(5).selectExpr("id", "id as partCol1", "id as partCol2").write + private case class TestSpec(setupTable: (String, File) => Unit, isDatasourceTable: Boolean) + + /** + * Runs a test against both converted hive and native datasource tables. The test can use the + * passed TestSpec object for setup and inspecting test parameters. + */ + private def genericTest(testName: String)(fn: TestSpec => Unit): Unit = { + test("hive table: " + testName) { + fn(TestSpec(setupPartitionedHiveTable, false)) + } + test("datasource table: " + testName) { + fn(TestSpec(setupPartitionedDatasourceTable, true)) + } + } + + private def setupPartitionedHiveTable(tableName: String, dir: File): Unit = { + spark.range(5).selectExpr("id as fieldOne", "id as partCol1", "id as partCol2").write .partitionBy("partCol1", "partCol2") .mode("overwrite") .parquet(dir.getAbsolutePath) spark.sql(s""" - |create external table $tableName (id long) + |create external table $tableName (fieldOne long) |partitioned by (partCol1 int, partCol2 int) |stored as parquet |location "${dir.getAbsolutePath}"""".stripMargin) spark.sql(s"msck repair table $tableName") } - test("partitioned pruned table reports only selected files") { + private def setupPartitionedDatasourceTable(tableName: String, dir: File): Unit = { + spark.range(5).selectExpr("id as fieldOne", "id as partCol1", "id as partCol2").write + .partitionBy("partCol1", "partCol2") + .mode("overwrite") + .parquet(dir.getAbsolutePath) + + spark.sql(s""" + |create table $tableName (fieldOne long, partCol1 int, partCol2 int) + |using parquet + |options (path "${dir.getAbsolutePath}") + |partitioned by (partCol1, partCol2)""".stripMargin) + spark.sql(s"msck repair table $tableName") + } + + genericTest("partitioned pruned table reports only selected files") { spec => assert(spark.sqlContext.getConf(HiveUtils.CONVERT_METASTORE_PARQUET.key) == "true") withTable("test") { withTempDir { dir => - setupPartitionedTable("test", dir) + spec.setupTable("test", dir) val df = spark.sql("select * from test") assert(df.count() == 5) assert(df.inputFiles.length == 5) // unpruned @@ -75,17 +104,24 @@ class HiveTablePerfStatsSuite val df4 = spark.sql("select * from test where partCol1 = 999") assert(df4.count() == 0) assert(df4.inputFiles.length == 0) + + // TODO(ekl) enable for hive tables as well once SPARK-17983 is fixed + if (spec.isDatasourceTable) { + val df5 = spark.sql("select * from test where fieldOne = 4") + assert(df5.count() == 1) + assert(df5.inputFiles.length == 5) + } } } } - test("lazy partition pruning reads only necessary partition data") { + genericTest("lazy partition pruning reads only necessary partition data") { spec => withSQLConf( - SQLConf.HIVE_FILESOURCE_PARTITION_PRUNING.key -> "true", + SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true", SQLConf.HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE.key -> "0") { withTable("test") { withTempDir { dir => - setupPartitionedTable("test", dir) + spec.setupTable("test", dir) HiveCatalogMetrics.reset() spark.sql("select * from test where partCol1 = 999").count() assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0) @@ -120,13 +156,13 @@ class HiveTablePerfStatsSuite } } - test("lazy partition pruning with file status caching enabled") { + genericTest("lazy partition pruning with file status caching enabled") { spec => withSQLConf( - "spark.sql.hive.filesourcePartitionPruning" -> "true", - "spark.sql.hive.filesourcePartitionFileCacheSize" -> "9999999") { + SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true", + SQLConf.HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE.key -> "9999999") { withTable("test") { withTempDir { dir => - setupPartitionedTable("test", dir) + spec.setupTable("test", dir) HiveCatalogMetrics.reset() assert(spark.sql("select * from test where partCol1 = 999").count() == 0) assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0) @@ -161,13 +197,13 @@ class HiveTablePerfStatsSuite } } - test("file status caching respects refresh table and refreshByPath") { + genericTest("file status caching respects refresh table and refreshByPath") { spec => withSQLConf( - "spark.sql.hive.filesourcePartitionPruning" -> "true", - "spark.sql.hive.filesourcePartitionFileCacheSize" -> "9999999") { + SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true", + SQLConf.HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE.key -> "9999999") { withTable("test") { withTempDir { dir => - setupPartitionedTable("test", dir) + spec.setupTable("test", dir) HiveCatalogMetrics.reset() assert(spark.sql("select * from test").count() == 5) assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) @@ -190,13 +226,13 @@ class HiveTablePerfStatsSuite } } - test("file status cache respects size limit") { + genericTest("file status cache respects size limit") { spec => withSQLConf( - "spark.sql.hive.filesourcePartitionPruning" -> "true", - "spark.sql.hive.filesourcePartitionFileCacheSize" -> "1" /* 1 byte */) { + SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true", + SQLConf.HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE.key -> "1" /* 1 byte */) { withTable("test") { withTempDir { dir => - setupPartitionedTable("test", dir) + spec.setupTable("test", dir) HiveCatalogMetrics.reset() assert(spark.sql("select * from test").count() == 5) assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) @@ -209,11 +245,11 @@ class HiveTablePerfStatsSuite } } - test("all partitions read and cached when filesource partition pruning is off") { - withSQLConf(SQLConf.HIVE_FILESOURCE_PARTITION_PRUNING.key -> "false") { + test("hive table: files read and cached when filesource partition management is off") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { withTable("test") { withTempDir { dir => - setupPartitionedTable("test", dir) + setupPartitionedHiveTable("test", dir) // We actually query the partitions from hive each time the table is resolved in this // mode. This is kind of terrible, but is needed to preserve the legacy behavior @@ -237,4 +273,32 @@ class HiveTablePerfStatsSuite } } } + + test("datasource table: all partition data cached in memory when partition management is off") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { + withTable("test") { + withTempDir { dir => + setupPartitionedDatasourceTable("test", dir) + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol1 = 999").count() == 0) + + // not using metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0) + + // reads and caches all the files initially + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol1 < 2").count() == 2) + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test").count() == 5) + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 0) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 0) + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index c351063a63ff..4f5ebc3d838b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -310,39 +310,50 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } - test("test table-level statistics for data source table created in HiveExternalCatalog") { - val parquetTable = "parquetTable" - withTable(parquetTable) { - sql(s"CREATE TABLE $parquetTable (key STRING, value STRING) USING PARQUET") - val catalogTable = spark.sessionState.catalog.getTableMetadata(TableIdentifier(parquetTable)) - assert(DDLUtils.isDatasourceTable(catalogTable)) + private def testUpdatingTableStats(tableDescription: String, createTableCmd: String): Unit = { + test("test table-level statistics for " + tableDescription) { + val parquetTable = "parquetTable" + withTable(parquetTable) { + sql(createTableCmd) + val catalogTable = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(parquetTable)) + assert(DDLUtils.isDatasourceTable(catalogTable)) + + sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") + checkTableStats( + parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") - checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) + // noscan won't count the number of rows + sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") + val fetchedStats1 = checkTableStats( + parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) - // noscan won't count the number of rows - sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats1 = checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) + sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") + sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") + val fetchedStats2 = checkTableStats( + parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") - sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats2 = checkTableStats( - parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) - assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) - - // without noscan, we count the number of rows - sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - val fetchedStats3 = checkTableStats( - parquetTable, - isDataSourceTable = true, - hasSizeInBytes = true, - expectedRowCounts = Some(1000)) - assert(fetchedStats3.get.sizeInBytes == fetchedStats2.get.sizeInBytes) + // without noscan, we count the number of rows + sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") + val fetchedStats3 = checkTableStats( + parquetTable, + isDataSourceTable = true, + hasSizeInBytes = true, + expectedRowCounts = Some(1000)) + assert(fetchedStats3.get.sizeInBytes == fetchedStats2.get.sizeInBytes) + } } } + testUpdatingTableStats( + "data source table created in HiveExternalCatalog", + "CREATE TABLE parquetTable (key STRING, value STRING) USING PARQUET") + + testUpdatingTableStats( + "partitioned data source table", + "CREATE TABLE parquetTable (key STRING, value STRING) USING PARQUET PARTITIONED BY (key)") + test("statistics collection of a table with zero column") { val table_no_cols = "table_no_cols" withTable(table_no_cols) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala index ad1e9b17a9f7..46ed18c70fb5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala @@ -415,10 +415,7 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto .mode(SaveMode.Overwrite) .saveAsTable("part_datasrc") - val message1 = intercept[AnalysisException] { - sql("SHOW PARTITIONS part_datasrc") - }.getMessage - assert(message1.contains("is not allowed on a datasource table")) + assert(sql("SHOW PARTITIONS part_datasrc").count() == 3) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 495b4f874a1d..01fa827220c5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -358,7 +358,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { "# Partition Information", "# col_name", "Detailed Partition Information CatalogPartition(", - "Partition Values: [Us, 1]", + "Partition Values: [c=Us, d=1]", "Storage(Location:", "Partition Parameters") @@ -399,10 +399,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { .range(1).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write .partitionBy("d") .saveAsTable("datasource_table") - val m4 = intercept[AnalysisException] { - sql("DESC datasource_table PARTITION (d=2)") - }.getMessage() - assert(m4.contains("DESC PARTITION is not allowed on a datasource table")) + + sql("DESC datasource_table PARTITION (d=0)") val m5 = intercept[AnalysisException] { spark.range(10).select('id as 'a, 'id as 'b).createTempView("view1") From ab5f938bc7c3c9b137d63e479fced2b7e9c9d75b Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Fri, 28 Oct 2016 08:39:02 +0800 Subject: [PATCH 043/381] [SPARK-18121][SQL] Unable to query global temp views when hive support is enabled ## What changes were proposed in this pull request? Issue: Querying on a global temp view throws Table or view not found exception. Fix: Update the lookupRelation in HiveSessionCatalog to check for global temp views similar to the SessionCatalog.lookupRelation. Before fix: Querying on a global temp view ( for. e.g.: select * from global_temp.v1) throws Table or view not found exception After fix: Query succeeds and returns the right result. ## How was this patch tested? - Two unit tests are added to check for global temp view for the code path when hive support is enabled. - Regression unit tests were run successfully. ( build/sbt -Phive hive/test, build/sbt sql/test, build/sbt catalyst/test) Author: Sunitha Kambhampati Closes #15649 from skambha/lookuprelationChanges. --- .../spark/sql/hive/HiveSessionCatalog.scala | 10 ++++++++-- .../spark/sql/hive/execution/SQLQuerySuite.scala | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 85ecf0ce7075..4f2910abfd21 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, Gener import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchTableException} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExpressionInfo} @@ -57,7 +57,13 @@ private[sql] class HiveSessionCatalog( override def lookupRelation(name: TableIdentifier, alias: Option[String]): LogicalPlan = { val table = formatTableName(name.table) - if (name.database.isDefined || !tempTables.contains(table)) { + val db = formatDatabaseName(name.database.getOrElse(currentDb)) + if (db == globalTempViewManager.database) { + val relationAlias = alias.getOrElse(table) + globalTempViewManager.get(table).map { viewDef => + SubqueryAlias(relationAlias, viewDef, Some(name)) + }.getOrElse(throw new NoSuchTableException(db, table)) + } else if (name.database.isDefined || !tempTables.contains(table)) { val database = name.database.map(formatDatabaseName) val newName = name.copy(database = database, table = table) metastoreCatalog.lookupRelation(newName, alias) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 01fa827220c5..2735d3a5267e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -68,6 +68,22 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import hiveContext._ import spark.implicits._ + test("query global temp view") { + val df = Seq(1).toDF("i1") + df.createGlobalTempView("tbl1") + val global_temp_db = spark.conf.get("spark.sql.globalTempDatabase") + checkAnswer(spark.sql(s"select * from ${global_temp_db}.tbl1"), Row(1)) + spark.sql(s"drop view ${global_temp_db}.tbl1") + } + + test("non-existent global temp view") { + val global_temp_db = spark.conf.get("spark.sql.globalTempDatabase") + val message = intercept[AnalysisException] { + spark.sql(s"select * from ${global_temp_db}.nonexistentview") + }.getMessage + assert(message.contains("Table or view not found")) + } + test("script") { val scriptFilePath = getTestResourcePath("test_script.sh") if (testCommandAvailable("bash") && testCommandAvailable("echo | sed")) { From 569788a55e4c6b218fb697e1e54c6138ffe657a6 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 28 Oct 2016 00:40:06 -0700 Subject: [PATCH 044/381] [SPARK-18109][ML] Add instrumentation to GMM ## What changes were proposed in this pull request? Add instrumentation to GMM ## How was this patch tested? Test in spark-shell Author: Zheng RuiFeng Closes #15636 from zhengruifeng/gmm_instr. --- .../org/apache/spark/ml/clustering/GaussianMixture.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index e3cb92f4f144..8fac63fefbb5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -323,6 +323,9 @@ class GaussianMixture @Since("2.0.0") ( case Row(point: Vector) => OldVectors.fromML(point) } + val instr = Instrumentation.create(this, rdd) + instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol) + val algo = new MLlibGM() .setK($(k)) .setMaxIterations($(maxIter)) @@ -337,6 +340,9 @@ class GaussianMixture @Since("2.0.0") ( val summary = new GaussianMixtureSummary(model.transform(dataset), $(predictionCol), $(probabilityCol), $(featuresCol), $(k)) model.setSummary(summary) + instr.logNumFeatures(model.gaussians.head.mean.size) + instr.logSuccess(model) + model } @Since("2.0.0") From e9746f87d0b553b8115948acb79f7e32c23dfd86 Mon Sep 17 00:00:00 2001 From: Jagadeesan Date: Fri, 28 Oct 2016 02:26:55 -0700 Subject: [PATCH 045/381] =?UTF-8?q?[SPARK-18133][EXAMPLES][ML]=20Python=20?= =?UTF-8?q?ML=20Pipeline=20Example=20has=20syntax=20e=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In Python 3, there is only one integer type (i.e., int), which mostly behaves like the long type in Python 2. Since Python 3 won't accept "L", so removed "L" in all examples. ## How was this patch tested? Unit tests. …rrors] Author: Jagadeesan Closes #15660 from jagadeesanas2/SPARK-18133. --- examples/src/main/python/ml/cross_validator.py | 8 ++++---- .../main/python/ml/gaussian_mixture_example.py | 2 +- examples/src/main/python/ml/pipeline_example.py | 16 ++++++++-------- .../binary_classification_metrics_example.py | 2 +- .../python/mllib/multi_class_metrics_example.py | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py index 907eec67a0eb..db7054307c2e 100644 --- a/examples/src/main/python/ml/cross_validator.py +++ b/examples/src/main/python/ml/cross_validator.py @@ -84,10 +84,10 @@ # Prepare test documents, which are unlabeled. test = spark.createDataFrame([ - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop") + (4, "spark i j k"), + (5, "l m n"), + (6, "mapreduce spark"), + (7, "apache hadoop") ], ["id", "text"]) # Make predictions on test documents. cvModel uses the best model found (lrModel). diff --git a/examples/src/main/python/ml/gaussian_mixture_example.py b/examples/src/main/python/ml/gaussian_mixture_example.py index 8ad450b669fc..e4a0d314e9d9 100644 --- a/examples/src/main/python/ml/gaussian_mixture_example.py +++ b/examples/src/main/python/ml/gaussian_mixture_example.py @@ -38,7 +38,7 @@ # loads data dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt") - gmm = GaussianMixture().setK(2).setSeed(538009335L) + gmm = GaussianMixture().setK(2).setSeed(538009335) model = gmm.fit(dataset) print("Gaussians shown as a DataFrame: ") diff --git a/examples/src/main/python/ml/pipeline_example.py b/examples/src/main/python/ml/pipeline_example.py index f63e4db43422..e1fab7cbe6d8 100644 --- a/examples/src/main/python/ml/pipeline_example.py +++ b/examples/src/main/python/ml/pipeline_example.py @@ -35,10 +35,10 @@ # $example on$ # Prepare training documents from a list of (id, text, label) tuples. training = spark.createDataFrame([ - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0) + (0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0) ], ["id", "text", "label"]) # Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. @@ -52,10 +52,10 @@ # Prepare test documents, which are unlabeled (id, text) tuples. test = spark.createDataFrame([ - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "spark hadoop spark"), - (7L, "apache hadoop") + (4, "spark i j k"), + (5, "l m n"), + (6, "spark hadoop spark"), + (7, "apache hadoop") ], ["id", "text"]) # Make predictions on test documents and print columns of interest. diff --git a/examples/src/main/python/mllib/binary_classification_metrics_example.py b/examples/src/main/python/mllib/binary_classification_metrics_example.py index daf000e38dcd..91f8378f29c0 100644 --- a/examples/src/main/python/mllib/binary_classification_metrics_example.py +++ b/examples/src/main/python/mllib/binary_classification_metrics_example.py @@ -39,7 +39,7 @@ .rdd.map(lambda row: LabeledPoint(row[0], row[1])) # Split data into training (60%) and test (40%) - training, test = data.randomSplit([0.6, 0.4], seed=11L) + training, test = data.randomSplit([0.6, 0.4], seed=11) training.cache() # Run training algorithm to build the model diff --git a/examples/src/main/python/mllib/multi_class_metrics_example.py b/examples/src/main/python/mllib/multi_class_metrics_example.py index cd56b3c97c77..7dc5fb4f9127 100644 --- a/examples/src/main/python/mllib/multi_class_metrics_example.py +++ b/examples/src/main/python/mllib/multi_class_metrics_example.py @@ -32,7 +32,7 @@ data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") # Split data into training (60%) and test (40%) - training, test = data.randomSplit([0.6, 0.4], seed=11L) + training, test = data.randomSplit([0.6, 0.4], seed=11) training.cache() # Run training algorithm to build the model From ac26e9cf27862fbfb97ae18d591606ecf2cd41cf Mon Sep 17 00:00:00 2001 From: Yunni Date: Fri, 28 Oct 2016 14:57:52 -0700 Subject: [PATCH 046/381] [SPARK-5992][ML] Locality Sensitive Hashing ## What changes were proposed in this pull request? Implement Locality Sensitive Hashing along with approximate nearest neighbors and approximate similarity join based on the [design doc](https://docs.google.com/document/d/1D15DTDMF_UWTTyWqXfG7y76iZalky4QmifUYQ6lH5GM/edit). Detailed changes are as follows: (1) Implement abstract LSH, LSHModel classes as Estimator-Model (2) Implement approxNearestNeighbors and approxSimilarityJoin in the abstract LSHModel (3) Implement Random Projection as LSH subclass for Euclidean distance, Min Hash for Jaccard Distance (4) Implement unit test utility methods including checkLshProperty, checkNearestNeighbor and checkSimilarityJoin Things that will be implemented in a follow-up PR: - Bit Sampling for Hamming Distance, SignRandomProjection for Cosine Distance - PySpark Integration for the scala classes and methods. ## How was this patch tested? Unit test is implemented for all the implemented classes and algorithms. A scalability test on Uber's dataset was performed internally. Tested the methods on [WEX dataset](https://aws.amazon.com/items/2345) from AWS, with the steps and results [here](https://docs.google.com/document/d/19BXg-67U83NVB3M0I84HVBVg3baAVaESD_mrg_-vLro/edit). ## References Gionis, Aristides, Piotr Indyk, and Rajeev Motwani. "Similarity search in high dimensions via hashing." VLDB 7 Sep. 1999: 518-529. Wang, Jingdong et al. "Hashing for similarity search: A survey." arXiv preprint arXiv:1408.2927 (2014). Author: Yunni Author: Yun Ni Closes #15148 from Yunni/SPARK-5992-yunn-lsh. --- .../org/apache/spark/ml/feature/LSH.scala | 313 ++++++++++++++++++ .../org/apache/spark/ml/feature/MinHash.scala | 194 +++++++++++ .../spark/ml/feature/RandomProjection.scala | 225 +++++++++++++ .../org/apache/spark/ml/feature/LSHTest.scala | 153 +++++++++ .../spark/ml/feature/MinHashSuite.scala | 126 +++++++ .../ml/feature/RandomProjectionSuite.scala | 197 +++++++++++ 6 files changed, 1208 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala new file mode 100644 index 000000000000..333a8c364a88 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.Random + +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * Params for [[LSH]]. + */ +private[ml] trait LSHParams extends HasInputCol with HasOutputCol { + /** + * Param for the dimension of LSH OR-amplification. + * + * In this implementation, we use LSH OR-amplification to reduce the false negative rate. The + * higher the dimension is, the lower the false negative rate. + * @group param + */ + final val outputDim: IntParam = new IntParam(this, "outputDim", "output dimension, where" + + "increasing dimensionality lowers the false negative rate, and decreasing dimensionality" + + " improves the running performance", ParamValidators.gt(0)) + + /** @group getParam */ + final def getOutputDim: Int = $(outputDim) + + setDefault(outputDim -> 1) + + /** + * Transform the Schema for LSH + * @param schema The schema of the input dataset without [[outputCol]] + * @return A derived schema with [[outputCol]] added + */ + protected[this] final def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) + } +} + +/** + * Model produced by [[LSH]]. + */ +private[ml] abstract class LSHModel[T <: LSHModel[T]] + extends Model[T] with LSHParams with MLWritable { + self: T => + + /** + * The hash function of LSH, mapping a predefined KeyType to a Vector + * @return The mapping of LSH function. + */ + protected[ml] val hashFunction: Vector => Vector + + /** + * Calculate the distance between two different keys using the distance metric corresponding + * to the hashFunction + * @param x One input vector in the metric space + * @param y One input vector in the metric space + * @return The distance between x and y + */ + protected[ml] def keyDistance(x: Vector, y: Vector): Double + + /** + * Calculate the distance between two different hash Vectors. + * + * @param x One of the hash vector + * @param y Another hash vector + * @return The distance between hash vectors x and y + */ + protected[ml] def hashDistance(x: Vector, y: Vector): Double + + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + val transformUDF = udf(hashFunction, new VectorUDT) + dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + /** + * Given a large dataset and an item, approximately find at most k items which have the closest + * distance to the item. If the [[outputCol]] is missing, the method will transform the data; if + * the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the + * transformed data when necessary. + * + * This method implements two ways of fetching k nearest neighbors: + * - Single Probing: Fast, return at most k elements (Probing only one buckets) + * - Multiple Probing: Slow, return exact k elements (Probing multiple buckets close to the key) + * + * @param dataset the dataset to search for nearest neighbors of the key + * @param key Feature vector representing the item to search for + * @param numNearestNeighbors The maximum number of nearest neighbors + * @param singleProbing True for using Single Probing; false for multiple probing + * @param distCol Output column for storing the distance between each result row and the key + * @return A dataset containing at most k items closest to the key. A distCol is added to show + * the distance between each row and the key. + */ + def approxNearestNeighbors( + dataset: Dataset[_], + key: Vector, + numNearestNeighbors: Int, + singleProbing: Boolean, + distCol: String): Dataset[_] = { + require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1") + // Get Hash Value of the key + val keyHash = hashFunction(key) + val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) { + transform(dataset) + } else { + dataset.toDF() + } + + // In the origin dataset, find the hash value that is closest to the key + val hashDistUDF = udf((x: Vector) => hashDistance(x, keyHash), DataTypes.DoubleType) + val hashDistCol = hashDistUDF(col($(outputCol))) + + val modelSubset = if (singleProbing) { + modelDataset.filter(hashDistCol === 0.0) + } else { + // Compute threshold to get exact k elements. + val modelDatasetSortedByHash = modelDataset.sort(hashDistCol).limit(numNearestNeighbors) + val thresholdDataset = modelDatasetSortedByHash.select(max(hashDistCol)) + val hashThreshold = thresholdDataset.take(1).head.getDouble(0) + + // Filter the dataset where the hash value is less than the threshold. + modelDataset.filter(hashDistCol <= hashThreshold) + } + + // Get the top k nearest neighbor by their distance to the key + val keyDistUDF = udf((x: Vector) => keyDistance(x, key), DataTypes.DoubleType) + val modelSubsetWithDistCol = modelSubset.withColumn(distCol, keyDistUDF(col($(inputCol)))) + modelSubsetWithDistCol.sort(distCol).limit(numNearestNeighbors) + } + + /** + * Overloaded method for approxNearestNeighbors. Use Single Probing as default way to search + * nearest neighbors and "distCol" as default distCol. + */ + def approxNearestNeighbors( + dataset: Dataset[_], + key: Vector, + numNearestNeighbors: Int): Dataset[_] = { + approxNearestNeighbors(dataset, key, numNearestNeighbors, true, "distCol") + } + + /** + * Preprocess step for approximate similarity join. Transform and explode the [[outputCol]] to + * two explodeCols: entry and value. "entry" is the index in hash vector, and "value" is the + * value of corresponding value of the index in the vector. + * + * @param dataset The dataset to transform and explode. + * @param explodeCols The alias for the exploded columns, must be a seq of two strings. + * @return A dataset containing idCol, inputCol and explodeCols + */ + private[this] def processDataset( + dataset: Dataset[_], + inputName: String, + explodeCols: Seq[String]): Dataset[_] = { + require(explodeCols.size == 2, "explodeCols must be two strings.") + val vectorToMap = udf((x: Vector) => x.asBreeze.iterator.toMap, + MapType(DataTypes.IntegerType, DataTypes.DoubleType)) + val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) { + transform(dataset) + } else { + dataset.toDF() + } + modelDataset.select( + struct(col("*")).as(inputName), + explode(vectorToMap(col($(outputCol)))).as(explodeCols)) + } + + /** + * Recreate a column using the same column name but different attribute id. Used in approximate + * similarity join. + * @param dataset The dataset where a column need to recreate + * @param colName The name of the column to recreate + * @param tmpColName A temporary column name which does not conflict with existing columns + * @return + */ + private[this] def recreateCol( + dataset: Dataset[_], + colName: String, + tmpColName: String): Dataset[_] = { + dataset + .withColumnRenamed(colName, tmpColName) + .withColumn(colName, col(tmpColName)) + .drop(tmpColName) + } + + /** + * Join two dataset to approximately find all pairs of rows whose distance are smaller than + * the threshold. If the [[outputCol]] is missing, the method will transform the data; if the + * [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the transformed + * data when necessary. + * + * @param datasetA One of the datasets to join + * @param datasetB Another dataset to join + * @param threshold The threshold for the distance of row pairs + * @param distCol Output column for storing the distance between each result row and the key + * @return A joined dataset containing pairs of rows. The original rows are in columns + * "datasetA" and "datasetB", and a distCol is added to show the distance of each pair + */ + def approxSimilarityJoin( + datasetA: Dataset[_], + datasetB: Dataset[_], + threshold: Double, + distCol: String): Dataset[_] = { + + val leftColName = "datasetA" + val rightColName = "datasetB" + val explodeCols = Seq("entry", "hashValue") + val explodedA = processDataset(datasetA, leftColName, explodeCols) + + // If this is a self join, we need to recreate the inputCol of datasetB to avoid ambiguity. + // TODO: Remove recreateCol logic once SPARK-17154 is resolved. + val explodedB = if (datasetA != datasetB) { + processDataset(datasetB, rightColName, explodeCols) + } else { + val recreatedB = recreateCol(datasetB, $(inputCol), s"${$(inputCol)}#${Random.nextString(5)}") + processDataset(recreatedB, rightColName, explodeCols) + } + + // Do a hash join on where the exploded hash values are equal. + val joinedDataset = explodedA.join(explodedB, explodeCols) + .drop(explodeCols: _*).distinct() + + // Add a new column to store the distance of the two rows. + val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y), DataTypes.DoubleType) + val joinedDatasetWithDist = joinedDataset.select(col("*"), + distUDF(col(s"$leftColName.${$(inputCol)}"), col(s"$rightColName.${$(inputCol)}")).as(distCol) + ) + + // Filter the joined datasets where the distance are smaller than the threshold. + joinedDatasetWithDist.filter(col(distCol) < threshold) + } + + /** + * Overloaded method for approxSimilarityJoin. Use "distCol" as default distCol. + */ + def approxSimilarityJoin( + datasetA: Dataset[_], + datasetB: Dataset[_], + threshold: Double): Dataset[_] = { + approxSimilarityJoin(datasetA, datasetB, threshold, "distCol") + } +} + +/** + * Locality Sensitive Hashing for different metrics space. Support basic transformation with a new + * hash column, approximate nearest neighbor search with a dataset and a key, and approximate + * similarity join of two datasets. + * + * This LSH class implements OR-amplification: more than 1 hash functions can be chosen, and each + * input vector are hashed by all hash functions. Two input vectors are defined to be in the same + * bucket as long as ANY one of the hash value matches. + * + * References: + * (1) Gionis, Aristides, Piotr Indyk, and Rajeev Motwani. "Similarity search in high dimensions + * via hashing." VLDB 7 Sep. 1999: 518-529. + * (2) Wang, Jingdong et al. "Hashing for similarity search: A survey." arXiv preprint + * arXiv:1408.2927 (2014). + */ +private[ml] abstract class LSH[T <: LSHModel[T]] + extends Estimator[T] with LSHParams with DefaultParamsWritable { + self: Estimator[T] => + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setOutputDim(value: Int): this.type = set(outputDim, value) + + /** + * Validate and create a new instance of concrete LSHModel. Because different LSHModel may have + * different initial setting, developer needs to define how their LSHModel is created instead of + * using reflection in this abstract class. + * @param inputDim The dimension of the input dataset + * @return A new LSHModel instance without any params + */ + protected[this] def createRawLSHModel(inputDim: Int): T + + override def fit(dataset: Dataset[_]): T = { + transformSchema(dataset.schema, logging = true) + val inputDim = dataset.select(col($(inputCol))).head().get(0).asInstanceOf[Vector].size + val model = createRawLSHModel(inputDim).setParent(this) + copyValues(model) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala new file mode 100644 index 000000000000..d9d0f32254e2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.Random + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasSeed +import org.apache.spark.ml.util._ +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * + * Model produced by [[MinHash]], where multiple hash functions are stored. Each hash function is + * a perfect hash function: + * `h_i(x) = (x * k_i mod prime) mod numEntries` + * where `k_i` is the i-th coefficient, and both `x` and `k_i` are from `Z_prime^*` + * + * Reference: + * [[https://en.wikipedia.org/wiki/Perfect_hash_function Wikipedia on Perfect Hash Function]] + * + * @param numEntries The number of entries of the hash functions. + * @param randCoefficients An array of random coefficients, each used by one hash function. + */ +@Experimental +@Since("2.1.0") +class MinHashModel private[ml] ( + override val uid: String, + @Since("2.1.0") val numEntries: Int, + @Since("2.1.0") val randCoefficients: Array[Int]) + extends LSHModel[MinHashModel] { + + @Since("2.1.0") + override protected[ml] val hashFunction: Vector => Vector = { + elems: Vector => + require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.") + val elemsList = elems.toSparse.indices.toList + val hashValues = randCoefficients.map({ randCoefficient: Int => + elemsList.map({elem: Int => + (1 + elem) * randCoefficient.toLong % MinHash.prime % numEntries + }).min.toDouble + }) + Vectors.dense(hashValues) + } + + @Since("2.1.0") + override protected[ml] def keyDistance(x: Vector, y: Vector): Double = { + val xSet = x.toSparse.indices.toSet + val ySet = y.toSparse.indices.toSet + val intersectionSize = xSet.intersect(ySet).size.toDouble + val unionSize = xSet.size + ySet.size - intersectionSize + assert(unionSize > 0, "The union of two input sets must have at least 1 elements") + 1 - intersectionSize / unionSize + } + + @Since("2.1.0") + override protected[ml] def hashDistance(x: Vector, y: Vector): Double = { + // Since it's generated by hashing, it will be a pair of dense vectors. + x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min + } + + @Since("2.1.0") + override def copy(extra: ParamMap): this.type = defaultCopy(extra) + + @Since("2.1.0") + override def write: MLWriter = new MinHashModel.MinHashModelWriter(this) +} + +/** + * :: Experimental :: + * + * LSH class for Jaccard distance. + * + * The input can be dense or sparse vectors, but it is more efficient if it is sparse. For example, + * `Vectors.sparse(10, Array[(2, 1.0), (3, 1.0), (5, 1.0)])` + * means there are 10 elements in the space. This set contains elem 2, elem 3 and elem 5. + * Also, any input vector must have at least 1 non-zero indices, and all non-zero values are treated + * as binary "1" values. + * + * References: + * [[https://en.wikipedia.org/wiki/MinHash Wikipedia on MinHash]] + */ +@Experimental +@Since("2.1.0") +class MinHash(override val uid: String) extends LSH[MinHashModel] with HasSeed { + + + @Since("2.1.0") + override def setInputCol(value: String): this.type = super.setInputCol(value) + + @Since("2.1.0") + override def setOutputCol(value: String): this.type = super.setOutputCol(value) + + @Since("2.1.0") + override def setOutputDim(value: Int): this.type = super.setOutputDim(value) + + @Since("2.1.0") + def this() = { + this(Identifiable.randomUID("min hash")) + } + + /** @group setParam */ + @Since("2.1.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("2.1.0") + override protected[ml] def createRawLSHModel(inputDim: Int): MinHashModel = { + require(inputDim <= MinHash.prime / 2, + s"The input vector dimension $inputDim exceeds the threshold ${MinHash.prime / 2}.") + val rand = new Random($(seed)) + val numEntry = inputDim * 2 + val randCoofs: Array[Int] = Array.fill($(outputDim))(1 + rand.nextInt(MinHash.prime - 1)) + new MinHashModel(uid, numEntry, randCoofs) + } + + @Since("2.1.0") + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + validateAndTransformSchema(schema) + } + + @Since("2.1.0") + override def copy(extra: ParamMap): this.type = defaultCopy(extra) +} + +@Since("2.1.0") +object MinHash extends DefaultParamsReadable[MinHash] { + // A large prime smaller than sqrt(2^63 − 1) + private[ml] val prime = 2038074743 + + @Since("2.1.0") + override def load(path: String): MinHash = super.load(path) +} + +@Since("2.1.0") +object MinHashModel extends MLReadable[MinHashModel] { + + @Since("2.1.0") + override def read: MLReader[MinHashModel] = new MinHashModelReader + + @Since("2.1.0") + override def load(path: String): MinHashModel = super.load(path) + + private[MinHashModel] class MinHashModelWriter(instance: MinHashModel) extends MLWriter { + + private case class Data(numEntries: Int, randCoefficients: Array[Int]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.numEntries, instance.randCoefficients) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MinHashModelReader extends MLReader[MinHashModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[MinHashModel].getName + + override def load(path: String): MinHashModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sparkSession.read.parquet(dataPath).select("numEntries", "randCoefficients").head() + val numEntries = data.getAs[Int](0) + val randCoefficients = data.getAs[Seq[Int]](1).toArray + val model = new MinHashModel(metadata.uid, numEntries, randCoefficients) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala new file mode 100644 index 000000000000..1b524c6710b4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.Random + +import breeze.linalg.normalize +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.HasSeed +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * + * Params for [[RandomProjection]]. + */ +private[ml] trait RandomProjectionParams extends Params { + + /** + * The length of each hash bucket, a larger bucket lowers the false negative rate. The number of + * buckets will be `(max L2 norm of input vectors) / bucketLength`. + * + * + * If input vectors are normalized, 1-10 times of pow(numRecords, -1/inputDim) would be a + * reasonable value + * @group param + */ + val bucketLength: DoubleParam = new DoubleParam(this, "bucketLength", + "the length of each hash bucket, a larger bucket lowers the false negative rate.", + ParamValidators.gt(0)) + + /** @group getParam */ + final def getBucketLength: Double = $(bucketLength) +} + +/** + * :: Experimental :: + * + * Model produced by [[RandomProjection]], where multiple random vectors are stored. The vectors + * are normalized to be unit vectors and each vector is used in a hash function: + * `h_i(x) = floor(r_i.dot(x) / bucketLength)` + * where `r_i` is the i-th random unit vector. The number of buckets will be `(max L2 norm of input + * vectors) / bucketLength`. + * + * @param randUnitVectors An array of random unit vectors. Each vector represents a hash function. + */ +@Experimental +@Since("2.1.0") +class RandomProjectionModel private[ml] ( + override val uid: String, + @Since("2.1.0") val randUnitVectors: Array[Vector]) + extends LSHModel[RandomProjectionModel] with RandomProjectionParams { + + @Since("2.1.0") + override protected[ml] val hashFunction: (Vector) => Vector = { + key: Vector => { + val hashValues: Array[Double] = randUnitVectors.map({ + randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength)) + }) + Vectors.dense(hashValues) + } + } + + @Since("2.1.0") + override protected[ml] def keyDistance(x: Vector, y: Vector): Double = { + Math.sqrt(Vectors.sqdist(x, y)) + } + + @Since("2.1.0") + override protected[ml] def hashDistance(x: Vector, y: Vector): Double = { + // Since it's generated by hashing, it will be a pair of dense vectors. + x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min + } + + @Since("2.1.0") + override def copy(extra: ParamMap): this.type = defaultCopy(extra) + + @Since("2.1.0") + override def write: MLWriter = new RandomProjectionModel.RandomProjectionModelWriter(this) +} + +/** + * :: Experimental :: + * + * This [[RandomProjection]] implements Locality Sensitive Hashing functions for Euclidean + * distance metrics. + * + * The input is dense or sparse vectors, each of which represents a point in the Euclidean + * distance space. The output will be vectors of configurable dimension. Hash value in the same + * dimension is calculated by the same hash function. + * + * References: + * + * 1. [[https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Stable_distributions + * Wikipedia on Stable Distributions]] + * + * 2. Wang, Jingdong et al. "Hashing for similarity search: A survey." arXiv preprint + * arXiv:1408.2927 (2014). + */ +@Experimental +@Since("2.1.0") +class RandomProjection(override val uid: String) extends LSH[RandomProjectionModel] + with RandomProjectionParams with HasSeed { + + @Since("2.1.0") + override def setInputCol(value: String): this.type = super.setInputCol(value) + + @Since("2.1.0") + override def setOutputCol(value: String): this.type = super.setOutputCol(value) + + @Since("2.1.0") + override def setOutputDim(value: Int): this.type = super.setOutputDim(value) + + @Since("2.1.0") + def this() = { + this(Identifiable.randomUID("random projection")) + } + + /** @group setParam */ + @Since("2.1.0") + def setBucketLength(value: Double): this.type = set(bucketLength, value) + + /** @group setParam */ + @Since("2.1.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("2.1.0") + override protected[this] def createRawLSHModel(inputDim: Int): RandomProjectionModel = { + val rand = new Random($(seed)) + val randUnitVectors: Array[Vector] = { + Array.fill($(outputDim)) { + val randArray = Array.fill(inputDim)(rand.nextGaussian()) + Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray))) + } + } + new RandomProjectionModel(uid, randUnitVectors) + } + + @Since("2.1.0") + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + validateAndTransformSchema(schema) + } + + @Since("2.1.0") + override def copy(extra: ParamMap): this.type = defaultCopy(extra) +} + +@Since("2.1.0") +object RandomProjection extends DefaultParamsReadable[RandomProjection] { + + @Since("2.1.0") + override def load(path: String): RandomProjection = super.load(path) +} + +@Since("2.1.0") +object RandomProjectionModel extends MLReadable[RandomProjectionModel] { + + @Since("2.1.0") + override def read: MLReader[RandomProjectionModel] = new RandomProjectionModelReader + + @Since("2.1.0") + override def load(path: String): RandomProjectionModel = super.load(path) + + private[RandomProjectionModel] class RandomProjectionModelWriter(instance: RandomProjectionModel) + extends MLWriter { + + // TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved. + private case class Data(randUnitVectors: Matrix) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val numRows = instance.randUnitVectors.length + require(numRows > 0) + val numCols = instance.randUnitVectors.head.size + val values = instance.randUnitVectors.map(_.toArray).reduce(Array.concat(_, _)) + val randMatrix = Matrices.dense(numRows, numCols, values) + val data = Data(randMatrix) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class RandomProjectionModelReader extends MLReader[RandomProjectionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[RandomProjectionModel].getName + + override def load(path: String): RandomProjectionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sparkSession.read.parquet(dataPath) + val Row(randUnitVectors: Matrix) = MLUtils.convertMatrixColumnsToML(data, "randUnitVectors") + .select("randUnitVectors") + .head() + val model = new RandomProjectionModel(metadata.uid, randUnitVectors.rowIter.toArray) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala new file mode 100644 index 000000000000..5c025546f332 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DataTypes + +private[ml] object LSHTest { + /** + * For any locality sensitive function h in a metric space, we meed to verify whether + * the following property is satisfied. + * + * There exist dist1, dist2, p1, p2, so that for any two elements e1 and e2, + * If dist(e1, e2) <= dist1, then Pr{h(x) == h(y)} >= p1 + * If dist(e1, e2) >= dist2, then Pr{h(x) == h(y)} <= p2 + * + * This is called locality sensitive property. This method checks the property on an + * existing dataset and calculate the probabilities. + * (https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Definition) + * + * This method hashes each elements to hash buckets using LSH, and calculate the false positive + * and false negative: + * False positive: Of all (e1, e2) sharing any bucket, the probability of dist(e1, e2) > distFP + * False negative: Of all (e1, e2) not sharing buckets, the probability of dist(e1, e2) < distFN + * + * @param dataset The dataset to verify the locality sensitive hashing property. + * @param lsh The lsh instance to perform the hashing + * @param distFP Distance threshold for false positive + * @param distFN Distance threshold for false negative + * @tparam T The class type of lsh + * @return A tuple of two doubles, representing the false positive and false negative rate + */ + def calculateLSHProperty[T <: LSHModel[T]]( + dataset: Dataset[_], + lsh: LSH[T], + distFP: Double, + distFN: Double): (Double, Double) = { + val model = lsh.fit(dataset) + val inputCol = model.getInputCol + val outputCol = model.getOutputCol + val transformedData = model.transform(dataset) + + SchemaUtils.checkColumnType(transformedData.schema, model.getOutputCol, new VectorUDT) + + // Perform a cross join and label each pair of same_bucket and distance + val pairs = transformedData.as("a").crossJoin(transformedData.as("b")) + val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType) + val sameBucket = udf((x: Vector, y: Vector) => model.hashDistance(x, y) == 0.0, + DataTypes.BooleanType) + val result = pairs + .withColumn("same_bucket", sameBucket(col(s"a.$outputCol"), col(s"b.$outputCol"))) + .withColumn("distance", distUDF(col(s"a.$inputCol"), col(s"b.$inputCol"))) + + // Compute the probabilities based on the join result + val positive = result.filter(col("same_bucket")) + val negative = result.filter(!col("same_bucket")) + val falsePositiveCount = positive.filter(col("distance") > distFP).count().toDouble + val falseNegativeCount = negative.filter(col("distance") < distFN).count().toDouble + (falsePositiveCount / positive.count(), falseNegativeCount / negative.count()) + } + + /** + * Compute the precision and recall of approximate nearest neighbors + * @param lsh The lsh instance + * @param dataset the dataset to look for the key + * @param key The key to hash for the item + * @param k The maximum number of items closest to the key + * @tparam T The class type of lsh + * @return A tuple of two doubles, representing precision and recall rate + */ + def calculateApproxNearestNeighbors[T <: LSHModel[T]]( + lsh: LSH[T], + dataset: Dataset[_], + key: Vector, + k: Int, + singleProbing: Boolean): (Double, Double) = { + val model = lsh.fit(dataset) + + // Compute expected + val distUDF = udf((x: Vector) => model.keyDistance(x, key), DataTypes.DoubleType) + val expected = dataset.sort(distUDF(col(model.getInputCol))).limit(k) + + // Compute actual + val actual = model.approxNearestNeighbors(dataset, key, k, singleProbing, "distCol") + + assert(actual.schema.sameType(model + .transformSchema(dataset.schema) + .add("distCol", DataTypes.DoubleType)) + ) + + if (!singleProbing) { + assert(actual.count() == k) + } + + // Compute precision and recall + val correctCount = expected.join(actual, model.getInputCol).count().toDouble + (correctCount / actual.count(), correctCount / expected.count()) + } + + /** + * Compute the precision and recall of approximate similarity join + * @param lsh The lsh instance + * @param datasetA One of the datasets to join + * @param datasetB Another dataset to join + * @param threshold The threshold for the distance of record pairs + * @tparam T The class type of lsh + * @return A tuple of two doubles, representing precision and recall rate + */ + def calculateApproxSimilarityJoin[T <: LSHModel[T]]( + lsh: LSH[T], + datasetA: Dataset[_], + datasetB: Dataset[_], + threshold: Double): (Double, Double) = { + val model = lsh.fit(datasetA) + val inputCol = model.getInputCol + + // Compute expected + val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType) + val expected = datasetA.as("a").crossJoin(datasetB.as("b")) + .filter(distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")) < threshold) + + // Compute actual + val actual = model.approxSimilarityJoin(datasetA, datasetB, threshold) + + SchemaUtils.checkColumnType(actual.schema, "distCol", DataTypes.DoubleType) + assert(actual.schema.apply("datasetA").dataType + .sameType(model.transformSchema(datasetA.schema))) + assert(actual.schema.apply("datasetB").dataType + .sameType(model.transformSchema(datasetB.schema))) + + // Compute precision and recall + val correctCount = actual.filter(col("distCol") < threshold).count().toDouble + (correctCount / actual.count(), correctCount / expected.count()) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala new file mode 100644 index 000000000000..c32ca7d69cf8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + +class MinHashSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var dataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val data = { + for (i <- 0 to 95) yield Vectors.sparse(100, (i until i + 5).map((_, 1.0))) + } + dataset = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys") + } + + test("params") { + ParamsSuite.checkParams(new MinHash) + val model = new MinHashModel("mh", numEntries = 2, randCoefficients = Array(1)) + ParamsSuite.checkParams(model) + } + + test("MinHash: default params") { + val rp = new MinHash + assert(rp.getOutputDim === 1.0) + } + + test("read/write") { + def checkModelData(model: MinHashModel, model2: MinHashModel): Unit = { + assert(model.numEntries === model2.numEntries) + assertResult(model.randCoefficients)(model2.randCoefficients) + } + val mh = new MinHash() + val settings = Map("inputCol" -> "keys", "outputCol" -> "values") + testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + } + + test("hashFunction") { + val model = new MinHashModel("mh", numEntries = 20, randCoefficients = Array(0, 1, 3)) + val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0)))) + assert(res.equals(Vectors.dense(0.0, 3.0, 4.0))) + } + + test("keyDistance and hashDistance") { + val model = new MinHashModel("mh", numEntries = 20, randCoefficients = Array(1)) + val v1 = Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0))) + val v2 = Vectors.sparse(10, Seq((1, 1.0), (3, 1.0), (5, 1.0), (7, 1.0), (9, 1.0))) + val keyDist = model.keyDistance(v1, v2) + val hashDist = model.hashDistance(Vectors.dense(-5, 5), Vectors.dense(1, 2)) + assert(keyDist === 0.5) + assert(hashDist === 3) + } + + test("MinHash: test of LSH property") { + val mh = new MinHash() + .setOutputDim(1) + .setInputCol("keys") + .setOutputCol("values") + .setSeed(12344) + + val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(dataset, mh, 0.75, 0.5) + assert(falsePositive < 0.3) + assert(falseNegative < 0.3) + } + + test("approxNearestNeighbors for min hash") { + val mh = new MinHash() + .setOutputDim(20) + .setInputCol("keys") + .setOutputCol("values") + .setSeed(12345) + + val key: Vector = Vectors.sparse(100, + (0 until 100).filter(_.toString.contains("1")).map((_, 1.0))) + + val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(mh, dataset, key, 20, + singleProbing = true) + assert(precision >= 0.7) + assert(recall >= 0.7) + } + + test("approxSimilarityJoin for minhash on different dataset") { + val data1 = { + for (i <- 0 until 20) yield Vectors.sparse(100, (5 * i until 5 * i + 5).map((_, 1.0))) + } + val df1 = spark.createDataFrame(data1.map(Tuple1.apply)).toDF("keys") + + val data2 = { + for (i <- 0 until 30) yield Vectors.sparse(100, (3 * i until 3 * i + 3).map((_, 1.0))) + } + val df2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys") + + val mh = new MinHash() + .setOutputDim(20) + .setInputCol("keys") + .setOutputCol("values") + .setSeed(12345) + + val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(mh, df1, df2, 0.5) + assert(precision == 1.0) + assert(recall >= 0.7) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala new file mode 100644 index 000000000000..cd82ee2117a0 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import breeze.numerics.{cos, sin} +import breeze.numerics.constants.Pi + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + +class RandomProjectionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var dataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val data = { + for (i <- -10 until 10; j <- -10 until 10) yield Vectors.dense(i.toDouble, j.toDouble) + } + dataset = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys") + } + + test("params") { + ParamsSuite.checkParams(new RandomProjection) + val model = new RandomProjectionModel("rp", randUnitVectors = Array(Vectors.dense(1.0, 0.0))) + ParamsSuite.checkParams(model) + } + + test("RandomProjection: default params") { + val rp = new RandomProjection + assert(rp.getOutputDim === 1.0) + } + + test("read/write") { + def checkModelData(model: RandomProjectionModel, model2: RandomProjectionModel): Unit = { + model.randUnitVectors.zip(model2.randUnitVectors) + .foreach(pair => assert(pair._1 === pair._2)) + } + val mh = new RandomProjection() + val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0) + testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + } + + test("hashFunction") { + val randUnitVectors = Array(Vectors.dense(0.0, 1.0), Vectors.dense(1.0, 0.0)) + val model = new RandomProjectionModel("rp", randUnitVectors) + model.set(model.bucketLength, 0.5) + val res = model.hashFunction(Vectors.dense(1.23, 4.56)) + assert(res.equals(Vectors.dense(9.0, 2.0))) + } + + test("keyDistance and hashDistance") { + val model = new RandomProjectionModel("rp", Array(Vectors.dense(0.0, 1.0))) + val keyDist = model.keyDistance(Vectors.dense(1, 2), Vectors.dense(-2, -2)) + val hashDist = model.hashDistance(Vectors.dense(-5, 5), Vectors.dense(1, 2)) + assert(keyDist === 5) + assert(hashDist === 3) + } + + test("RandomProjection: randUnitVectors") { + val rp = new RandomProjection() + .setOutputDim(20) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(1.0) + .setSeed(12345) + val unitVectors = rp.fit(dataset).randUnitVectors + unitVectors.foreach { v: Vector => + assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14) + } + } + + test("RandomProjection: test of LSH property") { + // Project from 2 dimensional Euclidean Space to 1 dimensions + val rp = new RandomProjection() + .setOutputDim(1) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(1.0) + .setSeed(12345) + + val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(dataset, rp, 8.0, 2.0) + assert(falsePositive < 0.4) + assert(falseNegative < 0.4) + } + + test("RandomProjection with high dimension data: test of LSH property") { + val numDim = 100 + val data = { + for (i <- 0 until numDim; j <- Seq(-2, -1, 1, 2)) + yield Vectors.sparse(numDim, Seq((i, j.toDouble))) + } + val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys") + + // Project from 100 dimensional Euclidean Space to 10 dimensions + val rp = new RandomProjection() + .setOutputDim(10) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(2.5) + .setSeed(12345) + + val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(df, rp, 3.0, 2.0) + assert(falsePositive < 0.3) + assert(falseNegative < 0.3) + } + + test("approxNearestNeighbors for random projection") { + val key = Vectors.dense(1.2, 3.4) + + val rp = new RandomProjection() + .setOutputDim(2) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(4.0) + .setSeed(12345) + + val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100, + singleProbing = true) + assert(precision >= 0.6) + assert(recall >= 0.6) + } + + test("approxNearestNeighbors with multiple probing") { + val key = Vectors.dense(1.2, 3.4) + + val rp = new RandomProjection() + .setOutputDim(20) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(1.0) + .setSeed(12345) + + val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100, + singleProbing = false) + assert(precision >= 0.7) + assert(recall >= 0.7) + } + + test("approxSimilarityJoin for random projection on different dataset") { + val data2 = { + for (i <- 0 until 24) yield Vectors.dense(10 * sin(Pi / 12 * i), 10 * cos(Pi / 12 * i)) + } + val dataset2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys") + + val rp = new RandomProjection() + .setOutputDim(2) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(4.0) + .setSeed(12345) + + val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(rp, dataset, dataset2, 1.0) + assert(precision == 1.0) + assert(recall >= 0.7) + } + + test("approxSimilarityJoin for self join") { + val data = { + for (i <- 0 until 24) yield Vectors.dense(10 * sin(Pi / 12 * i), 10 * cos(Pi / 12 * i)) + } + val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys") + + val rp = new RandomProjection() + .setOutputDim(2) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(4.0) + .setSeed(12345) + + val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(rp, df, df, 3.0) + assert(precision == 1.0) + assert(recall >= 0.7) + } +} From 59cccbda489f25add3e10997e950de7e88704aa7 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 28 Oct 2016 20:14:38 -0700 Subject: [PATCH 047/381] [SPARK-18164][SQL] ForeachSink should fail the Spark job if `process` throws exception ## What changes were proposed in this pull request? Fixed the issue that ForeachSink didn't rethrow the exception. ## How was this patch tested? The fixed unit test. Author: Shixiong Zhu Closes #15674 from zsxwing/foreach-sink-error. --- .../sql/execution/streaming/ForeachSink.scala | 7 ++----- .../streaming/ForeachSinkSuite.scala | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index 082664aa23f0..24f98b9211f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -68,19 +68,16 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria } datasetWithIncrementalExecution.foreachPartition { iter => if (writer.open(TaskContext.getPartitionId(), batchId)) { - var isFailed = false try { while (iter.hasNext) { writer.process(iter.next()) } } catch { case e: Throwable => - isFailed = true writer.close(e) + throw e } - if (!isFailed) { - writer.close(null) - } + writer.close(null) } else { writer.close(null) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 7928b8e8775c..9e059216110f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -23,8 +23,9 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkException import org.apache.spark.sql.ForeachWriter -import org.apache.spark.sql.streaming.{OutputMode, StreamTest} +import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { @@ -136,7 +137,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } } - test("foreach with error") { + testQuietly("foreach with error") { withTempDir { checkpointDir => val input = MemoryStream[Int] val query = input.toDS().repartition(1).writeStream @@ -148,16 +149,24 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } }).start() input.addData(1, 2, 3, 4) - query.processAllAvailable() + + // Error in `process` should fail the Spark job + val e = intercept[StreamingQueryException] { + query.processAllAvailable() + } + assert(e.getCause.isInstanceOf[SparkException]) + assert(e.getCause.getCause.getMessage === "error") + assert(query.isActive === false) val allEvents = ForeachSinkSuite.allEvents() assert(allEvents.size === 1) assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + + // `close` should be called with the error val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] assert(errorEvent.error.get.isInstanceOf[RuntimeException]) assert(errorEvent.error.get.getMessage === "error") - query.stop() } } } From d2d438d1d549628a0183e468ed11d6e85b5d6061 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 29 Oct 2016 06:49:57 +0200 Subject: [PATCH 048/381] [SPARK-18167][SQL] Add debug code for SQLQuerySuite flakiness when metastore partition pruning is enabled ## What changes were proposed in this pull request? org.apache.spark.sql.hive.execution.SQLQuerySuite is flaking when hive partition pruning is enabled. Based on the stack traces, it seems to be an old issue where Hive fails to cast a numeric partition column ("Invalid character string format for type DECIMAL"). There are two possibilities here: either we are somehow corrupting the partition table to have non-decimal values in that column, or there is a transient issue with Derby. This PR logs the result of the retry when this exception is encountered, so we can confirm what is going on. ## How was this patch tested? n/a cc yhuai Author: Eric Liang Closes #15676 from ericl/spark-18167. --- .../apache/spark/sql/hive/client/HiveShim.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 32387707612f..4bbbd66132b7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JS import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ +import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -585,7 +586,19 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] } else { logDebug(s"Hive metastore filter is '$filter'.") - getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]] + try { + getPartitionsByFilterMethod.invoke(hive, table, filter) + .asInstanceOf[JArrayList[Partition]] + } catch { + case e: InvocationTargetException => + // SPARK-18167 retry to investigate the flaky test. This should be reverted before + // the release is cut. + val retry = Try(getPartitionsByFilterMethod.invoke(hive, table, filter)) + val full = Try(getAllPartitionsMethod.invoke(hive, table)) + logError("getPartitionsByFilter failed, retry success = " + retry.isSuccess) + logError("getPartitionsByFilter failed, full fetch success = " + full.isSuccess) + throw e + } } partitions.asScala.toSeq From 505b927cb7ff037adb797b9c3b9ecac3f885b7c8 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Sun, 30 Oct 2016 09:32:19 +0000 Subject: [PATCH 049/381] [SPARK-16312][FOLLOW-UP][STREAMING][KAFKA][DOC] Add java code snippet for Kafka 0.10 integration doc ## What changes were proposed in this pull request? added java code snippet for Kafka 0.10 integration doc ## How was this patch tested? SKIP_API=1 jekyll build ## Screenshot ![kafka-doc](https://cloud.githubusercontent.com/assets/15843379/19826272/bf0d8a4c-9db8-11e6-9e40-1396723df4bc.png) Author: Liwei Lin Closes #15679 from lw-lin/kafka-010-examples. --- docs/streaming-kafka-0-10-integration.md | 133 +++++++++++++++++++++-- 1 file changed, 122 insertions(+), 11 deletions(-) diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index de95ea90137e..c1ef396907db 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -8,9 +8,9 @@ The Spark Streaming integration for Kafka 0.10 is similar in design to the 0.8 [ ### Linking For Scala/Java applications using SBT/Maven project definitions, link your streaming application with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). - groupId = org.apache.spark - artifactId = spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION_SHORT}} + groupId = org.apache.spark + artifactId = spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} ### Creating a Direct Stream Note that the namespace for the import includes the version, org.apache.spark.streaming.kafka010 @@ -44,6 +44,42 @@ For Scala/Java applications using SBT/Maven project definitions, link your strea Each item in the stream is a [ConsumerRecord](http://kafka.apache.org/0100/javadoc/org/apache/kafka/clients/consumer/ConsumerRecord.html)
+ import java.util.*; + import org.apache.spark.SparkConf; + import org.apache.spark.TaskContext; + import org.apache.spark.api.java.*; + import org.apache.spark.api.java.function.*; + import org.apache.spark.streaming.api.java.*; + import org.apache.spark.streaming.kafka010.*; + import org.apache.kafka.clients.consumer.ConsumerRecord; + import org.apache.kafka.common.TopicPartition; + import org.apache.kafka.common.serialization.StringDeserializer; + import scala.Tuple2; + + Map kafkaParams = new HashMap<>(); + kafkaParams.put("bootstrap.servers", "localhost:9092,anotherhost:9092"); + kafkaParams.put("key.deserializer", StringDeserializer.class); + kafkaParams.put("value.deserializer", StringDeserializer.class); + kafkaParams.put("group.id", "use_a_separate_group_id_for_each_stream"); + kafkaParams.put("auto.offset.reset", "latest"); + kafkaParams.put("enable.auto.commit", false); + + Collection topics = Arrays.asList("topicA", "topicB"); + + final JavaInputDStream> stream = + KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Subscribe(topics, kafkaParams) + ); + + stream.mapToPair( + new PairFunction, String, String>() { + @Override + public Tuple2 call(ConsumerRecord record) { + return new Tuple2<>(record.key(), record.value()); + } + })
@@ -85,6 +121,20 @@ If you have a use case that is better suited to batch processing, you can create
+ // Import dependencies and create kafka params as in Create Direct Stream above + + OffsetRange[] offsetRanges = { + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange.create("test", 0, 0, 100), + OffsetRange.create("test", 1, 0, 100) + }; + + JavaRDD> rdd = KafkaUtils.createRDD( + sparkContext, + kafkaParams, + offsetRanges, + LocationStrategies.PreferConsistent() + );
@@ -103,6 +153,20 @@ Note that you cannot use `PreferBrokers`, because without the stream there is no }
+ stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + final OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + rdd.foreachPartition(new VoidFunction>>() { + @Override + public void call(Iterator> consumerRecords) { + OffsetRange o = offsetRanges[TaskContext.get().partitionId()]; + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset()); + } + }); + } + });
@@ -120,15 +184,24 @@ Kafka has an offset commit API that stores offsets in a special Kafka topic. By
stream.foreachRDD { rdd => - val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges // some time later, after outputs have completed - stream.asInstanceOf[CanCommitOffsets].commitAsync(offsets) + stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges) } As with HasOffsetRanges, the cast to CanCommitOffsets will only succeed if called on the result of createDirectStream, not after transformations. The commitAsync call is threadsafe, but must occur after outputs if you want meaningful semantics.
+ stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + // some time later, after outputs have completed + ((CanCommitOffsets) stream.inputDStream()).commitAsync(offsetRanges); + } + });
@@ -141,7 +214,7 @@ For data stores that support transactions, saving offsets in the same transactio // begin from the the offsets committed to the database val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => - new TopicPartition(resultSet.string("topic")), resultSet.int("partition")) -> resultSet.long("offset") + new TopicPartition(resultSet.string("topic"), resultSet.int("partition")) -> resultSet.long("offset") }.toMap val stream = KafkaUtils.createDirectStream[String, String]( @@ -155,16 +228,46 @@ For data stores that support transactions, saving offsets in the same transactio val results = yourCalculation(rdd) - yourTransactionBlock { - // update results + // begin your transaction - // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly - // assert that offsets were updated correctly - } + // end your transaction }
+ // The details depend on your data store, but the general idea looks like this + + // begin from the the offsets committed to the database + Map fromOffsets = new HashMap<>(); + for (resultSet : selectOffsetsFromYourDatabase) + fromOffsets.put(new TopicPartition(resultSet.string("topic"), resultSet.int("partition")), resultSet.long("offset")); + } + + JavaInputDStream> stream = KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Assign(fromOffsets.keySet(), kafkaParams, fromOffsets) + ); + + stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + Object results = yourCalculation(rdd); + + // begin your transaction + + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly + + // end your transaction + } + });
@@ -185,6 +288,14 @@ The new Kafka consumer [supports SSL](http://kafka.apache.org/documentation.html )
+ Map kafkaParams = new HashMap(); + // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS + kafkaParams.put("security.protocol", "SSL"); + kafkaParams.put("ssl.truststore.location", "/some-directory/kafka.client.truststore.jks"); + kafkaParams.put("ssl.truststore.password", "test1234"); + kafkaParams.put("ssl.keystore.location", "/some-directory/kafka.client.keystore.jks"); + kafkaParams.put("ssl.keystore.password", "test1234"); + kafkaParams.put("ssl.key.password", "test1234");
From a489567e36e671cee290f8d69188837a8b1a75b3 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 30 Oct 2016 09:36:23 +0000 Subject: [PATCH 050/381] [SPARK-3261][MLLIB] KMeans clusterer can return duplicate cluster centers ## What changes were proposed in this pull request? Return potentially fewer than k cluster centers in cases where k distinct centroids aren't available or aren't selected. ## How was this patch tested? Existing tests Author: Sean Owen Closes #15450 from srowen/SPARK-3261. --- .../apache/spark/ml/clustering/KMeans.scala | 4 +- .../spark/mllib/clustering/KMeans.scala | 27 ++-- .../spark/mllib/clustering/KMeansSuite.scala | 119 ++++++++++-------- 3 files changed, 85 insertions(+), 65 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 05ed3223ae53..85bb8c93b3fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -41,7 +41,9 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe with HasSeed with HasPredictionCol with HasTol { /** - * The number of clusters to create (k). Must be > 1. Default: 2. + * The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than + * k clusters to be returned, for example, if there are fewer than k distinct points to cluster. + * Default: 2. * @group param */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 68a7b3b6763a..ed9c064879d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -56,13 +56,15 @@ class KMeans private ( def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) /** - * Number of clusters to create (k). + * Number of clusters to create (k). Note that it is possible for fewer than k clusters to + * be returned, for example, if there are fewer than k distinct points to cluster. */ @Since("1.4.0") def getK: Int = k /** - * Set the number of clusters to create (k). Default: 2. + * Set the number of clusters to create (k). Note that it is possible for fewer than k clusters to + * be returned, for example, if there are fewer than k distinct points to cluster. Default: 2. */ @Since("0.8.0") def setK(k: Int): this.type = { @@ -323,7 +325,10 @@ class KMeans private ( * Initialize a set of cluster centers at random. */ private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { - data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt()).map(_.toDense) + // Select without replacement; may still produce duplicates if the data has < k distinct + // points, so deduplicate the centroids to match the behavior of k-means|| in the same situation + data.takeSample(false, k, new XORShiftRandom(this.seed).nextInt()) + .map(_.vector).distinct.map(new VectorWithNorm(_)) } /** @@ -335,7 +340,7 @@ class KMeans private ( * * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf. */ - private def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { + private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { // Initialize empty centers and point costs. var costs = data.map(_ => Double.PositiveInfinity) @@ -378,19 +383,21 @@ class KMeans private ( costs.unpersist(blocking = false) bcNewCentersList.foreach(_.destroy(false)) - if (centers.size == k) { - centers.toArray + val distinctCenters = centers.map(_.vector).distinct.map(new VectorWithNorm(_)) + + if (distinctCenters.size <= k) { + distinctCenters.toArray } else { - // Finally, we might have a set of more or less than k candidate centers; weight each + // Finally, we might have a set of more than k distinct candidate centers; weight each // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick k of them - val bcCenters = data.context.broadcast(centers) + val bcCenters = data.context.broadcast(distinctCenters) val countMap = data.map(KMeans.findClosest(bcCenters.value, _)._1).countByValue() bcCenters.destroy(blocking = false) - val myWeights = centers.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray - LocalKMeans.kMeansPlusPlus(0, centers.toArray, myWeights, k, 30) + val myWeights = distinctCenters.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray + LocalKMeans.kMeansPlusPlus(0, distinctCenters.toArray, myWeights, k, 30) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 2d35b312083c..48bd41dc3e3b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -29,6 +29,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM} + private val seed = 42 + test("single cluster") { val data = sc.parallelize(Array( Vectors.dense(1.0, 2.0, 6.0), @@ -38,7 +40,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { val center = Vectors.dense(1.0, 3.0, 4.0) - // No matter how many runs or iterations we use, we should get one cluster, + // No matter how many iterations we use, we should get one cluster, // centered at the mean of the points var model = KMeans.train(data, k = 1, maxIterations = 1) @@ -50,44 +52,72 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { model = KMeans.train(data, k = 1, maxIterations = 5) assert(model.clusterCenters.head ~== center absTol 1E-5) - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head ~== center absTol 1E-5) - - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head ~== center absTol 1E-5) - - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) + model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = RANDOM) assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train( - data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) + data, k = 1, maxIterations = 1, initializationMode = K_MEANS_PARALLEL) assert(model.clusterCenters.head ~== center absTol 1E-5) } - test("no distinct points") { + test("fewer distinct points than clusters") { val data = sc.parallelize( Array( Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(1.0, 2.0, 3.0)), 2) - val center = Vectors.dense(1.0, 2.0, 3.0) - // Make sure code runs. - var model = KMeans.train(data, k = 2, maxIterations = 1) - assert(model.clusterCenters.size === 2) - } + var model = KMeans.train(data, k = 2, maxIterations = 1, initializationMode = "random") + assert(model.clusterCenters.length === 1) - test("more clusters than points") { - val data = sc.parallelize( - Array( - Vectors.dense(1.0, 2.0, 3.0), - Vectors.dense(1.0, 3.0, 4.0)), - 2) + model = KMeans.train(data, k = 2, maxIterations = 1, initializationMode = "k-means||") + assert(model.clusterCenters.length === 1) + } - // Make sure code runs. - var model = KMeans.train(data, k = 3, maxIterations = 1) - assert(model.clusterCenters.size === 3) + test("unique cluster centers") { + val rng = new Random(seed) + val numDistinctPoints = 10 + val points = (0 until numDistinctPoints).map(i => Vectors.dense(Array.fill(3)(rng.nextDouble))) + val data = sc.parallelize(points.flatMap(Array.fill(1 + rng.nextInt(3))(_)), 2) + val normedData = data.map(new VectorWithNorm(_)) + + // less centers than k + val km = new KMeans().setK(50) + .setMaxIterations(5) + .setInitializationMode("k-means||") + .setInitializationSteps(10) + .setSeed(seed) + val initialCenters = km.initKMeansParallel(normedData).map(_.vector) + assert(initialCenters.length === initialCenters.distinct.length) + assert(initialCenters.length <= numDistinctPoints) + + val model = km.run(data) + val finalCenters = model.clusterCenters + assert(finalCenters.length === finalCenters.distinct.length) + + // run local k-means + val k = 10 + val km2 = new KMeans().setK(k) + .setMaxIterations(5) + .setInitializationMode("k-means||") + .setInitializationSteps(10) + .setSeed(seed) + val initialCenters2 = km2.initKMeansParallel(normedData).map(_.vector) + assert(initialCenters2.length === initialCenters2.distinct.length) + assert(initialCenters2.length === k) + + val model2 = km2.run(data) + val finalCenters2 = model2.clusterCenters + assert(finalCenters2.length === finalCenters2.distinct.length) + + val km3 = new KMeans().setK(k) + .setMaxIterations(5) + .setInitializationMode("random") + .setSeed(seed) + val model3 = km3.run(data) + val finalCenters3 = model3.clusterCenters + assert(finalCenters3.length === finalCenters3.distinct.length) } test("deterministic initialization") { @@ -97,12 +127,12 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) { // Create three deterministic models and compare cluster means - val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, - initializationMode = initMode, seed = 42) + val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, + initializationMode = initMode, seed = seed) val centers1 = model1.clusterCenters - val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, - initializationMode = initMode, seed = 42) + val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, + initializationMode = initMode, seed = seed) val centers2 = model2.clusterCenters centers1.zip(centers2).foreach { case (c1, c2) => @@ -119,7 +149,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { ) val data = sc.parallelize((1 to 100).flatMap(_ => smallData), 4) - // No matter how many runs or iterations we use, we should get one cluster, + // No matter how many iterations we use, we should get one cluster, // centered at the mean of the points val center = Vectors.dense(1.0, 3.0, 4.0) @@ -134,17 +164,10 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { model = KMeans.train(data, k = 1, maxIterations = 5) assert(model.clusterCenters.head ~== center absTol 1E-5) - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) + model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = RANDOM) assert(model.clusterCenters.head ~== center absTol 1E-5) - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head ~== center absTol 1E-5) - - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head ~== center absTol 1E-5) - - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, - initializationMode = K_MEANS_PARALLEL) + model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = K_MEANS_PARALLEL) assert(model.clusterCenters.head ~== center absTol 1E-5) } @@ -165,7 +188,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { data.persist() - // No matter how many runs or iterations we use, we should get one cluster, + // No matter how many iterations we use, we should get one cluster, // centered at the mean of the points val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0))) @@ -179,17 +202,10 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { model = KMeans.train(data, k = 1, maxIterations = 5) assert(model.clusterCenters.head ~== center absTol 1E-5) - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head ~== center absTol 1E-5) - - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) + model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = RANDOM) assert(model.clusterCenters.head ~== center absTol 1E-5) - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head ~== center absTol 1E-5) - - model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, - initializationMode = K_MEANS_PARALLEL) + model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = K_MEANS_PARALLEL) assert(model.clusterCenters.head ~== center absTol 1E-5) data.unpersist() @@ -230,11 +246,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { model = KMeans.train(rdd, k = 5, maxIterations = 10) assert(model.clusterCenters.sortBy(VectorWithCompare(_)) .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) - - // Neither should more runs - model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5) - assert(model.clusterCenters.sortBy(VectorWithCompare(_)) - .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) } test("two clusters") { @@ -250,7 +261,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) { // Two iterations are sufficient no matter where the initial centers are. - val model = KMeans.train(rdd, k = 2, maxIterations = 2, runs = 1, initMode) + val model = KMeans.train(rdd, k = 2, maxIterations = 2, initMode) val predicts = model.predict(rdd).collect() From 3ad99f166494950665c137fd5dea636afa0feb10 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 30 Oct 2016 20:27:38 +0800 Subject: [PATCH 051/381] [SPARK-18146][SQL] Avoid using Union to chain together create table and repair partition commands ## What changes were proposed in this pull request? The behavior of union is not well defined here. It is safer to explicitly execute these commands in order. The other use of `Union` in this way will be removed by https://github.com/apache/spark/pull/15633 ## How was this patch tested? Existing tests. cc yhuai cloud-fan Author: Eric Liang Author: Eric Liang Closes #15665 from ericl/spark-18146. --- .../scala/org/apache/spark/sql/DataFrameWriter.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7ff3522f547d..11dd1df90993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -388,16 +388,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec ) - val createCmd = CreateTable(tableDesc, mode, Some(df.logicalPlan)) - val cmd = if (tableDesc.partitionColumnNames.nonEmpty && + df.sparkSession.sessionState.executePlan( + CreateTable(tableDesc, mode, Some(df.logicalPlan))).toRdd + if (tableDesc.partitionColumnNames.nonEmpty && df.sparkSession.sqlContext.conf.manageFilesourcePartitions) { // Need to recover partitions into the metastore so our saved data is visible. - val recoverPartitionCmd = AlterTableRecoverPartitionsCommand(tableDesc.identifier) - Union(createCmd, recoverPartitionCmd) - } else { - createCmd + df.sparkSession.sessionState.executePlan( + AlterTableRecoverPartitionsCommand(tableDesc.identifier)).toRdd } - df.sparkSession.sessionState.executePlan(cmd).toRdd } } From 90d3b91f4cb59d84fea7105d54ef8c87a7d5c6a2 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 30 Oct 2016 13:14:45 -0700 Subject: [PATCH 052/381] [SPARK-18103][SQL] Rename *FileCatalog to *FileIndex ## What changes were proposed in this pull request? To reduce the number of components in SQL named *Catalog, rename *FileCatalog to *FileIndex. A FileIndex is responsible for returning the list of partitions / files to scan given a filtering expression. ``` TableFileCatalog => CatalogFileIndex FileCatalog => FileIndex ListingFileCatalog => InMemoryFileIndex MetadataLogFileCatalog => MetadataLogFileIndex PrunedTableFileCatalog => PrunedInMemoryFileIndex ``` cc yhuai marmbrus ## How was this patch tested? N/A Author: Eric Liang Author: Eric Liang Closes #15634 from ericl/rename-file-provider. --- .../spark/metrics/source/StaticSources.scala | 2 +- .../spark/sql/execution/CacheManager.scala | 2 +- ...leCatalog.scala => CatalogFileIndex.scala} | 24 ++++++------- .../execution/datasources/DataSource.scala | 10 +++--- .../{FileCatalog.scala => FileIndex.scala} | 2 +- .../datasources/HadoopFsRelation.scala | 4 +-- ...eCatalog.scala => InMemoryFileIndex.scala} | 8 ++--- ...scala => PartitioningAwareFileIndex.scala} | 16 ++++----- .../PruneFileSourcePartitions.scala | 6 ++-- .../streaming/CompactibleFileStreamLog.scala | 4 +-- .../streaming/FileStreamSource.scala | 4 +-- .../streaming/MetadataLogFileCatalog.scala | 6 ++-- .../datasources/FileCatalogSuite.scala | 36 +++++++++---------- .../datasources/FileSourceStrategySuite.scala | 2 +- .../ParquetPartitionDiscoverySuite.scala | 2 +- .../sql/streaming/FileStreamSinkSuite.scala | 6 ++-- .../sql/streaming/FileStreamSourceSuite.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +-- .../spark/sql/hive/CachedTableSuite.scala | 10 +++--- .../hive/PartitionedTablePerfStatsSuite.scala | 2 +- .../PruneFileSourcePartitionsSuite.scala | 6 ++-- 21 files changed, 79 insertions(+), 79 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/{TableFileCatalog.scala => CatalogFileIndex.scala} (83%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/{FileCatalog.scala => FileIndex.scala} (99%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/{ListingFileCatalog.scala => InMemoryFileIndex.scala} (92%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/{PartitioningAwareFileCatalog.scala => PartitioningAwareFileIndex.scala} (96%) diff --git a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala index b54885b7ff8b..3f7cfd9d2c11 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala @@ -76,7 +76,7 @@ object HiveCatalogMetrics extends Source { val METRIC_PARTITIONS_FETCHED = metricRegistry.counter(MetricRegistry.name("partitionsFetched")) /** - * Tracks the total number of files discovered off of the filesystem by ListingFileCatalog. + * Tracks the total number of files discovered off of the filesystem by InMemoryFileIndex. */ val METRIC_FILES_DISCOVERED = metricRegistry.counter(MetricRegistry.name("filesDiscovered")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index fb72c679e362..526623a36d2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -177,7 +177,7 @@ class CacheManager extends Logging { /** * Traverses a given `plan` and searches for the occurrences of `qualifiedPath` in the - * [[org.apache.spark.sql.execution.datasources.FileCatalog]] of any [[HadoopFsRelation]] nodes + * [[org.apache.spark.sql.execution.datasources.FileIndex]] of any [[HadoopFsRelation]] nodes * in the plan. If found, we refresh the metadata and return true. Otherwise, this method returns * false. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala similarity index 83% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index b459df5734d4..092aabc89a36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -26,23 +26,23 @@ import org.apache.spark.sql.types.StructType /** - * A [[FileCatalog]] for a metastore catalog table. + * A [[FileIndex]] for a metastore catalog table. * * @param sparkSession a [[SparkSession]] * @param table the metadata of the table * @param sizeInBytes the table's data size in bytes */ -class TableFileCatalog( +class CatalogFileIndex( sparkSession: SparkSession, val table: CatalogTable, - override val sizeInBytes: Long) extends FileCatalog { + override val sizeInBytes: Long) extends FileIndex { protected val hadoopConf = sparkSession.sessionState.newHadoopConf private val fileStatusCache = FileStatusCache.newCache(sparkSession) assert(table.identifier.database.isDefined, - "The table identifier must be qualified in TableFileCatalog") + "The table identifier must be qualified in CatalogFileIndex") private val baseLocation = table.storage.locationUri @@ -57,12 +57,12 @@ class TableFileCatalog( override def refresh(): Unit = fileStatusCache.invalidateAll() /** - * Returns a [[ListingFileCatalog]] for this table restricted to the subset of partitions + * Returns a [[InMemoryFileIndex]] for this table restricted to the subset of partitions * specified by the given partition-pruning filters. * * @param filters partition-pruning filters */ - def filterPartitions(filters: Seq[Expression]): ListingFileCatalog = { + def filterPartitions(filters: Seq[Expression]): InMemoryFileIndex = { if (table.partitionColumnNames.nonEmpty) { val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( table.identifier, filters) @@ -70,20 +70,20 @@ class TableFileCatalog( PartitionPath(p.toRow(partitionSchema), p.storage.locationUri.get) } val partitionSpec = PartitionSpec(partitionSchema, partitions) - new PrunedTableFileCatalog( + new PrunedInMemoryFileIndex( sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec) } else { - new ListingFileCatalog(sparkSession, rootPaths, table.storage.properties, None) + new InMemoryFileIndex(sparkSession, rootPaths, table.storage.properties, None) } } override def inputFiles: Array[String] = filterPartitions(Nil).inputFiles - // `TableFileCatalog` may be a member of `HadoopFsRelation`, `HadoopFsRelation` may be a member + // `CatalogFileIndex` may be a member of `HadoopFsRelation`, `HadoopFsRelation` may be a member // of `LogicalRelation`, and `LogicalRelation` may be used as the cache key. So we need to // implement `equals` and `hashCode` here, to make it work with cache lookup. override def equals(o: Any): Boolean = o match { - case other: TableFileCatalog => this.table.identifier == other.table.identifier + case other: CatalogFileIndex => this.table.identifier == other.table.identifier case _ => false } @@ -97,12 +97,12 @@ class TableFileCatalog( * @param tableBasePath The default base path of the Hive metastore table * @param partitionSpec The partition specifications from Hive metastore */ -private class PrunedTableFileCatalog( +private class PrunedInMemoryFileIndex( sparkSession: SparkSession, tableBasePath: Path, fileStatusCache: FileStatusCache, override val partitionSpec: PartitionSpec) - extends ListingFileCatalog( + extends InMemoryFileIndex( sparkSession, partitionSpec.partitions.map(_.path), Map.empty, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 5b8f05a39624..996109865fdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -202,7 +202,7 @@ case class DataSource( val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - val fileCatalog = new ListingFileCatalog(sparkSession, globbedPaths, options, None) + val fileCatalog = new InMemoryFileIndex(sparkSession, globbedPaths, options, None) val partitionSchema = fileCatalog.partitionSpec().partitionColumns val inferred = format.inferSchema( sparkSession, @@ -364,7 +364,7 @@ case class DataSource( case (format: FileFormat, _) if hasMetadata(caseInsensitiveOptions.get("path").toSeq ++ paths) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) - val fileCatalog = new MetadataLogFileCatalog(sparkSession, basePath) + val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sparkSession, @@ -417,12 +417,12 @@ case class DataSource( val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && catalogTable.isDefined && catalogTable.get.partitionProviderIsHive) { - new TableFileCatalog( + new CatalogFileIndex( sparkSession, catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(0L)) } else { - new ListingFileCatalog( + new InMemoryFileIndex( sparkSession, globbedPaths, options, partitionSchema) } @@ -433,7 +433,7 @@ case class DataSource( format.inferSchema( sparkSession, caseInsensitiveOptions, - fileCatalog.asInstanceOf[ListingFileCatalog].allFiles()) + fileCatalog.asInstanceOf[InMemoryFileIndex].allFiles()) }.getOrElse { throw new AnalysisException( s"Unable to infer schema for $format at ${allPaths.take(2).mkString(",")}. " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCatalog.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala index dba64624c34b..277223d52ec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala @@ -33,7 +33,7 @@ case class PartitionDirectory(values: InternalRow, files: Seq[FileStatus]) * An interface for objects capable of enumerating the root paths of a relation as well as the * partitions of a relation subject to some pruning expressions. */ -trait FileCatalog { +trait FileIndex { /** * Returns the list of root input paths from which the catalog will get files. There may be a diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index afad8898089b..014abd454f5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.StructType * Acts as a container for all of the metadata required to read from a datasource. All discovery, * resolution and merging logic for schemas and partitions has been removed. * - * @param location A [[FileCatalog]] that can enumerate the locations of all the files that + * @param location A [[FileIndex]] that can enumerate the locations of all the files that * comprise this relation. * @param partitionSchema The schema of the columns (if any) that are used to partition the relation * @param dataSchema The schema of any remaining columns. Note that if any partition columns are @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.StructType * @param options Configuration used when reading / writing data. */ case class HadoopFsRelation( - location: FileCatalog, + location: FileIndex, partitionSchema: StructType, dataSchema: StructType, bucketSpec: Option[BucketSpec], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index d9d588388aaf..7531f0ae02e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types.StructType /** - * A [[FileCatalog]] that generates the list of files to process by recursively listing all the + * A [[FileIndex]] that generates the list of files to process by recursively listing all the * files present in `paths`. * * @param rootPaths the list of root table paths to scan @@ -34,13 +34,13 @@ import org.apache.spark.sql.types.StructType * @param partitionSchema an optional partition schema that will be use to provide types for the * discovered partitions */ -class ListingFileCatalog( +class InMemoryFileIndex( sparkSession: SparkSession, override val rootPaths: Seq[Path], parameters: Map[String, String], partitionSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) - extends PartitioningAwareFileCatalog( + extends PartitioningAwareFileIndex( sparkSession, parameters, partitionSchema, fileStatusCache) { @volatile private var cachedLeafFiles: mutable.LinkedHashMap[Path, FileStatus] = _ @@ -79,7 +79,7 @@ class ListingFileCatalog( } override def equals(other: Any): Boolean = other match { - case hdfs: ListingFileCatalog => rootPaths.toSet == hdfs.rootPaths.toSet + case hdfs: InMemoryFileIndex => rootPaths.toSet == hdfs.rootPaths.toSet case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index cc4049e92590..a8a722dd3c62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -34,19 +34,19 @@ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration /** - * An abstract class that represents [[FileCatalog]]s that are aware of partitioned tables. + * An abstract class that represents [[FileIndex]]s that are aware of partitioned tables. * It provides the necessary methods to parse partition data based on a set of files. * * @param parameters as set of options to control partition discovery * @param userPartitionSchema an optional partition schema that will be use to provide types for * the discovered partitions */ -abstract class PartitioningAwareFileCatalog( +abstract class PartitioningAwareFileIndex( sparkSession: SparkSession, parameters: Map[String, String], userPartitionSchema: Option[StructType], - fileStatusCache: FileStatusCache = NoopCache) extends FileCatalog with Logging { - import PartitioningAwareFileCatalog.BASE_PATH_PARAM + fileStatusCache: FileStatusCache = NoopCache) extends FileIndex with Logging { + import PartitioningAwareFileIndex.BASE_PATH_PARAM /** Returns the specification of the partitions inferred from the data. */ def partitionSpec(): PartitionSpec @@ -253,9 +253,9 @@ abstract class PartitioningAwareFileCatalog( } val discovered = if (pathsToFetch.length >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { - PartitioningAwareFileCatalog.listLeafFilesInParallel(pathsToFetch, hadoopConf, sparkSession) + PartitioningAwareFileIndex.listLeafFilesInParallel(pathsToFetch, hadoopConf, sparkSession) } else { - PartitioningAwareFileCatalog.listLeafFilesInSerial(pathsToFetch, hadoopConf) + PartitioningAwareFileIndex.listLeafFilesInSerial(pathsToFetch, hadoopConf) } discovered.foreach { case (path, leafFiles) => HiveCatalogMetrics.incrementFilesDiscovered(leafFiles.size) @@ -266,7 +266,7 @@ abstract class PartitioningAwareFileCatalog( } } -object PartitioningAwareFileCatalog extends Logging { +object PartitioningAwareFileIndex extends Logging { val BASE_PATH_PARAM = "basePath" /** A serializable variant of HDFS's BlockLocation. */ @@ -383,7 +383,7 @@ object PartitioningAwareFileCatalog extends Logging { if (shouldFilterOut(name)) { Seq.empty[FileStatus] } else { - // [SPARK-17599] Prevent ListingFileCatalog from failing if path doesn't exist + // [SPARK-17599] Prevent InMemoryFileIndex from failing if path doesn't exist // Note that statuses only include FileStatus for the files and dirs directly under path, // and does not include anything else recursively. val statuses = try fs.listStatus(path) catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 8689017c3ed7..8566a8061034 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -28,7 +28,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { logicalRelation @ LogicalRelation(fsRelation @ HadoopFsRelation( - tableFileCatalog: TableFileCatalog, + catalogFileIndex: CatalogFileIndex, partitionSchema, _, _, @@ -56,9 +56,9 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) if (partitionKeyFilters.nonEmpty) { - val prunedFileCatalog = tableFileCatalog.filterPartitions(partitionKeyFilters.toSeq) + val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFsRelation = - fsRelation.copy(location = prunedFileCatalog)(sparkSession) + fsRelation.copy(location = prunedFileIndex)(sparkSession) val prunedLogicalRelation = logicalRelation.copy( relation = prunedFsRelation, expectedOutputAttributes = Some(logicalRelation.output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index c14feea91ed7..b26edeeb0400 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -146,7 +146,7 @@ abstract class CompactibleFileStreamLog[T: ClassTag]( */ def allFiles(): Array[T] = { var latestId = getLatest().map(_._1).getOrElse(-1L) - // There is a race condition when `FileStreamSink` is deleting old files and `StreamFileCatalog` + // There is a race condition when `FileStreamSink` is deleting old files and `StreamFileIndex` // is calling this method. This loop will retry the reading to deal with the // race condition. while (true) { @@ -158,7 +158,7 @@ abstract class CompactibleFileStreamLog[T: ClassTag]( } catch { case e: IOException => // Another process using `CompactibleFileStreamLog` may delete the batch files when - // `StreamFileCatalog` are reading. However, it only happens when a compaction is + // `StreamFileIndex` are reading. However, it only happens when a compaction is // deleting old files. If so, let's try the next compaction batch and we should find it. // Otherwise, this is a real IO issue and we should throw it. latestId = nextCompactionBatchId(latestId, compactInterval) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index a392b8299902..680df01acc1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.execution.datasources.{DataSource, ListingFileCatalog, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{DataSource, InMemoryFileIndex, LogicalRelation} import org.apache.spark.sql.types.StructType /** @@ -156,7 +156,7 @@ class FileStreamSource( private def fetchAllFiles(): Seq[(String, Long)] = { val startTime = System.nanoTime val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) - val catalog = new ListingFileCatalog(sparkSession, globbedPaths, options, Some(new StructType)) + val catalog = new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(new StructType)) val files = catalog.allFiles().sortBy(_.getModificationTime).map { status => (status.getPath.toUri.toString, status.getModificationTime) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala index 82b67cb1ca6e..aeaa13473693 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala @@ -26,11 +26,11 @@ import org.apache.spark.sql.execution.datasources._ /** - * A [[FileCatalog]] that generates the list of files to processing by reading them from the + * A [[FileIndex]] that generates the list of files to processing by reading them from the * metadata log files generated by the [[FileStreamSink]]. */ -class MetadataLogFileCatalog(sparkSession: SparkSession, path: Path) - extends PartitioningAwareFileCatalog(sparkSession, Map.empty, None) { +class MetadataLogFileIndex(sparkSession: SparkSession, path: Path) + extends PartitioningAwareFileIndex(sparkSession, Map.empty, None) { private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) logInfo(s"Reading streaming file log from $metadataDirectory") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala index 9c43169cbf89..56df1face636 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala @@ -28,15 +28,15 @@ import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.test.SharedSQLContext -class FileCatalogSuite extends SharedSQLContext { +class FileIndexSuite extends SharedSQLContext { - test("ListingFileCatalog: leaf files are qualified paths") { + test("InMemoryFileIndex: leaf files are qualified paths") { withTempDir { dir => val file = new File(dir, "text.txt") stringToFile(file, "text") val path = new Path(file.getCanonicalPath) - val catalog = new ListingFileCatalog(spark, Seq(path), Map.empty, None) { + val catalog = new InMemoryFileIndex(spark, Seq(path), Map.empty, None) { def leafFilePaths: Seq[Path] = leafFiles.keys.toSeq def leafDirPaths: Seq[Path] = leafDirToChildrenFiles.keys.toSeq } @@ -45,7 +45,7 @@ class FileCatalogSuite extends SharedSQLContext { } } - test("ListingFileCatalog: input paths are converted to qualified paths") { + test("InMemoryFileIndex: input paths are converted to qualified paths") { withTempDir { dir => val file = new File(dir, "text.txt") stringToFile(file, "text") @@ -59,42 +59,42 @@ class FileCatalogSuite extends SharedSQLContext { val qualifiedFilePath = fs.makeQualified(new Path(file.getCanonicalPath)) require(qualifiedFilePath.toString.startsWith("file:")) - val catalog1 = new ListingFileCatalog( + val catalog1 = new InMemoryFileIndex( spark, Seq(unqualifiedDirPath), Map.empty, None) assert(catalog1.allFiles.map(_.getPath) === Seq(qualifiedFilePath)) - val catalog2 = new ListingFileCatalog( + val catalog2 = new InMemoryFileIndex( spark, Seq(unqualifiedFilePath), Map.empty, None) assert(catalog2.allFiles.map(_.getPath) === Seq(qualifiedFilePath)) } } - test("ListingFileCatalog: folders that don't exist don't throw exceptions") { + test("InMemoryFileIndex: folders that don't exist don't throw exceptions") { withTempDir { dir => val deletedFolder = new File(dir, "deleted") assert(!deletedFolder.exists()) - val catalog1 = new ListingFileCatalog( + val catalog1 = new InMemoryFileIndex( spark, Seq(new Path(deletedFolder.getCanonicalPath)), Map.empty, None) // doesn't throw an exception assert(catalog1.listLeafFiles(catalog1.rootPaths).isEmpty) } } - test("PartitioningAwareFileCatalog - file filtering") { - assert(!PartitioningAwareFileCatalog.shouldFilterOut("abcd")) - assert(PartitioningAwareFileCatalog.shouldFilterOut(".ab")) - assert(PartitioningAwareFileCatalog.shouldFilterOut("_cd")) - assert(!PartitioningAwareFileCatalog.shouldFilterOut("_metadata")) - assert(!PartitioningAwareFileCatalog.shouldFilterOut("_common_metadata")) - assert(PartitioningAwareFileCatalog.shouldFilterOut("_ab_metadata")) - assert(PartitioningAwareFileCatalog.shouldFilterOut("_cd_common_metadata")) + test("PartitioningAwareFileIndex - file filtering") { + assert(!PartitioningAwareFileIndex.shouldFilterOut("abcd")) + assert(PartitioningAwareFileIndex.shouldFilterOut(".ab")) + assert(PartitioningAwareFileIndex.shouldFilterOut("_cd")) + assert(!PartitioningAwareFileIndex.shouldFilterOut("_metadata")) + assert(!PartitioningAwareFileIndex.shouldFilterOut("_common_metadata")) + assert(PartitioningAwareFileIndex.shouldFilterOut("_ab_metadata")) + assert(PartitioningAwareFileIndex.shouldFilterOut("_cd_common_metadata")) } - test("SPARK-17613 - PartitioningAwareFileCatalog: base path w/o '/' at end") { + test("SPARK-17613 - PartitioningAwareFileIndex: base path w/o '/' at end") { class MockCatalog( override val rootPaths: Seq[Path]) - extends PartitioningAwareFileCatalog(spark, Map.empty, None) { + extends PartitioningAwareFileIndex(spark, Map.empty, None) { override def refresh(): Unit = {} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c32254d9dfde..d900ce7bb237 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -393,7 +393,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi util.stringToFile(file, fileName) } - val fileCatalog = new ListingFileCatalog( + val fileCatalog = new InMemoryFileIndex( sparkSession = spark, rootPaths = Seq(new Path(tempDir)), parameters = Map.empty[String, String], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index f2a209e91962..120a3a2ef33a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -634,7 +634,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val queryExecution = spark.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { case LogicalRelation( - HadoopFsRelation(location: PartitioningAwareFileCatalog, _, _, _, _, _), _, _) => + HadoopFsRelation(location: PartitioningAwareFileIndex, _, _, _, _, _), _, _) => assert(location.partitionSpec() === PartitionSpec.emptySpec) }.getOrElse { fail(s"Expecting a matching HadoopFsRelation, but got:\n$queryExecution") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 19c89f5c4100..18b42a81a098 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{FileStreamSinkWriter, MemoryStream, MetadataLogFileCatalog} +import org.apache.spark.sql.execution.streaming.{FileStreamSinkWriter, MemoryStream, MetadataLogFileIndex} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -179,14 +179,14 @@ class FileStreamSinkSuite extends StreamTest { .add(StructField("id", IntegerType)) assert(outputDf.schema === expectedSchema) - // Verify that MetadataLogFileCatalog is being used and the correct partitioning schema has + // Verify that MetadataLogFileIndex is being used and the correct partitioning schema has // been inferred val hadoopdFsRelations = outputDf.queryExecution.analyzed.collect { case LogicalRelation(baseRelation, _, _) if baseRelation.isInstanceOf[HadoopFsRelation] => baseRelation.asInstanceOf[HadoopFsRelation] } assert(hadoopdFsRelations.size === 1) - assert(hadoopdFsRelations.head.location.isInstanceOf[MetadataLogFileCatalog]) + assert(hadoopdFsRelations.head.location.isInstanceOf[MetadataLogFileIndex]) assert(hadoopdFsRelations.head.partitionSchema.exists(_.name == "id")) assert(hadoopdFsRelations.head.dataSchema.exists(_.name == "value")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index b9e9da9a1ec5..47018b3a3c49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -879,7 +879,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val numFiles = 10000 // This is to avoid running a spark job to list of files in parallel - // by the ListingFileCatalog. + // by the InMemoryFileIndex. spark.sessionState.conf.setConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD, numFiles * 2) withTempDirs { case (root, tmp) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index d1de863ce362..624ab747e442 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -200,7 +200,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val rootPaths: Seq[Path] = if (lazyPruningEnabled) { Seq(metastoreRelation.hiveQlTable.getDataLocation) } else { - // By convention (for example, see TableFileCatalog), the definition of a + // By convention (for example, see CatalogFileIndex), the definition of a // partitioned table's paths depends on whether that table has any actual partitions. // Partitioned tables without partitions use the location of the table's base path. // Partitioned tables with partitions use the locations of those partitions' data @@ -227,7 +227,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val logicalRelation = cached.getOrElse { val sizeInBytes = metastoreRelation.statistics.sizeInBytes.toLong val fileCatalog = { - val catalog = new TableFileCatalog( + val catalog = new CatalogFileIndex( sparkSession, metastoreRelation.catalogTable, sizeInBytes) if (lazyPruningEnabled) { catalog diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index ecdf4f14b398..fc35304c80ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset, QueryTest, SaveMode} import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, TableFileCatalog} +import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -321,17 +321,17 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("DROP TABLE cachedTable") } - test("cache a table using TableFileCatalog") { + test("cache a table using CatalogFileIndex") { withTable("test") { sql("CREATE TABLE test(i int) PARTITIONED BY (p int) STORED AS parquet") val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test") - val tableFileCatalog = new TableFileCatalog(spark, tableMeta, 0) + val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0) val dataSchema = StructType(tableMeta.schema.filterNot { f => tableMeta.partitionColumnNames.contains(f.name) }) val relation = HadoopFsRelation( - location = tableFileCatalog, + location = catalogFileIndex, partitionSchema = tableMeta.partitionSchema, dataSchema = dataSchema, bucketSpec = None, @@ -343,7 +343,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto assert(spark.sharedState.cacheManager.lookupCachedData(plan).isDefined) - val sameCatalog = new TableFileCatalog(spark, tableMeta, 0) + val sameCatalog = new CatalogFileIndex(spark, tableMeta, 0) val sameRelation = HadoopFsRelation( location = sameCatalog, partitionSchema = tableMeta.partitionSchema, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 476383a5b33a..d8e31c4e39a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -256,7 +256,7 @@ class PartitionedTablePerfStatsSuite // of doing plan cache validation based on the entire partition set. HiveCatalogMetrics.reset() assert(spark.sql("select * from test where partCol1 = 999").count() == 0) - // 5 from table resolution, another 5 from ListingFileCatalog + // 5 from table resolution, another 5 from InMemoryFileIndex assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount() == 10) assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 5) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index 59639aacf3a3..cdbc26cd5c57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions, TableFileCatalog} +import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -45,13 +45,13 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te |LOCATION '${dir.getAbsolutePath}'""".stripMargin) val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test") - val tableFileCatalog = new TableFileCatalog(spark, tableMeta, 0) + val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0) val dataSchema = StructType(tableMeta.schema.filterNot { f => tableMeta.partitionColumnNames.contains(f.name) }) val relation = HadoopFsRelation( - location = tableFileCatalog, + location = catalogFileIndex, partitionSchema = tableMeta.partitionSchema, dataSchema = dataSchema, bucketSpec = None, From 8ae2da0b2551011e2f6cf02907a1e20c138a4b2f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 30 Oct 2016 23:24:30 +0100 Subject: [PATCH 053/381] [SPARK-18106][SQL] ANALYZE TABLE should raise a ParseException for invalid option ## What changes were proposed in this pull request? Currently, `ANALYZE TABLE` command accepts `identifier` for option `NOSCAN`. This PR raises a ParseException for unknown option. **Before** ```scala scala> sql("create table test(a int)") res0: org.apache.spark.sql.DataFrame = [] scala> sql("analyze table test compute statistics blah") res1: org.apache.spark.sql.DataFrame = [] ``` **After** ```scala scala> sql("create table test(a int)") res0: org.apache.spark.sql.DataFrame = [] scala> sql("analyze table test compute statistics blah") org.apache.spark.sql.catalyst.parser.ParseException: Expected `NOSCAN` instead of `blah`(line 1, pos 0) ``` ## How was this patch tested? Pass the Jenkins test with a new test case. Author: Dongjoon Hyun Closes #15640 from dongjoon-hyun/SPARK-18106. --- .../spark/sql/execution/SparkSqlParser.scala | 10 +++++++--- .../sql/execution/SparkSqlParserSuite.scala | 18 ++++++++++++++++-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 1cc166d5a7a9..fe183d0097d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -98,9 +98,13 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) { - if (ctx.partitionSpec == null && - ctx.identifier != null && - ctx.identifier.getText.toLowerCase == "noscan") { + if (ctx.partitionSpec != null) { + logWarning(s"Partition specification is ignored: ${ctx.partitionSpec.getText}") + } + if (ctx.identifier != null) { + if (ctx.identifier.getText.toLowerCase != "noscan") { + throw new ParseException(s"Expected `NOSCAN` instead of `${ctx.identifier.getText}`", ctx) + } AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier)) } else if (ctx.identifierSeq() == null) { AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier), noscan = false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 679150e9ae4c..797fe9ffa8be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.{DescribeFunctionCommand, DescribeTableCommand, - ShowFunctionsCommand} +import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DescribeFunctionCommand, + DescribeTableCommand, ShowFunctionsCommand} import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} @@ -220,4 +220,18 @@ class SparkSqlParserSuite extends PlanTest { intercept("explain describe tables x", "Unsupported SQL statement") } + + test("SPARK-18106 analyze table") { + assertEqual("analyze table t compute statistics", + AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) + assertEqual("analyze table t compute statistics noscan", + AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + assertEqual("analyze table t partition (a) compute statistics noscan", + AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + + intercept("analyze table t compute statistics xxxx", + "Expected `NOSCAN` instead of `xxxx`") + intercept("analyze table t partition (a) compute statistics xxxx", + "Expected `NOSCAN` instead of `xxxx`") + } } From 2881a2d1d1a650a91df2c6a01275eba14a43b42a Mon Sep 17 00:00:00 2001 From: Hossein Date: Sun, 30 Oct 2016 16:17:23 -0700 Subject: [PATCH 054/381] [SPARK-17919] Make timeout to RBackend configurable in SparkR ## What changes were proposed in this pull request? This patch makes RBackend connection timeout configurable by user. ## How was this patch tested? N/A Author: Hossein Closes #15471 from falaki/SPARK-17919. --- R/pkg/R/backend.R | 20 ++++++++-- R/pkg/R/client.R | 2 +- R/pkg/R/sparkR.R | 8 +++- R/pkg/inst/worker/daemon.R | 4 +- R/pkg/inst/worker/worker.R | 7 +++- .../org/apache/spark/api/r/RBackend.scala | 15 ++++++- .../apache/spark/api/r/RBackendHandler.scala | 39 +++++++++++++++++-- .../org/apache/spark/api/r/RRunner.scala | 3 ++ .../apache/spark/api/r/SparkRDefaults.scala | 30 ++++++++++++++ .../org/apache/spark/deploy/RRunner.scala | 7 +++- docs/configuration.md | 15 +++++++ 11 files changed, 134 insertions(+), 16 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 03e70bb2cb82..0a789e6c379d 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -108,13 +108,27 @@ invokeJava <- function(isStatic, objId, methodName, ...) { conn <- get(".sparkRCon", .sparkREnv) writeBin(requestMessage, conn) - # TODO: check the status code to output error information returnStatus <- readInt(conn) + handleErrors(returnStatus, conn) + + # Backend will send +1 as keep alive value to prevent various connection timeouts + # on very long running jobs. See spark.r.heartBeatInterval + while (returnStatus == 1) { + returnStatus <- readInt(conn) + handleErrors(returnStatus, conn) + } + + readObject(conn) +} + +# Helper function to check for returned errors and print appropriate error message to user +handleErrors <- function(returnStatus, conn) { if (length(returnStatus) == 0) { stop("No status is returned. Java SparkR backend might have failed.") } - if (returnStatus != 0) { + + # 0 is success and +1 is reserved for heartbeats. Other negative values indicate errors. + if (returnStatus < 0) { stop(readString(conn)) } - readObject(conn) } diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 2d341d836c13..9d82814211bc 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -19,7 +19,7 @@ # Creates a SparkR client connection object # if one doesn't already exist -connectBackend <- function(hostname, port, timeout = 6000) { +connectBackend <- function(hostname, port, timeout) { if (exists(".sparkRcon", envir = .sparkREnv)) { if (isOpen(.sparkREnv[[".sparkRCon"]])) { cat("SparkRBackend client connection already exists\n") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index cc6d591bb2f4..6b4a2f2fdc85 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -154,6 +154,7 @@ sparkR.sparkContext <- function( packages <- processSparkPackages(sparkPackages) existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") + connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) if (existingPort != "") { if (length(packages) != 0) { warning(paste("sparkPackages has no effect when using spark-submit or sparkR shell", @@ -187,6 +188,7 @@ sparkR.sparkContext <- function( backendPort <- readInt(f) monitorPort <- readInt(f) rLibPath <- readString(f) + connectionTimeout <- readInt(f) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || @@ -194,7 +196,9 @@ sparkR.sparkContext <- function( length(rLibPath) != 1) { stop("JVM failed to launch") } - assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) + assign(".monitorConn", + socketConnection(port = monitorPort, timeout = connectionTimeout), + envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) if (rLibPath != "") { assign(".libPath", rLibPath, envir = .sparkREnv) @@ -204,7 +208,7 @@ sparkR.sparkContext <- function( .sparkREnv$backendPort <- backendPort tryCatch({ - connectBackend("localhost", backendPort) + connectBackend("localhost", backendPort, timeout = connectionTimeout) }, error = function(err) { stop("Failed to connect JVM\n") diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index b92e6be995ca..3a318b71ea06 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -18,6 +18,7 @@ # Worker daemon rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) dirs <- strsplit(rLibDir, ",")[[1]] script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") @@ -26,7 +27,8 @@ script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) -inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600) +inputCon <- socketConnection( + port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) while (TRUE) { ready <- socketSelect(list(inputCon)) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index cfe41ded200c..03e745014786 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -90,6 +90,7 @@ bootTime <- currentTimeSecs() bootElap <- elapsedSecs() rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) dirs <- strsplit(rLibDir, ",")[[1]] # Set libPaths to include SparkR package as loadNamespace needs this # TODO: Figure out if we can avoid this by not loading any objects that require @@ -98,8 +99,10 @@ dirs <- strsplit(rLibDir, ",")[[1]] suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) -inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb") -outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb") +inputCon <- socketConnection( + port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout) +outputCon <- socketConnection( + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) # read the index of the current partition inside the RDD partition <- SparkR:::readInt(inputCon) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 41d0a85ee3ad..550746c552d0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -22,12 +22,13 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.LengthFieldBasedFrameDecoder import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} +import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging @@ -43,7 +44,10 @@ private[spark] class RBackend { def init(): Int = { val conf = new SparkConf() - bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2)) + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) + bossGroup = new NioEventLoopGroup( + conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS)) val workerGroup = bossGroup val handler = new RBackendHandler(this) @@ -63,6 +67,7 @@ private[spark] class RBackend { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) + .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout)) .addLast("handler", handler) } }) @@ -110,6 +115,11 @@ private[spark] object RBackend extends Logging { val boundPort = sparkRBackend.init() val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() + // Connection timeout is set by socket client. To make it configurable we will pass the + // timeout value to client inside the temp file + val conf = new SparkConf() + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) // tell the R process via temporary file val path = args(0) @@ -118,6 +128,7 @@ private[spark] object RBackend extends Logging { dos.writeInt(boundPort) dos.writeInt(listenPort) SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) + dos.writeInt(backendConnectionTimeout) dos.close() f.renameTo(new File(path)) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 1422ef888fd4..9f5afa29d6d2 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -18,16 +18,19 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.util.concurrent.TimeUnit import scala.collection.mutable.HashMap import scala.language.existentials import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import io.netty.channel.ChannelHandler.Sharable +import io.netty.handler.timeout.ReadTimeoutException import org.apache.spark.api.r.SerDe._ import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.SparkConf +import org.apache.spark.util.{ThreadUtils, Utils} /** * Handler for RBackend @@ -83,7 +86,29 @@ private[r] class RBackendHandler(server: RBackend) writeString(dos, s"Error: unknown method $methodName") } } else { + // To avoid timeouts when reading results in SparkR driver, we will be regularly sending + // heartbeat responses. We use special code +1 to signal the client that backend is + // alive and it should continue blocking for result. + val execService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("SparkRKeepAliveThread") + val pingRunner = new Runnable { + override def run(): Unit = { + val pingBaos = new ByteArrayOutputStream() + val pingDaos = new DataOutputStream(pingBaos) + writeInt(pingDaos, +1) + ctx.write(pingBaos.toByteArray) + } + } + val conf = new SparkConf() + val heartBeatInterval = conf.getInt( + "spark.r.heartBeatInterval", SparkRDefaults.DEFAULT_HEARTBEAT_INTERVAL) + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) + val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1) + + execService.scheduleAtFixedRate(pingRunner, interval, interval, TimeUnit.SECONDS) handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + execService.shutdown() + execService.awaitTermination(1, TimeUnit.SECONDS) } val reply = bos.toByteArray @@ -95,9 +120,15 @@ private[r] class RBackendHandler(server: RBackend) } override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - // Close the connection when an exception is raised. - cause.printStackTrace() - ctx.close() + cause match { + case timeout: ReadTimeoutException => + // Do nothing. We don't want to timeout on read + logWarning("Ignoring read timeout in RBackendHandler") + case _ => + // Close the connection when an exception is raised. + cause.printStackTrace() + ctx.close() + } } def handleMethodCall( diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 496fdf851f7d..7ef64723d959 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -333,6 +333,8 @@ private[r] object RRunner { var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") rCommand = sparkConf.get("spark.r.command", rCommand) + val rConnectionTimeout = sparkConf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) val rOptions = "--vanilla" val rLibDir = RUtils.sparkRPackagePath(isDriver = false) val rExecScript = rLibDir(0) + "/SparkR/worker/" + script @@ -344,6 +346,7 @@ private[r] object RRunner { pb.environment().put("R_TESTS", "") pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) diff --git a/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala new file mode 100644 index 000000000000..af67cbbce4e5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +private[spark] object SparkRDefaults { + + // Default value for spark.r.backendConnectionTimeout config + val DEFAULT_CONNECTION_TIMEOUT: Int = 6000 + + // Default value for spark.r.heartBeatInterval config + val DEFAULT_HEARTBEAT_INTERVAL: Int = 100 + + // Default value for spark.r.numRBackendThreads config + val DEFAULT_NUM_RBACKEND_THREADS = 2 +} diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index d0466830b217..6eb53a825220 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkException, SparkUserAppException} -import org.apache.spark.api.r.{RBackend, RUtils} +import org.apache.spark.api.r.{RBackend, RUtils, SparkRDefaults} import org.apache.spark.util.RedirectThread /** @@ -51,6 +51,10 @@ object RRunner { cmd } + // Connection timeout set by R process on its connection to RBackend in seconds. + val backendConnectionTimeout = sys.props.getOrElse( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT.toString) + // Check if the file path exists. // If not, change directory to current working directory for YARN cluster mode val rF = new File(rFile) @@ -81,6 +85,7 @@ object RRunner { val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava) val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) + env.put("SPARKR_BACKEND_CONNECTION_TIMEOUT", backendConnectionTimeout) val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) // Put the R package directories into an env variable of comma-separated paths env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) diff --git a/docs/configuration.md b/docs/configuration.md index 6600cb6c0ac0..780fc94908d3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1890,6 +1890,21 @@ showDF(properties, numRows = 200, truncate = FALSE) spark.r.shell.command is used for sparkR shell while spark.r.driver.command is used for running R script. + + spark.r.backendConnectionTimeout + 6000 + + Connection timeout set by R process on its connection to RBackend in seconds. + + + + spark.r.heartBeatInterval + 100 + + Interval for heartbeats sents from SparkR backend to R process to prevent connection timeout. + + + #### Deploy From b6879b8b3518c71c23262554fcb0fdad60287011 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 30 Oct 2016 16:19:19 -0700 Subject: [PATCH 055/381] [SPARK-16137][SPARKR] randomForest for R ## What changes were proposed in this pull request? Random Forest Regression and Classification for R Clean-up/reordering generics.R ## How was this patch tested? manual tests, unit tests Author: Felix Cheung Closes #15607 from felixcheung/rrandomforest. --- R/pkg/NAMESPACE | 9 +- R/pkg/R/generics.R | 66 ++--- R/pkg/R/mllib.R | 252 +++++++++++++++++- R/pkg/inst/tests/testthat/test_mllib.R | 68 +++++ .../org/apache/spark/ml/r/RWrappers.scala | 4 + .../r/RandomForestClassificationWrapper.scala | 147 ++++++++++ .../ml/r/RandomForestRegressionWrapper.scala | 144 ++++++++++ 7 files changed, 656 insertions(+), 34 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7a89c01fee73..9cd6269f9a8f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -44,7 +44,8 @@ exportMethods("glm", "spark.gaussianMixture", "spark.als", "spark.kstest", - "spark.logit") + "spark.logit", + "spark.randomForest") # Job group lifecycle management methods export("setJobGroup", @@ -350,7 +351,9 @@ export("as.DataFrame", "uncacheTable", "print.summary.GeneralizedLinearRegressionModel", "read.ml", - "print.summary.KSTest") + "print.summary.KSTest", + "print.summary.RandomForestRegressionModel", + "print.summary.RandomForestClassificationModel") export("structField", "structField.jobj", @@ -375,6 +378,8 @@ S3method(print, structField) S3method(print, structType) S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) +S3method(print, summary.RandomForestRegressionModel) +S3method(print, summary.RandomForestClassificationModel) S3method(structField, character) S3method(structField, jobj) S3method(structType, jobj) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 107e1c638be7..0271b26a10a9 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1310,9 +1310,11 @@ setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @export setGeneric("year", function(x) { standardGeneric("year") }) -#' @rdname spark.glm +###################### Spark.ML Methods ########################## + +#' @rdname fitted #' @export -setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) +setGeneric("fitted") #' @param x,y For \code{glm}: logical values indicating whether the response vector #' and model matrix used in the fitting process should be returned as @@ -1332,13 +1334,38 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @export setGeneric("rbind", signature = "...") +#' @rdname spark.als +#' @export +setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) + +#' @rdname spark.gaussianMixture +#' @export +setGeneric("spark.gaussianMixture", + function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) + +#' @rdname spark.glm +#' @export +setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) + +#' @rdname spark.isoreg +#' @export +setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) + #' @rdname spark.kmeans #' @export setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark.kmeans") }) -#' @rdname fitted +#' @rdname spark.kstest #' @export -setGeneric("fitted") +setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) + +#' @rdname spark.lda +#' @export +setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) + +#' @rdname spark.logit +#' @export +setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @rdname spark.mlp #' @export @@ -1348,13 +1375,14 @@ setGeneric("spark.mlp", function(data, ...) { standardGeneric("spark.mlp") }) #' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) -#' @rdname spark.survreg +#' @rdname spark.randomForest #' @export -setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) +setGeneric("spark.randomForest", + function(data, formula, ...) { standardGeneric("spark.randomForest") }) -#' @rdname spark.lda +#' @rdname spark.survreg #' @export -setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) +setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) #' @rdname spark.lda #' @export @@ -1364,20 +1392,6 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark #' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) -#' @rdname spark.isoreg -#' @export -setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) - -#' @rdname spark.gaussianMixture -#' @export -setGeneric("spark.gaussianMixture", - function(data, formula, ...) { - standardGeneric("spark.gaussianMixture") - }) - -#' @rdname spark.logit -#' @export -setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @param object a fitted ML model object. #' @param path the directory where the model is saved. @@ -1385,11 +1399,3 @@ setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark. #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) - -#' @rdname spark.als -#' @export -setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) - -#' @rdname spark.kstest -#' @export -setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 629f284b79f3..7a220b8d53a2 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -102,6 +102,20 @@ setClass("KSTest", representation(jobj = "jobj")) #' @note LogisticRegressionModel since 2.1.0 setClass("LogisticRegressionModel", representation(jobj = "jobj")) +#' S4 class that represents a RandomForestRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel +#' @export +#' @note RandomForestRegressionModel since 2.1.0 +setClass("RandomForestRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a RandomForestClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel +#' @export +#' @note RandomForestClassificationModel since 2.1.0 +setClass("RandomForestClassificationModel", representation(jobj = "jobj")) + #' Saves the MLlib model to the input path #' #' Saves the MLlib model to the input path. For more information, see the specific @@ -112,7 +126,7 @@ setClass("LogisticRegressionModel", representation(jobj = "jobj")) #' @seealso \link{spark.glm}, \link{glm}, #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, #' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, -#' @seealso \link{spark.survreg} +#' @seealso \link{spark.randomForest}, \link{spark.survreg}, #' @seealso \link{read.ml} NULL @@ -125,7 +139,8 @@ NULL #' @export #' @seealso \link{spark.glm}, \link{glm}, #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, -#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} +#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, +#' @seealso \link{spark.randomForest}, \link{spark.survreg} NULL write_internal <- function(object, path, overwrite = FALSE) { @@ -1122,6 +1137,10 @@ read.ml <- function(path) { new("ALSModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) { new("LogisticRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestRegressorWrapper")) { + new("RandomForestRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { + new("RandomForestClassificationModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } @@ -1617,3 +1636,232 @@ print.summary.KSTest <- function(x, ...) { cat(summaryStr, "\n") invisible(x) } + +#' Random Forest Model for Regression and Classification +#' +#' \code{spark.randomForest} fits a Random Forest Regression model or Classification model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Random Forest +#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{Random Forest} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). (default = 5) +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. (default = 32) +#' @param numTrees Number of trees to train (>= 1). +#' @param impurity Criterion used for information gain calculation. +#' For regression, must be "variance". For classification, must be one of +#' "entropy" and "gini". (default = gini) +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param featureSubsetStrategy The number of features to consider for splits at each tree node. +#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. (default = 1.0) +#' @param probabilityCol column name for predicted class conditional probabilities, only for +#' classification. (default = "probability") +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. +#' @param ... additional arguments passed to the method. +#' @aliases spark.randomForest,SparkDataFrame,formula-method +#' @return \code{spark.randomForest} returns a fitted Random Forest model. +#' @rdname spark.randomForest +#' @name spark.randomForest +#' @export +#' @examples +#' \dontrun{ +#' # fit a Random Forest Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Random Forest Classification Model +#' df <- createDataFrame(iris) +#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, "classification") +#' } +#' @note spark.randomForest since 2.1.0 +setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, + probabilityCol = "probability", maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(impurity)) impurity <- "variance" + impurity <- match.arg(impurity, "variance") + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(impurity)) impurity <- "gini" + impurity <- match.arg(impurity, c("gini", "entropy")) + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), as.character(probabilityCol), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestClassificationModel", jobj = jobj) + } + ) + }) + +# Makes predictions from a Random Forest Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction" +#' @rdname spark.randomForest +#' @aliases predict,RandomForestRegressionModel-method +#' @export +#' @note predict(randomForestRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.randomForest +#' @aliases predict,RandomForestClassificationModel-method +#' @export +#' @note predict(randomForestClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Random Forest Regression or Classification model to the input path. + +#' @param object A fitted Random Forest regression model or classification model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,RandomForestRegressionModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,RandomForestClassificationModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +# Get the summary of an RandomForestRegressionModel model +summary.randomForest <- function(model) { + jobj <- model@jobj + formula <- callJMethod(jobj, "formula") + numFeatures <- callJMethod(jobj, "numFeatures") + features <- callJMethod(jobj, "features") + featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + numTrees <- callJMethod(jobj, "numTrees") + treeWeights <- callJMethod(jobj, "treeWeights") + list(formula = formula, + numFeatures = numFeatures, + features = features, + featureImportances = featureImportances, + numTrees = numTrees, + treeWeights = treeWeights, + jobj = jobj) +} + +#' @return \code{summary} returns the model's features as lists, depth and number of nodes +#' or number of classes. +#' @rdname spark.randomForest +#' @aliases summary,RandomForestRegressionModel-method +#' @export +#' @note summary(RandomForestRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestRegressionModel"), + function(object) { + ans <- summary.randomForest(object) + class(ans) <- "summary.RandomForestRegressionModel" + ans + }) + +# Get the summary of an RandomForestClassificationModel model + +#' @rdname spark.randomForest +#' @aliases summary,RandomForestClassificationModel-method +#' @export +#' @note summary(RandomForestClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestClassificationModel"), + function(object) { + ans <- summary.randomForest(object) + class(ans) <- "summary.RandomForestClassificationModel" + ans + }) + +# Prints the summary of Random Forest Regression Model +print.summary.randomForest <- function(x) { + jobj <- x$jobj + cat("Formula: ", x$formula) + cat("\nNumber of features: ", x$numFeatures) + cat("\nFeatures: ", unlist(x$features)) + cat("\nFeature importances: ", x$featureImportances) + cat("\nNumber of trees: ", x$numTrees) + cat("\nTree weights: ", unlist(x$treeWeights)) + + summaryStr <- callJMethod(jobj, "summary") + cat("\n", summaryStr, "\n") + invisible(x) +} + +#' @param x summary object of Random Forest regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestRegressionModel since 2.1.0 +print.summary.RandomForestRegressionModel <- function(x, ...) { + print.summary.randomForest(x) +} + +# Prints the summary of Random Forest Classification Model + +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestClassificationModel since 2.1.0 +print.summary.RandomForestClassificationModel <- function(x, ...) { + print.summary.randomForest(x) +} diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 6d1fccc7c058..db98d0e45547 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -871,4 +871,72 @@ test_that("spark.kstest", { expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") }) +test_that("spark.randomForest Regression", { + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 1) + + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + + stats <- summary(model) + expect_equal(stats$numTrees, 1) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 20, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258, + 63.736, 64.296, 64.868, 64.300, + 66.709, 67.697, 67.966, 67.252, + 68.866, 69.593, 69.195, 69.658), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + + modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) +}) + +test_that("spark.randomForest Classification", { + data <- suppressWarnings(createDataFrame(iris)) + model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) +}) + sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 1df3662a5822..0e09e18027ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -56,6 +56,10 @@ private[r] object RWrappers extends MLReader[Object] { ALSWrapper.load(path) case "org.apache.spark.ml.r.LogisticRegressionWrapper" => LogisticRegressionWrapper.load(path) + case "org.apache.spark.ml.r.RandomForestRegressorWrapper" => + RandomForestRegressorWrapper.load(path) + case "org.apache.spark.ml.r.RandomForestClassifierWrapper" => + RandomForestClassifierWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala new file mode 100644 index 000000000000..b0088ddaf3b1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class RandomForestClassifierWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + private val DTModel: RandomForestClassificationModel = + pipeline.stages(1).asInstanceOf[RandomForestClassificationModel] + + lazy val numFeatures: Int = DTModel.numFeatures + lazy val featureImportances: Vector = DTModel.featureImportances + lazy val numTrees: Int = DTModel.getNumTrees + lazy val treeWeights: Array[Double] = DTModel.treeWeights + + def summary: String = DTModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(DTModel.getFeaturesCol) + } + + override def write: MLWriter = new + RandomForestClassifierWrapper.RandomForestClassifierWrapperWriter(this) +} + +private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] { + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + numTrees: Int, + impurity: String, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + featureSubsetStrategy: String, + seed: String, + subsamplingRate: Double, + probabilityCol: String, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): RandomForestClassifierWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val rfc = new RandomForestClassifier() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setNumTrees(numTrees) + .setImpurity(impurity) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setFeatureSubsetStrategy(featureSubsetStrategy) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setProbabilityCol(probabilityCol) + .setFeaturesCol(rFormula.getFeaturesCol) + if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfc)) + .fit(data) + + new RandomForestClassifierWrapper(pipeline, formula, features) + } + + override def read: MLReader[RandomForestClassifierWrapper] = + new RandomForestClassifierWrapperReader + + override def load(path: String): RandomForestClassifierWrapper = super.load(path) + + class RandomForestClassifierWrapperWriter(instance: RandomForestClassifierWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class RandomForestClassifierWrapperReader extends MLReader[RandomForestClassifierWrapper] { + + override def load(path: String): RandomForestClassifierWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new RandomForestClassifierWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala new file mode 100644 index 000000000000..c8874407fa75 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class RandomForestRegressorWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + private val DTModel: RandomForestRegressionModel = + pipeline.stages(1).asInstanceOf[RandomForestRegressionModel] + + lazy val numFeatures: Int = DTModel.numFeatures + lazy val featureImportances: Vector = DTModel.featureImportances + lazy val numTrees: Int = DTModel.getNumTrees + lazy val treeWeights: Array[Double] = DTModel.treeWeights + + def summary: String = DTModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(DTModel.getFeaturesCol) + } + + override def write: MLWriter = new + RandomForestRegressorWrapper.RandomForestRegressorWrapperWriter(this) +} + +private[r] object RandomForestRegressorWrapper extends MLReadable[RandomForestRegressorWrapper] { + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + numTrees: Int, + impurity: String, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + featureSubsetStrategy: String, + seed: String, + subsamplingRate: Double, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): RandomForestRegressorWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val rfr = new RandomForestRegressor() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setNumTrees(numTrees) + .setImpurity(impurity) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setFeatureSubsetStrategy(featureSubsetStrategy) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + if (seed != null && seed.length > 0) rfr.setSeed(seed.toLong) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfr)) + .fit(data) + + new RandomForestRegressorWrapper(pipeline, formula, features) + } + + override def read: MLReader[RandomForestRegressorWrapper] = new RandomForestRegressorWrapperReader + + override def load(path: String): RandomForestRegressorWrapper = super.load(path) + + class RandomForestRegressorWrapperWriter(instance: RandomForestRegressorWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class RandomForestRegressorWrapperReader extends MLReader[RandomForestRegressorWrapper] { + + override def load(path: String): RandomForestRegressorWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new RandomForestRegressorWrapper(pipeline, formula, features) + } + } +} From 7c3786929205b962b430cf7fc292602c2993c193 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 30 Oct 2016 16:21:37 -0700 Subject: [PATCH 056/381] [SPARK-18110][PYTHON][ML] add missing parameter in Python for RandomForest regression and classification ## What changes were proposed in this pull request? Add subsmaplingRate to randomForestClassifier Add varianceCol to randomForestRegressor In Python ## How was this patch tested? manual tests Author: Felix Cheung Closes #15638 from felixcheung/pyrandomforest. --- python/pyspark/ml/classification.py | 11 ++++++----- python/pyspark/ml/regression.py | 12 ++++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 3f763a10d406..d9ff356b9403 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -758,20 +758,21 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", - numTrees=20, featureSubsetStrategy="auto", seed=None): + numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ - numTrees=20, featureSubsetStrategy="auto", seed=None) + numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0) """ super(RandomForestClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.RandomForestClassifier", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="gini", numTrees=20, featureSubsetStrategy="auto") + impurity="gini", numTrees=20, featureSubsetStrategy="auto", + subsamplingRate=1.0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -781,13 +782,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, - impurity="gini", numTrees=20, featureSubsetStrategy="auto"): + impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ - impurity="gini", numTrees=20, featureSubsetStrategy="auto") + impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0) Sets params for linear classification. """ kwargs = self.setParams._input_kwargs diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 55d38033ef72..9233d2e7e1a7 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -594,7 +594,7 @@ class RandomForestParams(TreeEnsembleParams): featureSubsetStrategy = \ Param(Params._dummy(), "featureSubsetStrategy", "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies) + " (0.0-1.0], [1-n].", + "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", typeConverter=TypeConverters.toString) def __init__(self): @@ -828,7 +828,7 @@ def featureImportances(self): @inherit_doc class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, RandomForestParams, TreeRegressorParams, HasCheckpointInterval, - JavaMLWritable, JavaMLReadable): + JavaMLWritable, JavaMLReadable, HasVarianceCol): """ `Random Forest `_ learning algorithm for regression. @@ -876,13 +876,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, - featureSubsetStrategy="auto"): + featureSubsetStrategy="auto", varianceCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \ - featureSubsetStrategy="auto") + featureSubsetStrategy="auto", varianceCol=None) """ super(RandomForestRegressor, self).__init__() self._java_obj = self._new_java_obj( @@ -900,13 +900,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, - featureSubsetStrategy="auto"): + featureSubsetStrategy="auto", varianceCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \ - featureSubsetStrategy="auto") + featureSubsetStrategy="auto", varianceCol=None) Sets params for linear regression. """ kwargs = self.setParams._input_kwargs From d2923f173265b66a4ec71c3c86ff71a58d5aeb3d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 31 Oct 2016 00:11:33 -0700 Subject: [PATCH 057/381] [SPARK-18143][SQL] Ignore Structured Streaming event logs to avoid breaking history server ## What changes were proposed in this pull request? Because of the refactoring work in Structured Streaming, the event logs generated by Strucutred Streaming in Spark 2.0.0 and 2.0.1 cannot be parsed. This PR just ignores these logs in ReplayListenerBus because no places use them. ## How was this patch tested? - Generated events logs using Spark 2.0.0 and 2.0.1, and saved them as `structured-streaming-query-event-logs-2.0.0.txt` and `structured-streaming-query-event-logs-2.0.1.txt` - The new added test makes sure ReplayListenerBus will skip these bad jsons. Author: Shixiong Zhu Closes #15663 from zsxwing/fix-event-log. --- .../spark/scheduler/ReplayListenerBus.scala | 13 ++++++ .../query-event-logs-version-2.0.0.txt | 4 ++ .../query-event-logs-version-2.0.1.txt | 4 ++ .../StreamingQueryListenerSuite.scala | 42 +++++++++++++++++++ 4 files changed, 63 insertions(+) create mode 100644 sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.0.txt create mode 100644 sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.1.txt diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 3eff8d952bfd..2424586431aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -72,6 +72,10 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { postToAll(JsonProtocol.sparkEventFromJson(parse(currentLine))) } catch { + case e: ClassNotFoundException if KNOWN_REMOVED_CLASSES.contains(e.getMessage) => + // Ignore events generated by Structured Streaming in Spark 2.0.0 and 2.0.1. + // It's safe since no place uses them. + logWarning(s"Dropped incompatible Structured Streaming log: $currentLine") case jpe: JsonParseException => // We can only ignore exception from last line of the file that might be truncated // the last entry may not be the very last line in the event log, but we treat it @@ -102,4 +106,13 @@ private[spark] object ReplayListenerBus { // utility filter that selects all event logs during replay val SELECT_ALL_FILTER: ReplayEventsFilter = { (eventString: String) => true } + + /** + * Classes that were removed. Structured Streaming doesn't use them any more. However, parsing + * old json may fail and we can just ignore these failures. + */ + val KNOWN_REMOVED_CLASSES = Set( + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress", + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated" + ) } diff --git a/sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.0.txt b/sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.0.txt new file mode 100644 index 000000000000..aa7e9a8c20c4 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.0.txt @@ -0,0 +1,4 @@ +{"Event":"org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress","queryInfo":{"name":"hello","id":0,"sourceStatuses":[{"description":"FileStreamSource[file:/Users/zsx/stream]","offsetDesc":"#0"}],"sinkStatus":{"description":"org.apache.spark.sql.execution.streaming.MemorySink@2b85b3a5","offsetDesc":"[#0]"}}} +{"Event":"org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated","queryInfo":{"name":"hello","id":0,"sourceStatuses":[{"description":"FileStreamSource[file:/Users/zsx/stream]","offsetDesc":"#0"}],"sinkStatus":{"description":"org.apache.spark.sql.execution.streaming.MemorySink@2b85b3a5","offsetDesc":"[#0]"}},"exception":null,"stackTrace":[]} +{"Event":"org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated","queryInfo":{"name":"hello","id":0,"sourceStatuses":[{"description":"FileStreamSource[file:/Users/zsx/stream]","offsetDesc":"#0"}],"sinkStatus":{"description":"org.apache.spark.sql.execution.streaming.MemorySink@514502dc","offsetDesc":"[-]"}},"exception":"Query hello terminated with exception: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, localhost): java.lang.ArithmeticException: / by zero\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:25)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:25)\n\tat org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)\n\tat org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)\n\tat org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370)\n\tat org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:246)\n\tat org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:240)\n\tat org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:784)\n\tat org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:784)\n\tat org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)\n\tat org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)\n\tat org.apache.spark.rdd.RDD.iterator(RDD.scala:283)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:85)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\nDriver stacktrace:","stackTrace":[{"methodName":"org$apache$spark$sql$execution$streaming$StreamExecution$$runBatches","fileName":"StreamExecution.scala","lineNumber":208,"className":"org.apache.spark.sql.execution.streaming.StreamExecution","nativeMethod":false},{"methodName":"run","fileName":"StreamExecution.scala","lineNumber":120,"className":"org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1","nativeMethod":false}]} +{"Event":"SparkListenerApplicationEnd","Timestamp":1477593059313} diff --git a/sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.1.txt b/sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.1.txt new file mode 100644 index 000000000000..646cf107183b --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/query-event-logs-version-2.0.1.txt @@ -0,0 +1,4 @@ +{"Event":"org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress","queryInfo":{"name":"hello","id":0,"sourceStatuses":[{"description":"FileStreamSource[file:/Users/zsx/stream]","offsetDesc":"#0"}],"sinkStatus":{"description":"org.apache.spark.sql.execution.streaming.MemorySink@10e5ec94","offsetDesc":"[#0]"}}} +{"Event":"org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated","queryInfo":{"name":"hello","id":0,"sourceStatuses":[{"description":"FileStreamSource[file:/Users/zsx/stream]","offsetDesc":"#0"}],"sinkStatus":{"description":"org.apache.spark.sql.execution.streaming.MemorySink@10e5ec94","offsetDesc":"[#0]"}},"exception":null} +{"Event":"org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated","queryInfo":{"name":"hello","id":0,"sourceStatuses":[{"description":"FileStreamSource[file:/Users/zsx/stream]","offsetDesc":"#0"}],"sinkStatus":{"description":"org.apache.spark.sql.execution.streaming.MemorySink@70c61dc8","offsetDesc":"[-]"}},"exception":"org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, localhost): java.lang.ArithmeticException: / by zero\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:25)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:25)\n\tat org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)\n\tat org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)\n\tat org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370)\n\tat org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:246)\n\tat org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:240)\n\tat org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)\n\tat org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)\n\tat org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)\n\tat org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)\n\tat org.apache.spark.rdd.RDD.iterator(RDD.scala:283)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:86)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\nDriver stacktrace:\n\tat org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1454)\n\tat org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1442)\n\tat org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1441)\n\tat scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)\n\tat scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)\n\tat org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1441)\n\tat org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:811)\n\tat org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:811)\n\tat scala.Option.foreach(Option.scala:257)\n\tat org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:811)\n\tat org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1667)\n\tat org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1622)\n\tat org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1611)\n\tat org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)\n\tat org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:632)\n\tat org.apache.spark.SparkContext.runJob(SparkContext.scala:1890)\n\tat org.apache.spark.SparkContext.runJob(SparkContext.scala:1903)\n\tat org.apache.spark.SparkContext.runJob(SparkContext.scala:1916)\n\tat org.apache.spark.SparkContext.runJob(SparkContext.scala:1930)\n\tat org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:912)\n\tat org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)\n\tat org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)\n\tat org.apache.spark.rdd.RDD.withScope(RDD.scala:358)\n\tat org.apache.spark.rdd.RDD.collect(RDD.scala:911)\n\tat org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:290)\n\tat org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$execute$1$1.apply(Dataset.scala:2193)\n\tat org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:57)\n\tat org.apache.spark.sql.Dataset.withNewExecutionId(Dataset.scala:2546)\n\tat org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$execute$1(Dataset.scala:2192)\n\tat org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$collect$1.apply(Dataset.scala:2197)\n\tat org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$collect$1.apply(Dataset.scala:2197)\n\tat org.apache.spark.sql.Dataset.withCallback(Dataset.scala:2559)\n\tat org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collect(Dataset.scala:2197)\n\tat org.apache.spark.sql.Dataset.collect(Dataset.scala:2173)\n\tat org.apache.spark.sql.execution.streaming.MemorySink.addBatch(memory.scala:154)\n\tat org.apache.spark.sql.execution.streaming.StreamExecution.org$apache$spark$sql$execution$streaming$StreamExecution$$runBatch(StreamExecution.scala:366)\n\tat org.apache.spark.sql.execution.streaming.StreamExecution$$anonfun$org$apache$spark$sql$execution$streaming$StreamExecution$$runBatches$1.apply$mcZ$sp(StreamExecution.scala:197)\n\tat org.apache.spark.sql.execution.streaming.ProcessingTimeExecutor.execute(TriggerExecutor.scala:43)\n\tat org.apache.spark.sql.execution.streaming.StreamExecution.org$apache$spark$sql$execution$streaming$StreamExecution$$runBatches(StreamExecution.scala:187)\n\tat org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.run(StreamExecution.scala:124)\nCaused by: java.lang.ArithmeticException: / by zero\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:25)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:25)\n\tat org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)\n\tat org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)\n\tat org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370)\n\tat org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:246)\n\tat org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:240)\n\tat org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)\n\tat org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)\n\tat org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)\n\tat org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)\n\tat org.apache.spark.rdd.RDD.iterator(RDD.scala:283)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:86)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n"} +{"Event":"SparkListenerApplicationEnd","Timestamp":1477701734609} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index ff843865a017..cebb32a0a56c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.streaming +import scala.collection.mutable + import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter import org.scalatest.PrivateMethodTester._ import org.apache.spark.SparkException +import org.apache.spark.scheduler._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ @@ -206,6 +209,45 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { assert(queryQueryTerminated.exception === newQueryTerminated.exception) } + test("ReplayListenerBus should ignore broken event jsons generated in 2.0.0") { + // query-event-logs-version-2.0.0.txt has all types of events generated by + // Structured Streaming in Spark 2.0.0. + // SparkListenerApplicationEnd is the only valid event and it's the last event. We use it + // to verify that we can skip broken jsons generated by Structured Streaming. + testReplayListenerBusWithBorkenEventJsons("query-event-logs-version-2.0.0.txt") + } + + test("ReplayListenerBus should ignore broken event jsons generated in 2.0.1") { + // query-event-logs-version-2.0.1.txt has all types of events generated by + // Structured Streaming in Spark 2.0.1. + // SparkListenerApplicationEnd is the only valid event and it's the last event. We use it + // to verify that we can skip broken jsons generated by Structured Streaming. + testReplayListenerBusWithBorkenEventJsons("query-event-logs-version-2.0.1.txt") + } + + private def testReplayListenerBusWithBorkenEventJsons(fileName: String): Unit = { + val input = getClass.getResourceAsStream(s"/structured-streaming/$fileName") + val events = mutable.ArrayBuffer[SparkListenerEvent]() + try { + val replayer = new ReplayListenerBus() { + // Redirect all parsed events to `events` + override def doPostEvent( + listener: SparkListenerInterface, + event: SparkListenerEvent): Unit = { + events += event + } + } + // Add a dummy listener so that "doPostEvent" will be called. + replayer.addListener(new SparkListener {}) + replayer.replay(input, fileName) + // SparkListenerApplicationEnd is the only valid event + assert(events.size === 1) + assert(events(0).isInstanceOf[SparkListenerApplicationEnd]) + } finally { + input.close() + } + } + private def assertStreamingQueryInfoEquals( expected: StreamingQueryStatus, actual: StreamingQueryStatus): Unit = { From 26b07f1908eeffd934b1e86fb4de02f69945e004 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 31 Oct 2016 10:10:22 +0000 Subject: [PATCH 058/381] [BUILD] Close stale Pull Requests. Closes #11610 Closes #15411 Closes #15501 Closes #12613 Closes #12518 Closes #12026 Closes #15524 Closes #12693 Closes #12358 Closes #15588 Closes #15635 Closes #15678 Closes #14699 Closes #9008 Author: Sean Owen Closes #15685 from srowen/CloseStalePRs. From 8bfc3b7aac577e36aadc4fe6dee0665d0b2ae919 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 31 Oct 2016 13:39:59 -0700 Subject: [PATCH 059/381] [SPARK-17972][SQL] Add Dataset.checkpoint() to truncate large query plans ## What changes were proposed in this pull request? ### Problem Iterative ML code may easily create query plans that grow exponentially. We found that query planning time also increases exponentially even when all the sub-plan trees are cached. The following snippet illustrates the problem: ``` scala (0 until 6).foldLeft(Seq(1, 2, 3).toDS) { (plan, iteration) => println(s"== Iteration $iteration ==") val time0 = System.currentTimeMillis() val joined = plan.join(plan, "value").join(plan, "value").join(plan, "value").join(plan, "value") joined.cache() println(s"Query planning takes ${System.currentTimeMillis() - time0} ms") joined.as[Int] } // == Iteration 0 == // Query planning takes 9 ms // == Iteration 1 == // Query planning takes 26 ms // == Iteration 2 == // Query planning takes 53 ms // == Iteration 3 == // Query planning takes 163 ms // == Iteration 4 == // Query planning takes 700 ms // == Iteration 5 == // Query planning takes 3418 ms ``` This is because when building a new Dataset, the new plan is always built upon `QueryExecution.analyzed`, which doesn't leverage existing cached plans. On the other hand, usually, doing caching every a few iterations may not be the right direction for this problem since caching is too memory consuming (imaging computing connected components over a graph with 50 billion nodes). What we really need here is to truncate both the query plan (to minimize query planning time) and the lineage of the underlying RDD (to avoid stack overflow). ### Changes introduced in this PR This PR tries to fix this issue by introducing a `checkpoint()` method into `Dataset[T]`, which does exactly the things described above. The following snippet, which is essentially the same as the one above but invokes `checkpoint()` instead of `cache()`, shows the micro benchmark result of this PR: One key point is that the checkpointed Dataset should preserve the origianl partitioning and ordering information of the original Dataset, so that we can avoid unnecessary shuffling (similar to reading from a pre-bucketed table). This is done by adding `outputPartitioning` and `outputOrdering` to `LogicalRDD` and `RDDScanExec`. ### Micro benchmark ``` scala spark.sparkContext.setCheckpointDir("/tmp/cp") (0 until 100).foldLeft(Seq(1, 2, 3).toDS) { (plan, iteration) => println(s"== Iteration $iteration ==") val time0 = System.currentTimeMillis() val cp = plan.checkpoint() cp.count() System.out.println(s"Checkpointing takes ${System.currentTimeMillis() - time0} ms") val time1 = System.currentTimeMillis() val joined = cp.join(cp, "value").join(cp, "value").join(cp, "value").join(cp, "value") val result = joined.as[Int] println(s"Query planning takes ${System.currentTimeMillis() - time1} ms") result } // == Iteration 0 == // Checkpointing takes 591 ms // Query planning takes 13 ms // == Iteration 1 == // Checkpointing takes 1605 ms // Query planning takes 16 ms // == Iteration 2 == // Checkpointing takes 782 ms // Query planning takes 8 ms // == Iteration 3 == // Checkpointing takes 729 ms // Query planning takes 10 ms // == Iteration 4 == // Checkpointing takes 734 ms // Query planning takes 9 ms // == Iteration 5 == // ... // == Iteration 50 == // Checkpointing takes 571 ms // Query planning takes 7 ms // == Iteration 51 == // Checkpointing takes 548 ms // Query planning takes 7 ms // == Iteration 52 == // Checkpointing takes 596 ms // Query planning takes 8 ms // == Iteration 53 == // Checkpointing takes 568 ms // Query planning takes 7 ms // ... ``` You may see that although checkpointing is more heavy weight an operation, it always takes roughly the same amount of time to perform both checkpointing and query planning. ### Open question mengxr mentioned that it would be more convenient if we can make `Dataset.checkpoint()` eager, i.e., always performs a `RDD.count()` after calling `RDD.checkpoint()`. Not quite sure whether this is a universal requirement. Maybe we can add a `eager: Boolean` argument for `Dataset.checkpoint()` to support that. ## How was this patch tested? Unit test added in `DatasetSuite`. Author: Cheng Lian Author: Yin Huai Closes #15651 from liancheng/ds-checkpoint. --- .../scala/org/apache/spark/sql/Dataset.scala | 57 +++++++++++++++- .../spark/sql/execution/ExistingRDD.scala | 37 ++++++++-- .../spark/sql/execution/SparkStrategies.scala | 7 +- .../org/apache/spark/sql/DatasetSuite.scala | 68 +++++++++++++++++++ 4 files changed, 157 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 286d8549bfe2..6e0a2471e0fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -40,13 +40,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.streaming.{DataStreamWriter, StreamingQuery} +import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -482,6 +483,58 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def isStreaming: Boolean = logicalPlan.isStreaming + /** + * Returns a checkpointed version of this Dataset. + * + * @group basic + * @since 2.1.0 + */ + @Experimental + @InterfaceStability.Evolving + def checkpoint(): Dataset[T] = checkpoint(eager = true) + + /** + * Returns a checkpointed version of this Dataset. + * + * @param eager When true, materializes the underlying checkpointed RDD eagerly. + * + * @group basic + * @since 2.1.0 + */ + @Experimental + @InterfaceStability.Evolving + def checkpoint(eager: Boolean): Dataset[T] = { + val internalRdd = queryExecution.toRdd.map(_.copy()) + internalRdd.checkpoint() + + if (eager) { + internalRdd.count() + } + + val physicalPlan = queryExecution.executedPlan + + // Takes the first leaf partitioning whenever we see a `PartitioningCollection`. Otherwise the + // size of `PartitioningCollection` may grow exponentially for queries involving deep inner + // joins. + def firstLeafPartitioning(partitioning: Partitioning): Partitioning = { + partitioning match { + case p: PartitioningCollection => firstLeafPartitioning(p.partitionings.head) + case p => p + } + } + + val outputPartitioning = firstLeafPartitioning(physicalPlan.outputPartitioning) + + Dataset.ofRows( + sparkSession, + LogicalRDD( + logicalPlan.output, + internalRdd, + outputPartitioning, + physicalPlan.outputOrdering + )(sparkSession)).as[T] + } + /** * Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated, * and all cells will be aligned right. For example: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d3a22228623e..455fb5bfbb6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -130,17 +130,40 @@ case class ExternalRDDScanExec[T]( /** Logical plan node for scanning data from an RDD of InternalRow. */ case class LogicalRDD( output: Seq[Attribute], - rdd: RDD[InternalRow])(session: SparkSession) + rdd: RDD[InternalRow], + outputPartitioning: Partitioning = UnknownPartitioning(0), + outputOrdering: Seq[SortOrder] = Nil)(session: SparkSession) extends LeafNode with MultiInstanceRelation { override protected final def otherCopyArgs: Seq[AnyRef] = session :: Nil - override def newInstance(): LogicalRDD.this.type = - LogicalRDD(output.map(_.newInstance()), rdd)(session).asInstanceOf[this.type] + override def newInstance(): LogicalRDD.this.type = { + val rewrite = output.zip(output.map(_.newInstance())).toMap + + val rewrittenPartitioning = outputPartitioning match { + case p: Expression => + p.transform { + case e: Attribute => rewrite.getOrElse(e, e) + }.asInstanceOf[Partitioning] + + case p => p + } + + val rewrittenOrdering = outputOrdering.map(_.transform { + case e: Attribute => rewrite.getOrElse(e, e) + }.asInstanceOf[SortOrder]) + + LogicalRDD( + output.map(rewrite), + rdd, + rewrittenPartitioning, + rewrittenOrdering + )(session).asInstanceOf[this.type] + } override def sameResult(plan: LogicalPlan): Boolean = { plan.canonicalized match { - case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id + case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id case _ => false } } @@ -158,7 +181,9 @@ case class LogicalRDD( case class RDDScanExec( output: Seq[Attribute], rdd: RDD[InternalRow], - override val nodeName: String) extends LeafExecNode { + override val nodeName: String, + override val outputPartitioning: Partitioning = UnknownPartitioning(0), + override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7cfae5ce283b..5412aca95dcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -32,8 +32,6 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.execution.streaming.{MemoryPlan, StreamingExecutionRelation, StreamingRelation, StreamingRelationExec} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.StreamingQuery /** * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting @@ -402,13 +400,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil - case r : logical.Range => + case r: logical.Range => execution.RangeExec(r) :: Nil case logical.RepartitionByExpression(expressions, child, nPartitions) => exchange.ShuffleExchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil - case LogicalRDD(output, rdd) => RDDScanExec(output, rdd, "ExistingRDD") :: Nil + case r: LogicalRDD => + RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index cc367acae2ba..55f04878052a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -22,8 +22,11 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -919,6 +922,71 @@ class DatasetSuite extends QueryTest with SharedSQLContext { df.withColumn("b", expr("0")).as[ClassData] .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() }) } + + Seq(true, false).foreach { eager => + def testCheckpointing(testName: String)(f: => Unit): Unit = { + test(s"Dataset.checkpoint() - $testName (eager = $eager)") { + withTempDir { dir => + val originalCheckpointDir = spark.sparkContext.checkpointDir + + try { + spark.sparkContext.setCheckpointDir(dir.getCanonicalPath) + f + } finally { + // Since the original checkpointDir can be None, we need + // to set the variable directly. + spark.sparkContext.checkpointDir = originalCheckpointDir + } + } + } + } + + testCheckpointing("basic") { + val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc) + val cp = ds.checkpoint(eager) + + val logicalRDD = cp.logicalPlan match { + case plan: LogicalRDD => plan + case _ => + val treeString = cp.logicalPlan.treeString(verbose = true) + fail(s"Expecting a LogicalRDD, but got\n$treeString") + } + + val dsPhysicalPlan = ds.queryExecution.executedPlan + val cpPhysicalPlan = cp.queryExecution.executedPlan + + assertResult(dsPhysicalPlan.outputPartitioning) { logicalRDD.outputPartitioning } + assertResult(dsPhysicalPlan.outputOrdering) { logicalRDD.outputOrdering } + + assertResult(dsPhysicalPlan.outputPartitioning) { cpPhysicalPlan.outputPartitioning } + assertResult(dsPhysicalPlan.outputOrdering) { cpPhysicalPlan.outputOrdering } + + // For a lazy checkpoint() call, the first check also materializes the checkpoint. + checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) + + // Reads back from checkpointed data and check again. + checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) + } + + testCheckpointing("should preserve partitioning information") { + val ds = spark.range(10).repartition('id % 2) + val cp = ds.checkpoint(eager) + + val agg = cp.groupBy('id % 2).agg(count('id)) + + agg.queryExecution.executedPlan.collectFirst { + case ShuffleExchange(_, _: RDDScanExec, _) => + case BroadcastExchangeExec(_, _: RDDScanExec) => + }.foreach { _ => + fail( + "No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " + + "preserves partitioning information:\n\n" + agg.queryExecution + ) + } + + checkAnswer(agg, ds.groupBy('id % 2).agg(count('id))) + } + } } case class Generic[T](id: T, value: Double) From de3f87fa712c305fdd463fc36acffc5418c95c4d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 31 Oct 2016 16:05:17 -0700 Subject: [PATCH 060/381] [SPARK-18030][TESTS] Fix flaky FileStreamSourceSuite by not deleting the files ## What changes were proposed in this pull request? The test `when schema inference is turned on, should read partition data` should not delete files because the source maybe is listing files. This PR just removes the delete actions since they are not necessary. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15699 from zsxwing/SPARK-18030. --- .../spark/sql/streaming/FileStreamSourceSuite.scala | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 47018b3a3c49..fab7642994ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -102,12 +102,6 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext with Private } } - case class DeleteFile(file: File) extends ExternalAction { - def runAction(): Unit = { - Utils.deleteRecursively(file) - } - } - /** Use `format` and `path` to create FileStreamSource via DataFrameReader */ def createFileStream( format: String, @@ -697,10 +691,6 @@ class FileStreamSourceSuite extends FileStreamSourceTest { AddTextFileData("{'value': 'keep5'}", partitionBarSubDir, tmp), CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar")), - // Delete the two partition dirs - DeleteFile(partitionFooSubDir), - DeleteFile(partitionBarSubDir), - AddTextFileData("{'value': 'keep6'}", partitionBarSubDir, tmp), CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar"), ("keep6", "bar")) From 6633b97b579c7f003d60b6bfa2e2a248340d3dc6 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 31 Oct 2016 16:26:52 -0700 Subject: [PATCH 061/381] [SPARK-18167][SQL] Also log all partitions when the SQLQuerySuite test flakes ## What changes were proposed in this pull request? One possibility for this test flaking is that we have corrupted the partition schema somehow in the tests, which causes the cast to decimal to fail in the call. This should at least show us the actual partition values. ## How was this patch tested? Run it locally, it prints out something like `ArrayBuffer(test(partcol=0), test(partcol=1), test(partcol=2), test(partcol=3), test(partcol=4))`. Author: Eric Liang Closes #15701 from ericl/print-more-info. --- .../main/scala/org/apache/spark/sql/hive/client/HiveShim.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 4bbbd66132b7..85edaf63db88 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -594,9 +594,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { // SPARK-18167 retry to investigate the flaky test. This should be reverted before // the release is cut. val retry = Try(getPartitionsByFilterMethod.invoke(hive, table, filter)) - val full = Try(getAllPartitionsMethod.invoke(hive, table)) logError("getPartitionsByFilter failed, retry success = " + retry.isSuccess) - logError("getPartitionsByFilter failed, full fetch success = " + full.isSuccess) + logError("all partitions: " + getAllPartitions(hive, table)) throw e } } From efc254a82bc3331d78023f00d29d4c4318dfb734 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 31 Oct 2016 19:46:55 -0700 Subject: [PATCH 062/381] [SPARK-18087][SQL] Optimize insert to not require REPAIR TABLE ## What changes were proposed in this pull request? When inserting into datasource tables with partitions managed by the hive metastore, we need to notify the metastore of newly added partitions. Previously this was implemented via `msck repair table`, but this is more expensive than needed. This optimizes the insertion path to add only the updated partitions. ## How was this patch tested? Existing tests (I verified manually that tests fail if the repair operation is omitted). Author: Eric Liang Closes #15633 from ericl/spark-18087. --- .../execution/datasources/DataSource.scala | 2 +- .../datasources/DataSourceStrategy.scala | 27 ++++++++++------- .../InsertIntoHadoopFsRelationCommand.scala | 3 +- .../datasources/PartitioningUtils.scala | 12 ++++++++ .../execution/datasources/WriteOutput.scala | 29 +++++++++++++------ 5 files changed, 52 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 996109865fdc..d980e6a15aab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -528,7 +528,7 @@ case class DataSource( columns, bucketSpec, format, - () => Unit, // No existing table needs to be refreshed. + _ => Unit, // No existing table needs to be refreshed. options, data.logicalPlan, mode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index f0bcf94eadc9..34b77cab65de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, Inte import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation} +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, DDLUtils, ExecutedCommandExec} +import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -179,24 +180,30 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { "Cannot overwrite a path that is also being read from.") } + def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = { + if (l.catalogTable.isDefined && + l.catalogTable.get.partitionColumnNames.nonEmpty && + l.catalogTable.get.partitionProviderIsHive) { + val metastoreUpdater = AlterTableAddPartitionCommand( + l.catalogTable.get.identifier, + updatedPartitions.map(p => (p, None)), + ifNotExists = true) + metastoreUpdater.run(t.sparkSession) + } + t.location.refresh() + } + val insertCmd = InsertIntoHadoopFsRelationCommand( outputPath, query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver), t.bucketSpec, t.fileFormat, - () => t.location.refresh(), + refreshPartitionsCallback, t.options, query, mode) - if (l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty && - l.catalogTable.get.partitionProviderIsHive) { - // TODO(ekl) we should be more efficient here and only recover the newly added partitions - val recoverPartitionCmd = AlterTableRecoverPartitionsCommand(l.catalogTable.get.identifier) - Union(insertCmd, recoverPartitionCmd) - } else { - insertCmd - } + insertCmd } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 22dbe7149531..a1221d0ae6d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand @@ -40,7 +41,7 @@ case class InsertIntoHadoopFsRelationCommand( partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], fileFormat: FileFormat, - refreshFunction: () => Unit, + refreshFunction: (Seq[TablePartitionSpec]) => Unit, options: Map[String, String], @transient query: LogicalPlan, mode: SaveMode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index f66e8b4e2b55..b51b41869bf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ @@ -244,6 +245,17 @@ object PartitioningUtils { } } + /** + * Given a partition path fragment, e.g. `fieldOne=1/fieldTwo=2`, returns a parsed spec + * for that fragment, e.g. `Map(("fieldOne", "1"), ("fieldTwo", "2"))`. + */ + def parsePathFragment(pathFragment: String): TablePartitionSpec = { + pathFragment.split("/").map { kv => + val pair = kv.split("=", 2) + (unescapePathName(pair(0)), unescapePathName(pair(1))) + }.toMap + } + /** * Normalize the column names in partition specification, w.r.t. the real partition column names * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala index bd56e511d0cc..0eb86fdd6caa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} +import scala.collection.mutable + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ @@ -30,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow @@ -85,7 +88,7 @@ object WriteOutput extends Logging { hadoopConf: Configuration, partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], - refreshFunction: () => Unit, + refreshFunction: (Seq[TablePartitionSpec]) => Unit, options: Map[String, String], isAppend: Boolean): Unit = { @@ -120,7 +123,7 @@ object WriteOutput extends Logging { val committer = setupDriverCommitter(job, outputPath.toString, isAppend) try { - sparkSession.sparkContext.runJob(queryExecution.toRdd, + val updatedPartitions = sparkSession.sparkContext.runJob(queryExecution.toRdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -128,11 +131,11 @@ object WriteOutput extends Logging { sparkPartitionId = taskContext.partitionId(), sparkAttemptNumber = taskContext.attemptNumber(), iterator = iter) - }) + }).flatten.distinct committer.commitJob(job) logInfo(s"Job ${job.getJobID} committed.") - refreshFunction() + refreshFunction(updatedPartitions.map(PartitioningUtils.parsePathFragment)) } catch { case cause: Throwable => logError(s"Aborting job ${job.getJobID}.", cause) committer.abortJob(job, JobStatus.State.FAILED) @@ -147,7 +150,7 @@ object WriteOutput extends Logging { sparkStageId: Int, sparkPartitionId: Int, sparkAttemptNumber: Int, - iterator: Iterator[InternalRow]): Unit = { + iterator: Iterator[InternalRow]): Set[String] = { val jobId = SparkHadoopWriter.createJobID(new Date, sparkStageId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) @@ -187,11 +190,12 @@ object WriteOutput extends Logging { try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out - writeTask.execute(iterator) + val outputPaths = writeTask.execute(iterator) writeTask.releaseResources() // Commit the task SparkHadoopMapRedUtil.commitTask(committer, taskAttemptContext, jobId.getId, taskId.getId) + outputPaths })(catchBlock = { // If there is an error, release resource and then abort the task try { @@ -213,7 +217,7 @@ object WriteOutput extends Logging { * automatically trigger task aborts. */ private trait ExecuteWriteTask { - def execute(iterator: Iterator[InternalRow]): Unit + def execute(iterator: Iterator[InternalRow]): Set[String] def releaseResources(): Unit final def filePrefix(split: Int, uuid: String, bucketId: Option[Int]): String = { @@ -240,11 +244,12 @@ object WriteOutput extends Logging { outputWriter } - override def execute(iter: Iterator[InternalRow]): Unit = { + override def execute(iter: Iterator[InternalRow]): Set[String] = { while (iter.hasNext) { val internalRow = iter.next() outputWriter.writeInternal(internalRow) } + Set.empty } override def releaseResources(): Unit = { @@ -327,7 +332,7 @@ object WriteOutput extends Logging { newWriter } - override def execute(iter: Iterator[InternalRow]): Unit = { + override def execute(iter: Iterator[InternalRow]): Set[String] = { // We should first sort by partition columns, then bucket id, and finally sorting columns. val sortingExpressions: Seq[Expression] = description.partitionColumns ++ bucketIdExpression ++ sortColumns @@ -375,6 +380,7 @@ object WriteOutput extends Logging { // If anything below fails, we should abort the task. var currentKey: UnsafeRow = null + val updatedPartitions = mutable.Set[String]() while (sortedIterator.next()) { val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] if (currentKey != nextKey) { @@ -386,6 +392,10 @@ object WriteOutput extends Logging { logDebug(s"Writing partition: $currentKey") currentWriter = newOutputWriter(currentKey, getPartitionString) + val partitionPath = getPartitionString(currentKey).getString(0) + if (partitionPath.nonEmpty) { + updatedPartitions.add(partitionPath) + } } currentWriter.writeInternal(sortedIterator.getValue) } @@ -393,6 +403,7 @@ object WriteOutput extends Logging { currentWriter.close() currentWriter = null } + updatedPartitions.toSet } override def releaseResources(): Unit = { From 7d6c87155c740cf622c2c600a8ca64154d24c422 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 31 Oct 2016 20:23:22 -0700 Subject: [PATCH 063/381] [SPARK-18167][SQL] Retry when the SQLQuerySuite test flakes ## What changes were proposed in this pull request? This will re-run the flaky test a few times after it fails. This will help determine if it's due to nondeterministic test setup, or because of some environment issue (e.g. leaked config from another test). cc yhuai Author: Eric Liang Closes #15708 from ericl/spark-18167-3. --- .../sql/hive/execution/SQLQuerySuite.scala | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2735d3a5267e..f64010a64b01 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1566,14 +1566,26 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("SPARK-10562: partition by column with mixed case name") { - withTable("tbl10562") { - val df = Seq(2012 -> "a").toDF("Year", "val") - df.write.partitionBy("Year").saveAsTable("tbl10562") - checkAnswer(sql("SELECT year FROM tbl10562"), Row(2012)) - checkAnswer(sql("SELECT Year FROM tbl10562"), Row(2012)) - checkAnswer(sql("SELECT yEAr FROM tbl10562"), Row(2012)) - checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year > 2015"), Nil) - checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year == 2012"), Row("a")) + def runOnce() { + withTable("tbl10562") { + val df = Seq(2012 -> "a").toDF("Year", "val") + df.write.partitionBy("Year").saveAsTable("tbl10562") + checkAnswer(sql("SELECT year FROM tbl10562"), Row(2012)) + checkAnswer(sql("SELECT Year FROM tbl10562"), Row(2012)) + checkAnswer(sql("SELECT yEAr FROM tbl10562"), Row(2012)) + checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year > 2015"), Nil) + checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year == 2012"), Row("a")) + } + } + try { + runOnce() + } catch { + case t: Throwable => + // Retry to gather more test data. TODO(ekl) revert this once we deflake this test. + runOnce() + runOnce() + runOnce() + throw t } } From d9d1465009fb40550467089ede315496552374c5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 31 Oct 2016 22:23:38 -0700 Subject: [PATCH 064/381] [SPARK-18024][SQL] Introduce an internal commit protocol API ## What changes were proposed in this pull request? This patch introduces an internal commit protocol API that is used by the batch data source to do write commits. It currently has only one implementation that uses Hadoop MapReduce's OutputCommitter API. In the future, this commit API can be used to unify streaming and batch commits. ## How was this patch tested? Should be covered by existing write tests. Author: Reynold Xin Author: Eric Liang Closes #15707 from rxin/SPARK-18024-2. --- .../ml/source/libsvm/LibSVMRelation.scala | 17 +- .../datasources/FileCommitProtocol.scala | 254 ++++++++++++++++++ .../execution/datasources/OutputWriter.scala | 26 +- .../execution/datasources/WriteOutput.scala | 167 +++--------- .../datasources/csv/CSVRelation.scala | 17 +- .../datasources/json/JsonFileFormat.scala | 17 +- .../parquet/ParquetFileFormat.scala | 8 +- .../parquet/ParquetOutputWriter.scala | 19 +- .../datasources/text/TextFileFormat.scala | 17 +- .../apache/spark/sql/internal/SQLConf.scala | 29 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 28 +- .../sql/sources/CommitFailureTestSource.scala | 10 +- .../sql/sources/SimpleTextRelation.scala | 19 +- 13 files changed, 387 insertions(+), 241 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 5e9e6ff1a569..cb3ca1b6c4be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -41,17 +41,11 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration private[libsvm] class LibSVMOutputWriter( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { - override val path: String = { - val compressionExtension = TextOutputWriter.getCompressionExtension(context) - new Path(stagingDir, fileNamePrefix + ".libsvm" + compressionExtension).toString - } - private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { @@ -135,11 +129,14 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour dataSchema: StructType): OutputWriterFactory = { new OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new LibSVMOutputWriter(stagingDir, fileNamePrefix, dataSchema, context) + new LibSVMOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + ".libsvm" + TextOutputWriter.getCompressionExtension(context) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala new file mode 100644 index 000000000000..1ce9ae4266c1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.{Date, UUID} + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.SparkHadoopWriter +import org.apache.spark.internal.Logging +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + + +object FileCommitProtocol { + class TaskCommitMessage(obj: Any) extends Serializable + + object EmptyTaskCommitMessage extends TaskCommitMessage(Unit) + + /** + * Instantiates a FileCommitProtocol using the given className. + */ + def instantiate(className: String, outputPath: String, isAppend: Boolean): FileCommitProtocol = { + try { + val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] + + // First try the one with argument (outputPath: String, isAppend: Boolean). + // If that doesn't exist, try the one with (outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[Boolean]) + ctor.newInstance(outputPath, isAppend.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String]) + ctor.newInstance(outputPath) + } + } catch { + case e: ClassNotFoundException => + throw e + } + } +} + + +/** + * An interface to define how a Spark job commits its outputs. Implementations must be serializable, + * as the committer instance instantiated on the driver will be used for tasks on executors. + * + * The proper call sequence is: + * + * 1. Driver calls setupJob. + * 2. As part of each task's execution, executor calls setupTask and then commitTask + * (or abortTask if task failed). + * 3. When all necessary tasks completed successfully, the driver calls commitJob. If the job + * failed to execute (e.g. too many failed tasks), the job should call abortJob. + */ +abstract class FileCommitProtocol { + import FileCommitProtocol._ + + /** + * Setups up a job. Must be called on the driver before any other methods can be invoked. + */ + def setupJob(jobContext: JobContext): Unit + + /** + * Commits a job after the writes succeed. Must be called on the driver. + */ + def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit + + /** + * Aborts a job after the writes fail. Must be called on the driver. + * + * Calling this function is a best-effort attempt, because it is possible that the driver + * just crashes (or killed) before it can call abort. + */ + def abortJob(jobContext: JobContext): Unit + + /** + * Sets up a task within a job. + * Must be called before any other task related methods can be invoked. + */ + def setupTask(taskContext: TaskAttemptContext): Unit + + /** + * Notifies the commit protocol to add a new file, and gets back the full path that should be + * used. Must be called on the executors when running tasks. + * + * Note that the returned temp file may have an arbitrary path. The commit protocol only + * promises that the file will be at the location specified by the arguments after job commit. + * + * A full file path consists of the following parts: + * 1. the base path + * 2. some sub-directory within the base path, used to specify partitioning + * 3. file prefix, usually some unique job id with the task id + * 4. bucket id + * 5. source specific file extension, e.g. ".snappy.parquet" + * + * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest + * are left to the commit protocol implementation to decide. + */ + def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + + /** + * Commits a task after the writes succeed. Must be called on the executors when running tasks. + */ + def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage + + /** + * Aborts a task after the writes have failed. Must be called on the executors when running tasks. + * + * Calling this function is a best-effort attempt, because it is possible that the executor + * just crashes (or killed) before it can call abort. + */ + def abortTask(taskContext: TaskAttemptContext): Unit +} + + +/** + * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter + * (from the newer mapreduce API, not the old mapred API). + * + * Unlike Hadoop's OutputCommitter, this implementation is serializable. + */ +class HadoopCommitProtocolWrapper(path: String, isAppend: Boolean) + extends FileCommitProtocol with Serializable with Logging { + + import FileCommitProtocol._ + + /** OutputCommitter from Hadoop is not serializable so marking it transient. */ + @transient private var committer: OutputCommitter = _ + + /** UUID used to identify the job in file name. */ + private val uuid: String = UUID.randomUUID().toString + + private def setupCommitter(context: TaskAttemptContext): Unit = { + committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context) + + if (!isAppend) { + // If we are appending data to an existing dir, we will only use the output committer + // associated with the file output format since it is not safe to use a custom + // committer for appending. For example, in S3, direct parquet output committer may + // leave partial data in the destination dir when the appending job fails. + // See SPARK-8578 for more details. + val configuration = context.getConfiguration + val clazz = + configuration.getClass(SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) + + if (clazz != null) { + logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") + + // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat + // has an associated output committer. To override this output committer, + // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. + // If a data source needs to override the output committer, it needs to set the + // output committer in prepareForWrite method. + if (classOf[FileOutputCommitter].isAssignableFrom(clazz)) { + // The specified output committer is a FileOutputCommitter. + // So, we will use the FileOutputCommitter-specified constructor. + val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + committer = ctor.newInstance(new Path(path), context) + } else { + // The specified output committer is just an OutputCommitter. + // So, we will use the no-argument constructor. + val ctor = clazz.getDeclaredConstructor() + committer = ctor.newInstance() + } + } + } + logInfo(s"Using output committer class ${committer.getClass.getCanonicalName}") + } + + override def newTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + val filename = f"part-$split%05d-$uuid$ext" + + val stagingDir: String = committer match { + // For FileOutputCommitter it has its own staging path called "work path". + case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path) + case _ => path + } + + dir.map { d => + new Path(new Path(stagingDir, d), filename).toString + }.getOrElse { + new Path(stagingDir, filename).toString + } + } + + override def setupJob(jobContext: JobContext): Unit = { + // Setup IDs + val jobId = SparkHadoopWriter.createJobID(new Date, 0) + val taskId = new TaskID(jobId, TaskType.MAP, 0) + val taskAttemptId = new TaskAttemptID(taskId, 0) + + // Set up the configuration object + jobContext.getConfiguration.set("mapred.job.id", jobId.toString) + jobContext.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString) + jobContext.getConfiguration.set("mapred.task.id", taskAttemptId.toString) + jobContext.getConfiguration.setBoolean("mapred.task.is.map", true) + jobContext.getConfiguration.setInt("mapred.task.partition", 0) + + val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) + setupCommitter(taskAttemptContext) + + committer.setupJob(jobContext) + } + + override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { + committer.commitJob(jobContext) + } + + override def abortJob(jobContext: JobContext): Unit = { + committer.abortJob(jobContext, JobStatus.State.FAILED) + } + + override def setupTask(taskContext: TaskAttemptContext): Unit = { + setupCommitter(taskContext) + committer.setupTask(taskContext) + } + + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { + val attemptId = taskContext.getTaskAttemptID + SparkHadoopMapRedUtil.commitTask( + committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) + EmptyTaskCommitMessage + } + + override def abortTask(taskContext: TaskAttemptContext): Unit = { + committer.abortTask(taskContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala index fbf6e96d3f85..a73c8146c1b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala @@ -30,28 +30,21 @@ import org.apache.spark.sql.types.StructType * to executor side to create actual [[OutputWriter]]s on the fly. */ abstract class OutputWriterFactory extends Serializable { + + /** Returns the file extension to be used when writing files out. */ + def getFileExtension(context: TaskAttemptContext): String + /** * When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side * to instantiate new [[OutputWriter]]s. * - * @param stagingDir Base path (directory) of the file to which this [[OutputWriter]] is supposed - * to write. Note that this may not point to the final output file. For - * example, `FileOutputFormat` writes to temporary directories and then merge - * written files back to the final destination. In this case, `path` points to - * a temporary output file under the temporary directory. - * @param fileNamePrefix Prefix of the file name. The returned OutputWriter must make sure this - * prefix is used in the actual file name. For example, if the prefix is - * "part-1-2-3", then the file name must start with "part_1_2_3" but can - * end in arbitrary extension that is deterministic given the configuration - * (i.e. the suffix extension should not depend on any task id, attempt id, - * or partition id). + * @param path Path to write the file. * @param dataSchema Schema of the rows to be written. Partition columns are not included in the * schema if the relation being written is partitioned. * @param context The Hadoop MapReduce task context. */ def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter @@ -77,13 +70,6 @@ abstract class OutputWriterFactory extends Serializable { * executor side. This instance is used to persist rows to this single output file. */ abstract class OutputWriter { - - /** - * The path of the file to be written out. This path should include the staging directory and - * the file name prefix passed into the associated createOutputWriter function. - */ - def path: String - /** * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned * tables, dynamic partition columns are not included in rows to be written. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala index 0eb86fdd6caa..a07855111b40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala @@ -24,12 +24,11 @@ import scala.collection.mutable import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat} +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -38,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SQLExecution, UnsafeKVExternalSorter} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.execution.datasources.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter @@ -57,8 +56,7 @@ object WriteOutput extends Logging { val nonPartitionColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], val isAppend: Boolean, - val path: String, - val outputFormatClass: Class[_ <: OutputFormat[_, _]]) + val path: String) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns), @@ -114,31 +112,38 @@ object WriteOutput extends Logging { nonPartitionColumns = dataColumns, bucketSpec = bucketSpec, isAppend = isAppend, - path = outputPath.toString, - outputFormatClass = job.getOutputFormatClass) + path = outputPath.toString) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - val committer = setupDriverCommitter(job, outputPath.toString, isAppend) + val committer = FileCommitProtocol.instantiate( + sparkSession.sessionState.conf.fileCommitProtocolClass, + outputPath.toString, + isAppend) + committer.setupJob(job) try { - val updatedPartitions = sparkSession.sparkContext.runJob(queryExecution.toRdd, + val ret = sparkSession.sparkContext.runJob(queryExecution.toRdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, sparkStageId = taskContext.stageId(), sparkPartitionId = taskContext.partitionId(), sparkAttemptNumber = taskContext.attemptNumber(), + committer, iterator = iter) - }).flatten.distinct + }) - committer.commitJob(job) + val commitMsgs = ret.map(_._1) + val updatedPartitions = ret.flatMap(_._2).distinct.map(PartitioningUtils.parsePathFragment) + + committer.commitJob(job, commitMsgs) logInfo(s"Job ${job.getJobID} committed.") - refreshFunction(updatedPartitions.map(PartitioningUtils.parsePathFragment)) + refreshFunction(updatedPartitions) } catch { case cause: Throwable => logError(s"Aborting job ${job.getJobID}.", cause) - committer.abortJob(job, JobStatus.State.FAILED) + committer.abortJob(job) throw new SparkException("Job aborted.", cause) } } @@ -150,7 +155,8 @@ object WriteOutput extends Logging { sparkStageId: Int, sparkPartitionId: Int, sparkAttemptNumber: Int, - iterator: Iterator[InternalRow]): Set[String] = { + committer: FileCommitProtocol, + iterator: Iterator[InternalRow]): (TaskCommitMessage, Set[String]) = { val jobId = SparkHadoopWriter.createJobID(new Date, sparkStageId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) @@ -169,33 +175,21 @@ object WriteOutput extends Logging { new TaskAttemptContextImpl(hadoopConf, taskAttemptId) } - val committer = newOutputCommitter( - description.outputFormatClass, taskAttemptContext, description.path, description.isAppend) committer.setupTask(taskAttemptContext) - // Figure out where we need to write data to for staging. - // For FileOutputCommitter it has its own staging path called "work path". - val stagingPath = committer match { - case f: FileOutputCommitter => f.getWorkPath.toString - case _ => description.path - } - val writeTask = if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) { - new SingleDirectoryWriteTask(description, taskAttemptContext, stagingPath) + new SingleDirectoryWriteTask(description, taskAttemptContext, committer) } else { - new DynamicPartitionWriteTask(description, taskAttemptContext, stagingPath) + new DynamicPartitionWriteTask(description, taskAttemptContext, committer) } try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - // Execute the task to write rows out - val outputPaths = writeTask.execute(iterator) + // Execute the task to write rows out and commit the task. + val outputPartitions = writeTask.execute(iterator) writeTask.releaseResources() - - // Commit the task - SparkHadoopMapRedUtil.commitTask(committer, taskAttemptContext, jobId.getId, taskId.getId) - outputPaths + (committer.commitTask(taskAttemptContext), outputPartitions) })(catchBlock = { // If there is an error, release resource and then abort the task try { @@ -217,27 +211,28 @@ object WriteOutput extends Logging { * automatically trigger task aborts. */ private trait ExecuteWriteTask { + /** + * Writes data out to files, and then returns the list of partition strings written out. + * The list of partitions is sent back to the driver and used to update the catalog. + */ def execute(iterator: Iterator[InternalRow]): Set[String] def releaseResources(): Unit - - final def filePrefix(split: Int, uuid: String, bucketId: Option[Int]): String = { - val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - f"part-r-$split%05d-$uuid$bucketString" - } } /** Writes data to a single directory (used for non-dynamic-partition writes). */ private class SingleDirectoryWriteTask( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - stagingPath: String) extends ExecuteWriteTask { + committer: FileCommitProtocol) extends ExecuteWriteTask { private[this] var outputWriter: OutputWriter = { - val split = taskAttemptContext.getTaskAttemptID.getTaskID.getId + val tmpFilePath = committer.newTaskTempFile( + taskAttemptContext, + None, + description.outputWriterFactory.getFileExtension(taskAttemptContext)) val outputWriter = description.outputWriterFactory.newInstance( - stagingDir = stagingPath, - fileNamePrefix = filePrefix(split, description.uuid, None), + path = tmpFilePath, dataSchema = description.nonPartitionColumns.toStructType, context = taskAttemptContext) outputWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType) @@ -267,7 +262,7 @@ object WriteOutput extends Logging { private class DynamicPartitionWriteTask( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - stagingPath: String) extends ExecuteWriteTask { + committer: FileCommitProtocol) extends ExecuteWriteTask { // currentWriter is initialized whenever we see a new key private var currentWriter: OutputWriter = _ @@ -307,25 +302,20 @@ object WriteOutput extends Logging { * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet */ private def newOutputWriter(key: InternalRow, partString: UnsafeProjection): OutputWriter = { - val path = - if (description.partitionColumns.nonEmpty) { - val partitionPath = partString(key).getString(0) - new Path(stagingPath, partitionPath).toString - } else { - stagingPath - } + val partDir = + if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0)) // If the bucket spec is defined, the bucket column is right after the partition columns val bucketId = if (description.bucketSpec.isDefined) { - Some(key.getInt(description.partitionColumns.length)) + BucketingUtils.bucketIdToString(key.getInt(description.partitionColumns.length)) } else { - None + "" } + val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext) - val split = taskAttemptContext.getTaskAttemptID.getTaskID.getId + val path = committer.newTaskTempFile(taskAttemptContext, partDir, ext) val newWriter = description.outputWriterFactory.newInstance( - stagingDir = path, - fileNamePrefix = filePrefix(split, description.uuid, bucketId), + path = path, dataSchema = description.nonPartitionColumns.toStructType, context = taskAttemptContext) newWriter.initConverter(description.nonPartitionColumns.toStructType) @@ -413,75 +403,4 @@ object WriteOutput extends Logging { } } } - - private def setupDriverCommitter(job: Job, path: String, isAppend: Boolean): OutputCommitter = { - // Setup IDs - val jobId = SparkHadoopWriter.createJobID(new Date, 0) - val taskId = new TaskID(jobId, TaskType.MAP, 0) - val taskAttemptId = new TaskAttemptID(taskId, 0) - - // Set up the configuration object - job.getConfiguration.set("mapred.job.id", jobId.toString) - job.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString) - job.getConfiguration.set("mapred.task.id", taskAttemptId.toString) - job.getConfiguration.setBoolean("mapred.task.is.map", true) - job.getConfiguration.setInt("mapred.task.partition", 0) - - val taskAttemptContext = new TaskAttemptContextImpl(job.getConfiguration, taskAttemptId) - val outputCommitter = newOutputCommitter( - job.getOutputFormatClass, taskAttemptContext, path, isAppend) - outputCommitter.setupJob(job) - outputCommitter - } - - private def newOutputCommitter( - outputFormatClass: Class[_ <: OutputFormat[_, _]], - context: TaskAttemptContext, - path: String, - isAppend: Boolean): OutputCommitter = { - val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) - - if (isAppend) { - // If we are appending data to an existing dir, we will only use the output committer - // associated with the file output format since it is not safe to use a custom - // committer for appending. For example, in S3, direct parquet output committer may - // leave partial data in the destination dir when the appending job fails. - // See SPARK-8578 for more details - logInfo( - s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + - "for appending.") - defaultOutputCommitter - } else { - val configuration = context.getConfiguration - val clazz = - configuration.getClass(SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) - - if (clazz != null) { - logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") - - // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat - // has an associated output committer. To override this output committer, - // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. - // If a data source needs to override the output committer, it needs to set the - // output committer in prepareForWrite method. - if (classOf[FileOutputCommitter].isAssignableFrom(clazz)) { - // The specified output committer is a FileOutputCommitter. - // So, we will use the FileOutputCommitter-specified constructor. - val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - ctor.newInstance(new Path(path), context) - } else { - // The specified output committer is just an OutputCommitter. - // So, we will use the no-argument constructor. - val ctor = clazz.getDeclaredConstructor() - ctor.newInstance() - } - } else { - // If output committer class is not set, we will use the one associated with the - // file output format. - logInfo( - s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}") - defaultOutputCommitter - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index a35cfdb2c234..a249b9d9d59b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -171,26 +171,23 @@ object CSVRelation extends Logging { private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new CsvOutputWriter(stagingDir, fileNamePrefix, dataSchema, context, params) + new CsvOutputWriter(path, dataSchema, context, params) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + ".csv" + TextOutputWriter.getCompressionExtension(context) } } private[csv] class CsvOutputWriter( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext, params: CSVOptions) extends OutputWriter with Logging { - override val path: String = { - val compressionExtension = TextOutputWriter.getCompressionExtension(context) - new Path(stagingDir, fileNamePrefix + ".csv" + compressionExtension).toString - } - // create the Generator without separator inserted between 2 records private[this] val text = new Text() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 651fa78a4e92..5a409c04c929 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -83,11 +83,14 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { new OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new JsonOutputWriter(stagingDir, parsedOptions, fileNamePrefix, dataSchema, context) + new JsonOutputWriter(path, parsedOptions, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + ".json" + TextOutputWriter.getCompressionExtension(context) } } } @@ -154,18 +157,12 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { } private[json] class JsonOutputWriter( - stagingDir: String, + path: String, options: JSONOptions, - fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter with Logging { - override val path: String = { - val compressionExtension = TextOutputWriter.getCompressionExtension(context) - new Path(stagingDir, fileNamePrefix + ".json" + compressionExtension).toString - } - private[this] val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 502dd0e8d4cf..77c83ba38efe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -33,6 +33,7 @@ import org.apache.parquet.{Log => ApacheParquetLog} import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.codec.CodecConfig import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType import org.slf4j.bridge.SLF4JBridgeHandler @@ -133,10 +134,13 @@ class ParquetFileFormat new OutputWriterFactory { override def newInstance( path: String, - fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new ParquetOutputWriter(path, fileNamePrefix, context) + new ParquetOutputWriter(path, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + CodecConfig.from(context).getCodec.getExtension + ".parquet" } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala index 1300069c42b0..92d4f27be3fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -89,7 +89,7 @@ private[parquet] class ParquetOutputWriterFactory( * Returns a [[OutputWriter]] that writes data to the give path without using * [[OutputCommitter]]. */ - override def newWriter(path1: String): OutputWriter = new OutputWriter { + override def newWriter(path: String): OutputWriter = new OutputWriter { // Create TaskAttemptContext that is used to pass on Configuration to the ParquetRecordWriter private val hadoopTaskAttemptId = new TaskAttemptID(new TaskID(new JobID, TaskType.MAP, 0), 0) @@ -99,8 +99,6 @@ private[parquet] class ParquetOutputWriterFactory( // Instance of ParquetRecordWriter that does not use OutputCommitter private val recordWriter = createNoCommitterRecordWriter(path, hadoopAttemptContext) - override def path: String = path1 - override def write(row: Row): Unit = { throw new UnsupportedOperationException("call writeInternal") } @@ -127,27 +125,22 @@ private[parquet] class ParquetOutputWriterFactory( /** Disable the use of the older API. */ override def newInstance( path: String, - fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { throw new UnsupportedOperationException("this version of newInstance not supported for " + "ParquetOutputWriterFactory") } + + override def getFileExtension(context: TaskAttemptContext): String = { + CodecConfig.from(context).getCodec.getExtension + ".parquet" + } } // NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[parquet] class ParquetOutputWriter( - stagingDir: String, - fileNamePrefix: String, - context: TaskAttemptContext) +private[parquet] class ParquetOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { - override val path: String = { - val filename = fileNamePrefix + CodecConfig.from(context).getCodec.getExtension + ".parquet" - new Path(stagingDir, filename).toString - } - private val recordWriter: RecordWriter[Void, InternalRow] = { new ParquetOutputFormat[InternalRow]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index d40b5725199a..8e043960326d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -75,11 +75,14 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { new OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new TextOutputWriter(stagingDir, fileNamePrefix, dataSchema, context) + new TextOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + ".txt" + TextOutputWriter.getCompressionExtension(context) } } } @@ -124,17 +127,11 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } class TextOutputWriter( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { - override val path: String = { - val compressionExtension = TextOutputWriter.getCompressionExtension(context) - new Path(stagingDir, fileNamePrefix + ".txt" + compressionExtension).toString - } - private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index dc31f3bc323f..29e79847aa38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.execution.datasources.HadoopCommitProtocolWrapper import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -240,9 +241,8 @@ object SQLConf { val PARQUET_OUTPUT_COMMITTER_CLASS = SQLConfigBuilder("spark.sql.parquet.output.committer.class") .doc("The output committer class used by Parquet. The specified class needs to be a " + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + - "of org.apache.parquet.hadoop.ParquetOutputCommitter. NOTE: 1. Instead of SQLConf, this " + - "option must be set in Hadoop Configuration. 2. This option overrides " + - "\"spark.sql.sources.outputCommitterClass\".") + "of org.apache.parquet.hadoop.ParquetOutputCommitter.") + .internal() .stringConf .createWithDefault(classOf[ParquetOutputCommitter].getName) @@ -375,16 +375,17 @@ object SQLConf { .booleanConf .createWithDefault(true) - // The output committer class used by HadoopFsRelation. The specified class needs to be a + // The output committer class used by data sources. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. - // - // NOTE: - // - // 1. Instead of SQLConf, this option *must be set in Hadoop Configuration*. - // 2. This option can be overridden by "spark.sql.parquet.output.committer.class". val OUTPUT_COMMITTER_CLASS = SQLConfigBuilder("spark.sql.sources.outputCommitterClass").internal().stringConf.createOptional + val FILE_COMMIT_PROTOCOL_CLASS = + SQLConfigBuilder("spark.sql.sources.commitProtocolClass") + .internal() + .stringConf + .createWithDefault(classOf[HadoopCommitProtocolWrapper].getName) + val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = SQLConfigBuilder("spark.sql.sources.parallelPartitionDiscovery.threshold") .doc("The maximum number of files allowed for listing files at driver side. If the number " + @@ -518,6 +519,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STREAMING_FILE_COMMIT_PROTOCOL_CLASS = + SQLConfigBuilder("spark.sql.streaming.commitProtocolClass") + .internal() + .stringConf + .createWithDefault(classOf[HadoopCommitProtocolWrapper].getName) + val FILE_SINK_LOG_DELETION = SQLConfigBuilder("spark.sql.streaming.fileSink.log.deletion") .internal() .doc("Whether to delete the expired log files in file stream sink.") @@ -631,6 +638,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) + def streamingFileCommitProtocolClass: String = getConf(STREAMING_FILE_COMMIT_PROTOCOL_CLASS) + def fileSinkLogDeletion: Boolean = getConf(FILE_SINK_LOG_DELETION) def fileSinkLogCompactInterval: Int = getConf(FILE_SINK_LOG_COMPACT_INTERVAL) @@ -741,6 +750,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def partitionColumnTypeInferenceEnabled: Boolean = getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE) + def fileCommitProtocolClass: String = getConf(SQLConf.FILE_COMMIT_PROTOCOL_CLASS) + def parallelPartitionDiscoveryThreshold: Int = getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index eba7aa386ade..7c519a074317 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -83,11 +83,19 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable new OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new OrcOutputWriter(stagingDir, fileNamePrefix, dataSchema, context) + new OrcOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + val compressionExtension: String = { + val name = context.getConfiguration.get(OrcRelation.ORC_COMPRESSION) + OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") + } + + compressionExtension + ".orc" } } } @@ -210,23 +218,11 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) } private[orc] class OrcOutputWriter( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { - override val path: String = { - val compressionExtension: String = { - val name = context.getConfiguration.get(OrcRelation.ORC_COMPRESSION) - OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") - } - // It has the `.orc` extension at the end because (de)compression tools - // such as gunzip would not be able to decompress this as the compression - // is not applied on this whole file but on each "stream" in ORC format. - new Path(stagingDir, fileNamePrefix + compressionExtension + ".orc").toString - } - private[this] val serializer = new OrcSerializer(dataSchema, context.getConfiguration) // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala index 731540db17ee..abc7c8cc4db8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.sources -import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext @@ -40,19 +39,16 @@ class CommitFailureTestSource extends SimpleTextSource { dataSchema: StructType): OutputWriterFactory = new OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(stagingDir, fileNamePrefix, context) { + new SimpleTextOutputWriter(path, context) { var failed = false TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => failed = true SimpleTextRelation.callbackCalled = true } - override val path: String = new Path(stagingDir, fileNamePrefix).toString - override def write(row: Row): Unit = { if (SimpleTextRelation.failWriter) { sys.error("Intentional task writer failure for testing purpose.") @@ -67,6 +63,8 @@ class CommitFailureTestSource extends SimpleTextSource { } } } + + override def getFileExtension(context: TaskAttemptContext): String = "" } override def shortName(): String = "commit-failure-test" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 9896b9bde99c..64d0ecbeefc9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -51,12 +51,13 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { SimpleTextRelation.lastHadoopConf = Option(job.getConfiguration) new OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(stagingDir, fileNamePrefix, context) + new SimpleTextOutputWriter(path, context) } + + override def getFileExtension(context: TaskAttemptContext): String = "" } } @@ -120,14 +121,11 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { } } -class SimpleTextOutputWriter( - stagingDir: String, fileNamePrefix: String, context: TaskAttemptContext) +class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { - override val path: String = new Path(stagingDir, fileNamePrefix).toString - private val recordWriter: RecordWriter[NullWritable, Text] = - new AppendingTextOutputFormat(new Path(stagingDir), fileNamePrefix).getRecordWriter(context) + new AppendingTextOutputFormat(path).getRecordWriter(context) override def write(row: Row): Unit = { val serialized = row.toSeq.map { v => @@ -141,15 +139,14 @@ class SimpleTextOutputWriter( } } -class AppendingTextOutputFormat(stagingDir: Path, fileNamePrefix: String) - extends TextOutputFormat[NullWritable, Text] { +class AppendingTextOutputFormat(path: String) extends TextOutputFormat[NullWritable, Text] { val numberFormat = NumberFormat.getInstance() numberFormat.setMinimumIntegerDigits(5) numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(stagingDir, fileNamePrefix) + new Path(path) } } From dd85eb5448c8f2672260b57e94c0da0eaac12616 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 Nov 2016 00:24:08 -0700 Subject: [PATCH 065/381] [SPARK-18107][SQL] Insert overwrite statement runs much slower in spark-sql than it does in hive-client ## What changes were proposed in this pull request? As reported on the jira, insert overwrite statement runs much slower in Spark, compared with hive-client. It seems there is a patch [HIVE-11940](https://github.com/apache/hive/commit/ba21806b77287e237e1aa68fa169d2a81e07346d) which largely improves insert overwrite performance on Hive. HIVE-11940 is patched after Hive 2.0.0. Because Spark SQL uses older Hive library, we can not benefit from such improvement. The reporter verified that there is also a big performance gap between Hive 1.2.1 (520.037 secs) and Hive 2.0.1 (35.975 secs) on insert overwrite execution. Instead of upgrading to Hive 2.0 in Spark SQL, which might not be a trivial task, this patch provides an approach to delete the partition before asking Hive to load data files into the partition. Note: The case reported on the jira is insert overwrite to partition. Since `Hive.loadTable` also uses the function to replace files, insert overwrite to table should has the same issue. We can take the same approach to delete the table first. I will upgrade this to include this. ## How was this patch tested? Jenkins tests. There are existing tests using insert overwrite statement. Those tests should be passed. I added a new test to specially test insert overwrite into partition. For performance issue, as I don't have Hive 2.0 environment, this needs the reporter to verify it. Please refer to the jira. Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. Author: Liang-Chi Hsieh Closes #15667 from viirya/improve-hive-insertoverwrite. --- .../hive/execution/InsertIntoHiveTable.scala | 24 +++++++++++++- .../sql/hive/execution/SQLQuerySuite.scala | 33 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index c3c4e2925b90..2843100fb3b3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, AlterTableDropPartitionCommand} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.SparkException @@ -257,7 +258,28 @@ case class InsertIntoHiveTable( table.catalogTable.identifier.table, partitionSpec) + var doHiveOverwrite = overwrite + if (oldPart.isEmpty || !ifNotExists) { + // SPARK-18107: Insert overwrite runs much slower than hive-client. + // Newer Hive largely improves insert overwrite performance. As Spark uses older Hive + // version and we may not want to catch up new Hive version every time. We delete the + // Hive partition first and then load data file into the Hive partition. + if (oldPart.nonEmpty && overwrite) { + oldPart.get.storage.locationUri.map { uri => + val partitionPath = new Path(uri) + val fs = partitionPath.getFileSystem(hadoopConf) + if (fs.exists(partitionPath)) { + if (!fs.delete(partitionPath, true)) { + throw new RuntimeException( + "Cannot remove partition directory '" + partitionPath.toString) + } + // Don't let Hive do overwrite operation since it is slower. + doHiveOverwrite = false + } + } + } + // inheritTableSpecs is set to true. It should be set to false for an IMPORT query // which is currently considered as a Hive native command. val inheritTableSpecs = true @@ -266,7 +288,7 @@ case class InsertIntoHiveTable( table.catalogTable.identifier.table, outputPath.toString, partitionSpec, - isOverwrite = overwrite, + isOverwrite = doHiveOverwrite, holdDDLTime = holdDDLTime, inheritTableSpecs = inheritTableSpecs) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f64010a64b01..8b916932ff54 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1973,6 +1973,39 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("Insert overwrite with partition") { + withTable("tableWithPartition") { + sql( + """ + |CREATE TABLE tableWithPartition (key int, value STRING) + |PARTITIONED BY (part STRING) + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE tableWithPartition PARTITION (part = '1') + |SELECT * FROM default.src + """.stripMargin) + checkAnswer( + sql("SELECT part, key, value FROM tableWithPartition"), + sql("SELECT '1' AS part, key, value FROM default.src") + ) + + sql( + """ + |INSERT OVERWRITE TABLE tableWithPartition PARTITION (part = '1') + |SELECT * FROM VALUES (1, "one"), (2, "two"), (3, null) AS data(key, value) + """.stripMargin) + checkAnswer( + sql("SELECT part, key, value FROM tableWithPartition"), + sql( + """ + |SELECT '1' AS part, key, value FROM VALUES + |(1, "one"), (2, "two"), (3, null) AS data(key, value) + """.stripMargin) + ) + } + } + def testCommandAvailable(command: String): Boolean = { val attempt = Try(Process(command).run(ProcessLogger(_ => ())).exitValue()) attempt.isSuccess && attempt.get == 0 From 623fc7fc67735cfafdb7f527bd3df210987943c6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 1 Nov 2016 13:08:49 +0000 Subject: [PATCH 066/381] [MINOR][DOC] Remove spaces following slashs ## What changes were proposed in this pull request? This PR merges multiple lines enumerating items in order to remove the redundant spaces following slashes in [Structured Streaming Programming Guide in 2.0.2-rc1](http://people.apache.org/~pwendell/spark-releases/spark-2.0.2-rc1-docs/structured-streaming-programming-guide.html). - Before: `Scala/ Java/ Python` - After: `Scala/Java/Python` ## How was this patch tested? Manual by the followings because this is documentation update. ``` cd docs SKIP_API=1 jekyll build ``` Author: Dongjoon Hyun Closes #15686 from dongjoon-hyun/minor_doc_space. --- .../structured-streaming-programming-guide.md | 44 +++++++++---------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 173fd6e8c73b..d838ed35a14f 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -14,10 +14,8 @@ Structured Streaming is a scalable and fault-tolerant stream processing engine b # Quick Example Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in -[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/ -[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/ -[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py). And if you -[download Spark](http://spark.apache.org/downloads.html), you can directly run the example. In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark. +[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py). +And if you [download Spark](http://spark.apache.org/downloads.html), you can directly run the example. In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark.
@@ -409,16 +407,15 @@ Delivering end-to-end exactly-once semantics was one of key goals behind the des to track the read position in the stream. The engine uses checkpointing and write ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure. # API using Datasets and DataFrames -Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` ([Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/ -[Java](api/java/org/apache/spark/sql/SparkSession.html)/ -[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession) docs) to create streaming DataFrames/Datasets from streaming sources, and apply the same operations on them as static DataFrames/Datasets. If you are not familiar with Datasets/DataFrames, you are strongly advised to familiarize yourself with them using the +Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` +([Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/[Java](api/java/org/apache/spark/sql/SparkSession.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession) docs) +to create streaming DataFrames/Datasets from streaming sources, and apply the same operations on them as static DataFrames/Datasets. If you are not familiar with Datasets/DataFrames, you are strongly advised to familiarize yourself with them using the [DataFrame/Dataset Programming Guide](sql-programming-guide.html). ## Creating streaming DataFrames and streaming Datasets Streaming DataFrames can be created through the `DataStreamReader` interface -([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamReader)/ -[Java](api/java/org/apache/spark/sql/streaming/DataStreamReader.html)/ -[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamReader)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamReader.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) +returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. #### Data Sources In Spark 2.0, there are a few built-in sources. @@ -628,9 +625,7 @@ The result tables would look something like the following. ![Window Operations](img/structured-streaming-window.png) Since this windowing is similar to grouping, in code, you can use `groupBy()` and `window()` operations to express windowed aggregations. You can see the full code for the below examples in -[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala)/ -[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java)/ -[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py). +[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py).
@@ -753,10 +748,9 @@ In addition, there are some Dataset methods that will not work on streaming Data If you try any of these operations, you will see an AnalysisException like "operation XYZ is not supported with streaming DataFrames/Datasets". ## Starting Streaming Queries -Once you have defined the final result DataFrame/Dataset, all that is left is for you start the streaming computation. To do that, you have to use the -`DataStreamWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/ -[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/ -[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) returned through `Dataset.writeStream()`. You will have to specify one or more of the following in this interface. +Once you have defined the final result DataFrame/Dataset, all that is left is for you start the streaming computation. To do that, you have to use the `DataStreamWriter` +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) +returned through `Dataset.writeStream()`. You will have to specify one or more of the following in this interface. - *Details of the output sink:* Data format, location, etc. @@ -953,8 +947,9 @@ spark.sql("select * from aggregates").show() # interactively query in-memory t
#### Using Foreach -The `foreach` operation allows arbitrary operations to be computed on the output data. As of Spark 2.0, this is available only for Scala and Java. To use this, you will have to implement the interface `ForeachWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.ForeachWriter)/ -[Java](api/java/org/apache/spark/sql/ForeachWriter.html) docs), which has methods that get called whenever there is a sequence of rows generated as output after a trigger. Note the following important points. +The `foreach` operation allows arbitrary operations to be computed on the output data. As of Spark 2.0, this is available only for Scala and Java. To use this, you will have to implement the interface `ForeachWriter` +([Scala](api/scala/index.html#org.apache.spark.sql.ForeachWriter)/[Java](api/java/org/apache/spark/sql/ForeachWriter.html) docs), +which has methods that get called whenever there is a sequence of rows generated as output after a trigger. Note the following important points. - The writer must be serializable, as it will be serialized and sent to the executors for execution. @@ -1046,9 +1041,9 @@ query.sinkStatus() # progress information about data written to the output sin
-You can start any number of queries in a single SparkSession. They will all be running concurrently sharing the cluster resources. You can use `sparkSession.streams()` to get the `StreamingQueryManager` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryManager)/ -[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryManager.html)/ -[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.StreamingQueryManager) docs) that can be used to manage the currently active queries. +You can start any number of queries in a single SparkSession. They will all be running concurrently sharing the cluster resources. You can use `sparkSession.streams()` to get the `StreamingQueryManager` +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryManager)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryManager.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.StreamingQueryManager) docs) +that can be used to manage the currently active queries.
@@ -1092,8 +1087,9 @@ spark.streams().awaitAnyTermination() # block until any one of them terminates
-Finally, for asynchronous monitoring of streaming queries, you can create and attach a `StreamingQueryListener` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/ -[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs), which will give you regular callback-based updates when queries are started and terminated. +Finally, for asynchronous monitoring of streaming queries, you can create and attach a `StreamingQueryListener` +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs), +which will give you regular callback-based updates when queries are started and terminated. ## Recovering from Failures with Checkpointing In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. As of Spark 2.0, this checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). From cb80edc26349e2e358d27fe2ae8e5d6959b77fab Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Tue, 1 Nov 2016 13:11:24 +0000 Subject: [PATCH 067/381] [SPARK-18111][SQL] Wrong ApproximatePercentile answer when multiple records have the minimum value ## What changes were proposed in this pull request? When multiple records have the minimum value, the answer of ApproximatePercentile is wrong. ## How was this patch tested? add a test case Author: wangzhenhua Closes #15641 from wzhfy/percentile. --- .../spark/sql/catalyst/util/QuantileSummaries.scala | 4 +++- .../spark/sql/ApproximatePercentileQuerySuite.scala | 11 +++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index 27928c493d5f..04f4ff2a9224 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -264,7 +264,9 @@ object QuantileSummaries { res.prepend(head) // If necessary, add the minimum element: val currHead = currentSamples.head - if (currHead.value < head.value) { + // don't add the minimum element if `currentSamples` has only one element (both `currHead` and + // `head` point to the same element) + if (currHead.value <= head.value && currentSamples.length > 1) { res.prepend(currentSamples.head) } res.toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 37d7c442bbeb..e98092df4951 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -64,6 +64,17 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { } } + test("percentile_approx, multiple records with the minimum value in a partition") { + withTempView(table) { + spark.sparkContext.makeRDD(Seq(1, 1, 2, 1, 1, 3, 1, 1, 4, 1, 1, 5), 4).toDF("col") + .createOrReplaceTempView(table) + checkAnswer( + spark.sql(s"SELECT percentile_approx(col, array(0.5)) FROM $table"), + Row(Seq(1.0D)) + ) + } + } + test("percentile_approx, with different accuracies") { withTempView(table) { From e34b4e12673fb76c92f661d7c03527410857a0f8 Mon Sep 17 00:00:00 2001 From: Charles Allen Date: Tue, 1 Nov 2016 13:14:17 +0000 Subject: [PATCH 068/381] [SPARK-15994][MESOS] Allow enabling Mesos fetch cache in coarse executor backend Mesos 0.23.0 introduces a Fetch Cache feature http://mesos.apache.org/documentation/latest/fetcher/ which allows caching of resources specified in command URIs. This patch: - Updates the Mesos shaded protobuf dependency to 0.23.0 - Allows setting `spark.mesos.fetcherCache.enable` to enable the fetch cache for all specified URIs. (URIs must be specified for the setting to have any affect) - Updates documentation for Mesos configuration with the new setting. This patch does NOT: - Allow for per-URI caching configuration. The cache setting is global to ALL URIs for the command. Author: Charles Allen Closes #13713 from drcrallen/SPARK15994. --- docs/running-on-mesos.md | 9 ++++-- .../cluster/mesos/MesosClusterScheduler.scala | 3 +- .../MesosCoarseGrainedSchedulerBackend.scala | 6 ++-- .../cluster/mesos/MesosSchedulerUtils.scala | 6 ++-- ...osCoarseGrainedSchedulerBackendSuite.scala | 28 +++++++++++++++++++ 5 files changed, 45 insertions(+), 7 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 77b06fcf3374..923d8dbebf3d 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -506,8 +506,13 @@ See the [configuration page](configuration.html) for information on Spark config since this configuration is just a upper limit and not a guaranteed amount. - - + + spark.mesos.fetcherCache.enable + false + + If set to `true`, all URIs (example: `spark.executor.uri`, `spark.mesos.uris`) will be cached by the [Mesos fetcher cache](http://mesos.apache.org/documentation/latest/fetcher/) + + # Troubleshooting and Debugging diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 0b454997772d..635712c00d30 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -129,6 +129,7 @@ private[spark] class MesosClusterScheduler( private val queuedCapacity = conf.getInt("spark.mesos.maxDrivers", 200) private val retainedDrivers = conf.getInt("spark.mesos.retainedDrivers", 200) private val maxRetryWaitTime = conf.getInt("spark.mesos.cluster.retry.wait.max", 60) // 1 minute + private val useFetchCache = conf.getBoolean("spark.mesos.fetchCache.enable", false) private val schedulerState = engineFactory.createEngine("scheduler") private val stateLock = new Object() private val finishedDrivers = @@ -396,7 +397,7 @@ private[spark] class MesosClusterScheduler( val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:") ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri => - CommandInfo.URI.newBuilder().setValue(uri.trim()).build()) + CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) } private def getDriverCommandValue(desc: MesosDriverDescription): String = { diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index e67bf3e328f9..5063c1fe988b 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -59,6 +59,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + val useFetcherCache = conf.getBoolean("spark.mesos.fetcherCache.enable", false) + val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) private[this] val shutdownTimeoutMS = @@ -226,10 +228,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") - command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get).setCache(useFetcherCache)) } - conf.getOption("spark.mesos.uris").foreach(setupUris(_, command)) + conf.getOption("spark.mesos.uris").foreach(setupUris(_, command, useFetcherCache)) command.build() } diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 73cc241239c4..9cb60237044a 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -369,9 +369,11 @@ trait MesosSchedulerUtils extends Logging { sc.executorMemory } - def setupUris(uris: String, builder: CommandInfo.Builder): Unit = { + def setupUris(uris: String, + builder: CommandInfo.Builder, + useFetcherCache: Boolean = false): Unit = { uris.split(",").foreach { uri => - builder.addUris(CommandInfo.URI.newBuilder().setValue(uri.trim())) + builder.addUris(CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetcherCache)) } } diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 75ba02e470e2..f73638fda623 100644 --- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -463,6 +463,34 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(launchedTasks.head.getCommand.getUrisList.asScala(0).getValue == url) } + test("mesos supports setting fetcher cache") { + val url = "spark.spark.spark.com" + setBackend(Map( + "spark.mesos.fetcherCache.enable" -> "true", + "spark.executor.uri" -> url + ), false) + val offers = List(Resources(backend.executorMemory(sc), 1)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + val uris = launchedTasks.head.getCommand.getUrisList + assert(uris.size() == 1) + assert(uris.asScala.head.getCache) + } + + test("mesos supports disabling fetcher cache") { + val url = "spark.spark.spark.com" + setBackend(Map( + "spark.mesos.fetcherCache.enable" -> "false", + "spark.executor.uri" -> url + ), false) + val offers = List(Resources(backend.executorMemory(sc), 1)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + val uris = launchedTasks.head.getCommand.getUrisList + assert(uris.size() == 1) + assert(!uris.asScala.head.getCache) + } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) private def verifyDeclinedOffer(driver: SchedulerDriver, From ec6f479bb1d14c9eb45e0418353007be0416e4c5 Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Tue, 1 Nov 2016 13:18:11 +0000 Subject: [PATCH 069/381] [SPARK-16881][MESOS] Migrate Mesos configs to use ConfigEntry ## What changes were proposed in this pull request? Migrate Mesos configs to use ConfigEntry ## How was this patch tested? Jenkins Tests Author: Sandeep Singh Closes #15654 from techaddict/SPARK-16881. --- .../deploy/mesos/MesosClusterDispatcher.scala | 9 +-- .../mesos/MesosExternalShuffleService.scala | 3 +- .../apache/spark/deploy/mesos/config.scala | 59 +++++++++++++++++++ .../deploy/mesos/ui/MesosClusterPage.scala | 3 +- 4 files changed, 68 insertions(+), 6 deletions(-) create mode 100644 mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 73b6ca384438..7d6693b4cdf5 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy.mesos import java.util.concurrent.CountDownLatch import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.internal.Logging @@ -51,7 +52,7 @@ private[mesos] class MesosClusterDispatcher( extends Logging { private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) - private val recoveryMode = conf.get("spark.deploy.recoveryMode", "NONE").toUpperCase() + private val recoveryMode = conf.get(RECOVERY_MODE).toUpperCase() logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) private val engineFactory = recoveryMode match { @@ -74,7 +75,7 @@ private[mesos] class MesosClusterDispatcher( def start(): Unit = { webUi.bind() - scheduler.frameworkUrl = conf.get("spark.mesos.dispatcher.webui.url", webUi.activeWebUiUrl) + scheduler.frameworkUrl = conf.get(DISPATCHER_WEBUI_URL).getOrElse(webUi.activeWebUiUrl) scheduler.start() server.start() } @@ -99,8 +100,8 @@ private[mesos] object MesosClusterDispatcher extends Logging { conf.setMaster(dispatcherArgs.masterUrl) conf.setAppName(dispatcherArgs.name) dispatcherArgs.zookeeperUrl.foreach { z => - conf.set("spark.deploy.recoveryMode", "ZOOKEEPER") - conf.set("spark.deploy.zookeeper.url", z) + conf.set(RECOVERY_MODE, "ZOOKEEPER") + conf.set(ZOOKEEPER_URL, z) } val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) dispatcher.start() diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 6b297c4600a6..859aa836a315 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.ExternalShuffleService +import org.apache.spark.deploy.mesos.config._ import org.apache.spark.internal.Logging import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler @@ -114,7 +115,7 @@ private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManage protected override def newShuffleBlockHandler( conf: TransportConf): ExternalShuffleBlockHandler = { - val cleanerIntervalS = this.conf.getTimeAsSeconds("spark.shuffle.cleaner.interval", "30s") + val cleanerIntervalS = this.conf.get(SHUFFLE_CLEANER_INTERVAL_S) new MesosExternalShuffleBlockHandler(conf, cleanerIntervalS) } } diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala new file mode 100644 index 000000000000..19e253394f1b --- /dev/null +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.concurrent.TimeUnit + +import org.apache.spark.internal.config.ConfigBuilder + +package object config { + + /* Common app configuration. */ + + private[spark] val SHUFFLE_CLEANER_INTERVAL_S = + ConfigBuilder("spark.shuffle.cleaner.interval") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("30s") + + private[spark] val RECOVERY_MODE = + ConfigBuilder("spark.deploy.recoveryMode") + .stringConf + .createWithDefault("NONE") + + private[spark] val DISPATCHER_WEBUI_URL = + ConfigBuilder("spark.mesos.dispatcher.webui.url") + .doc("Set the Spark Mesos dispatcher webui_url for interacting with the " + + "framework. If unset it will point to Spark's internal web UI.") + .stringConf + .createOptional + + private[spark] val ZOOKEEPER_URL = + ConfigBuilder("spark.deploy.zookeeper.url") + .doc("When `spark.deploy.recoveryMode` is set to ZOOKEEPER, this " + + "configuration is used to set the zookeeper URL to connect to.") + .stringConf + .createOptional + + private[spark] val HISTORY_SERVER_URL = + ConfigBuilder("spark.mesos.dispatcher.historyServer.url") + .doc("Set the URL of the history server. The dispatcher will then " + + "link each driver to its entry in the history server.") + .stringConf + .createOptional + +} diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 8dcbdaad8685..13ba7d311e57 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -23,12 +23,13 @@ import scala.xml.Node import org.apache.mesos.Protos.TaskStatus +import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.scheduler.cluster.mesos.MesosClusterSubmissionState import org.apache.spark.ui.{UIUtils, WebUIPage} private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage("") { - private val historyServerURL = parent.conf.getOption("spark.mesos.dispatcher.historyServer.url") + private val historyServerURL = parent.conf.get(HISTORY_SERVER_URL) def render(request: HttpServletRequest): Seq[Node] = { val state = parent.scheduler.getSchedulerState() From 9b377aa49f14af31f54164378d60e0fdea2142e5 Mon Sep 17 00:00:00 2001 From: Wang Lei Date: Tue, 1 Nov 2016 13:42:10 +0000 Subject: [PATCH 070/381] [SPARK-18114][MESOS] Fix mesos cluster scheduler generage command option error ## What changes were proposed in this pull request? Enclose --conf option value with "" to support multi value configs like spark.driver.extraJavaOptions, without "", driver will fail to start. ## How was this patch tested? Jenkins Tests. Test in our production environment, also unit tests, It is a very small change. Author: Wang Lei Closes #15643 from LeightonWong/messos-cluster. --- .../spark/scheduler/cluster/mesos/MesosClusterScheduler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 635712c00d30..8db1d126d59b 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -482,7 +482,7 @@ private[spark] class MesosClusterScheduler( .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } .toMap (defaultConf ++ driverConf).foreach { case (key, value) => - options ++= Seq("--conf", s"$key=${shellEscape(value)}") } + options ++= Seq("--conf", s""""$key=${shellEscape(value)}"""".stripMargin) } options } From f7c145d8ce14b23019099c509d5a2b6dfb1fe62c Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 1 Nov 2016 15:41:45 +0100 Subject: [PATCH 071/381] [SPARK-17996][SQL] Fix unqualified catalog.getFunction(...) ## What changes were proposed in this pull request? Currently an unqualified `getFunction(..)`call returns a wrong result; the returned function is shown as temporary function without a database. For example: ``` scala> sql("create function fn1 as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDFAbs'") res0: org.apache.spark.sql.DataFrame = [] scala> spark.catalog.getFunction("fn1") res1: org.apache.spark.sql.catalog.Function = Function[name='fn1', className='org.apache.hadoop.hive.ql.udf.generic.GenericUDFAbs', isTemporary='true'] ``` This PR fixes this by adding database information to ExpressionInfo (which is used to store the function information). ## How was this patch tested? Added more thorough tests to `CatalogSuite`. Author: Herman van Hovell Closes #15542 from hvanhovell/SPARK-17996. --- .../sql/catalyst/expressions/ExpressionInfo.java | 14 ++++++++++++-- .../sql/catalyst/analysis/FunctionRegistry.scala | 2 +- .../sql/catalyst/catalog/SessionCatalog.scala | 10 ++++++++-- .../spark/sql/execution/command/functions.scala | 5 +++-- .../apache/spark/sql/internal/CatalogImpl.scala | 6 +++--- .../apache/spark/sql/internal/CatalogSuite.scala | 15 ++++++++++++--- 6 files changed, 39 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java index ba8e9cb4be28..4565ed44877a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java @@ -25,6 +25,7 @@ public class ExpressionInfo { private String usage; private String name; private String extended; + private String db; public String getClassName() { return className; @@ -42,14 +43,23 @@ public String getExtended() { return extended; } - public ExpressionInfo(String className, String name, String usage, String extended) { + public String getDb() { + return db; + } + + public ExpressionInfo(String className, String db, String name, String usage, String extended) { this.className = className; + this.db = db; this.name = name; this.usage = usage; this.extended = extended; } public ExpressionInfo(String className, String name) { - this(className, name, null, null); + this(className, null, name, null, null); + } + + public ExpressionInfo(String className, String db, String name) { + this(className, db, name, null, null); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b05f4f61f6a3..3e836ca375e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -495,7 +495,7 @@ object FunctionRegistry { val clazz = scala.reflect.classTag[T].runtimeClass val df = clazz.getAnnotation(classOf[ExpressionDescription]) if (df != null) { - new ExpressionInfo(clazz.getCanonicalName, name, df.usage(), df.extended()) + new ExpressionInfo(clazz.getCanonicalName, null, name, df.usage(), df.extended()) } else { new ExpressionInfo(clazz.getCanonicalName, name) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 3d6eec81c03c..714ef825ab83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -943,7 +943,10 @@ class SessionCatalog( requireDbExists(db) if (externalCatalog.functionExists(db, name.funcName)) { val metadata = externalCatalog.getFunction(db, name.funcName) - new ExpressionInfo(metadata.className, qualifiedName.unquotedString) + new ExpressionInfo( + metadata.className, + qualifiedName.database.orNull, + qualifiedName.identifier) } else { failFunctionLookup(name.funcName) } @@ -1000,7 +1003,10 @@ class SessionCatalog( // catalog. So, it is possible that qualifiedName is not exactly the same as // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). // At here, we preserve the input from the user. - val info = new ExpressionInfo(catalogFunction.className, qualifiedName.unquotedString) + val info = new ExpressionInfo( + catalogFunction.className, + qualifiedName.database.orNull, + qualifiedName.funcName) val builder = makeFunctionBuilder(qualifiedName.unquotedString, catalogFunction.className) createTempFunction(qualifiedName.unquotedString, info, builder, ignoreIfExists = false) // Now, we need to create the Expression. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 26593d2918a6..24d825f5cb33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -118,14 +118,15 @@ case class DescribeFunctionCommand( case _ => try { val info = sparkSession.sessionState.catalog.lookupFunctionInfo(functionName) + val name = if (info.getDb != null) info.getDb + "." + info.getName else info.getName val result = - Row(s"Function: ${info.getName}") :: + Row(s"Function: $name") :: Row(s"Class: ${info.getClassName}") :: Row(s"Usage: ${replaceFunctionName(info.getUsage, info.getName)}") :: Nil if (isExtended) { result :+ - Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}") + Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, name)}") } else { result } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index f6c297e91b7c..44fd38dfb96f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -133,11 +133,11 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { private def makeFunction(funcIdent: FunctionIdentifier): Function = { val metadata = sessionCatalog.lookupFunctionInfo(funcIdent) new Function( - name = funcIdent.identifier, - database = funcIdent.database.orNull, + name = metadata.getName, + database = metadata.getDb, description = null, // for now, this is always undefined className = metadata.getClassName, - isTemporary = funcIdent.database.isEmpty) + isTemporary = metadata.getDb == null) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 214bc736bd4d..89ec162c8ed5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -386,15 +386,24 @@ class CatalogSuite createFunction("fn2", Some(db)) // Find a temporary function - assert(spark.catalog.getFunction("fn1").name === "fn1") + val fn1 = spark.catalog.getFunction("fn1") + assert(fn1.name === "fn1") + assert(fn1.database === null) + assert(fn1.isTemporary) // Find a qualified function - assert(spark.catalog.getFunction(db, "fn2").name === "fn2") + val fn2 = spark.catalog.getFunction(db, "fn2") + assert(fn2.name === "fn2") + assert(fn2.database === db) + assert(!fn2.isTemporary) // Find an unqualified function using the current database intercept[AnalysisException](spark.catalog.getFunction("fn2")) spark.catalog.setCurrentDatabase(db) - assert(spark.catalog.getFunction("fn2").name === "fn2") + val unqualified = spark.catalog.getFunction("fn2") + assert(unqualified.name === "fn2") + assert(unqualified.database === db) + assert(!unqualified.isTemporary) } } } From 5441a6269e00e3903ae6c1ea8deb4ddf3d2e9975 Mon Sep 17 00:00:00 2001 From: eyal farago Date: Tue, 1 Nov 2016 17:12:20 +0100 Subject: [PATCH 072/381] [SPARK-16839][SQL] redundant aliases after cleanupAliases ## What changes were proposed in this pull request? Simplify struct creation, especially the aspect of `CleanupAliases` which missed some aliases when handling trees created by `CreateStruct`. This PR includes: 1. A failing test (create struct with nested aliases, some of the aliases survive `CleanupAliases`). 2. A fix that transforms `CreateStruct` into a `CreateNamedStruct` constructor, effectively eliminating `CreateStruct` from all expression trees. 3. A `NamePlaceHolder` used by `CreateStruct` when column names cannot be extracted from unresolved `NamedExpression`. 4. A new Analyzer rule that resolves `NamePlaceHolder` into a string literal once the `NamedExpression` is resolved. 5. `CleanupAliases` code was simplified as it no longer has to deal with `CreateStruct`'s top level columns. ## How was this patch tested? running all tests-suits in package org.apache.spark.sql, especially including the analysis suite, making sure added test initially fails, after applying suggested fix rerun the entire analysis package successfully. modified few tests that expected `CreateStruct` which is now transformed into `CreateNamedStruct`. Credit goes to hvanhovell for assisting with this PR. Author: eyal farago Author: eyal farago Author: Herman van Hovell Author: Eyal Farago Author: Hyukjin Kwon Author: eyalfa Closes #14444 from eyalfa/SPARK-16839_redundant_aliases_after_cleanupAliases. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 12 +- .../sql/catalyst/analysis/Analyzer.scala | 53 ++--- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../sql/catalyst/expressions/Projection.scala | 2 - .../expressions/complexTypeCreator.scala | 211 ++++++------------ .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 38 +++- .../expressions/ComplexTypeSuite.scala | 1 - .../scala/org/apache/spark/sql/Column.scala | 3 + .../command/AnalyzeColumnCommand.scala | 4 +- .../resources/sql-tests/inputs/group-by.sql | 2 +- .../sql-tests/results/group-by.sql.out | 4 +- .../apache/spark/sql/hive/test/TestHive.scala | 20 +- .../resources/sqlgen/subquery_in_having_2.sql | 2 +- .../sql/catalyst/LogicalPlanToSQLSuite.scala | 12 +- 15 files changed, 170 insertions(+), 200 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9289db57b6d6..5002655fc03c 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1222,16 +1222,16 @@ test_that("column functions", { # Test struct() df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) - result <- collect(select(df, struct("a", "c"))) + result <- collect(select(df, alias(struct("a", "c"), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), - listToStruct(list(a = 4L, c = 6L))) + expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) - result <- collect(select(df, struct(df$a, df$b))) + result <- collect(select(df, alias(struct(df$a, df$b), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), - listToStruct(list(a = 4L, b = 5L))) + expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) # Test encode(), decode() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f8f4799322b3..5011f2fdbf9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.trees.{TreeNodeRef} +import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types._ @@ -83,6 +83,7 @@ class Analyzer( ResolveTableValuedFunctions :: ResolveRelations :: ResolveReferences :: + ResolveCreateNamedStruct :: ResolveDeserializer :: ResolveNewInstance :: ResolveUpCast :: @@ -653,11 +654,12 @@ class Analyzer( case s: Star => s.expand(child, resolver) case o => o :: Nil }) - case c: CreateStruct if containsStar(c.children) => - c.copy(children = c.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) + case c: CreateNamedStruct if containsStar(c.valExprs) => + val newChildren = c.children.grouped(2).flatMap { + case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children + case kv => kv + } + c.copy(children = newChildren.toList ) case c: CreateArray if containsStar(c.children) => c.copy(children = c.children.flatMap { case s: Star => s.expand(child, resolver) @@ -1141,7 +1143,7 @@ class Analyzer( case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => // Get the left hand side expressions. val expressions = e match { - case CreateStruct(exprs) => exprs + case cns : CreateNamedStruct => cns.valExprs case expr => Seq(expr) } resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => @@ -2072,18 +2074,8 @@ object EliminateUnions extends Rule[LogicalPlan] { */ object CleanupAliases extends Rule[LogicalPlan] { private def trimAliases(e: Expression): Expression = { - var stop = false e.transformDown { - // CreateStruct is a special case, we need to retain its top level Aliases as they decide the - // name of StructField. We also need to stop transform down this expression, or the Aliases - // under CreateStruct will be mistakenly trimmed. - case c: CreateStruct if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case c: CreateStructUnsafe if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case Alias(child, _) if !stop => child + case Alias(child, _) => child } } @@ -2116,15 +2108,8 @@ object CleanupAliases extends Rule[LogicalPlan] { case a: AppendColumns => a case other => - var stop = false other transformExpressionsDown { - case c: CreateStruct if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case c: CreateStructUnsafe if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case Alias(child, _) if !stop => child + case Alias(child, _) => child } } } @@ -2217,3 +2202,19 @@ object TimeWindowing extends Rule[LogicalPlan] { } } } + +/** + * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. + */ +object ResolveCreateNamedStruct extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + case e: CreateNamedStruct if !e.resolved => + val children = e.children.grouped(2).flatMap { + case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => + Seq(Literal(e.name), e) + case kv => + kv + } + CreateNamedStruct(children.toList) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3e836ca375e2..b028d07fb8d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -357,7 +357,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), - expression[CreateStruct]("struct"), + CreateStruct.registryEntry, // misc functions expression[AssertTrue]("assert_true"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index a81fa1ce3adc..03e054d09851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -119,7 +119,6 @@ object UnsafeProjection { */ def create(exprs: Seq[Expression]): UnsafeProjection = { val unsafeExprs = exprs.map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) GenerateUnsafeProjection.generate(unsafeExprs) @@ -145,7 +144,6 @@ object UnsafeProjection { subexpressionEliminationEnabled: Boolean): UnsafeProjection = { val e = exprs.map(BindReferences.bindReference(_, inputSchema)) .map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 917aa0873130..e9623f96e1cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -172,101 +174,70 @@ case class CreateMap(children: Seq[Expression]) extends Expression { } /** - * Returns a Row containing the evaluation of all children expressions. + * An expression representing a not yet available attribute name. This expression is unevaluable + * and as its name suggests it is a temporary place holder until we're able to determine the + * actual attribute name. */ -@ExpressionDescription( - usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.") -case class CreateStruct(children: Seq[Expression]) extends Expression { - - override def foldable: Boolean = children.forall(_.foldable) - - override lazy val dataType: StructType = { - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } - } - StructType(fields) - } - +case object NamePlaceholder extends LeafExpression with Unevaluable { + override lazy val resolved: Boolean = false + override def foldable: Boolean = false override def nullable: Boolean = false + override def dataType: DataType = StringType + override def prettyName: String = "NamePlaceholder" + override def toString: String = prettyName +} - override def eval(input: InternalRow): Any = { - InternalRow(children.map(_.eval(input)): _*) +/** + * Returns a Row containing the evaluation of all children expressions. + */ +object CreateStruct extends FunctionBuilder { + def apply(children: Seq[Expression]): CreateNamedStruct = { + CreateNamedStruct(children.zipWithIndex.flatMap { + case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) + case (e: NamedExpression, _) => Seq(NamePlaceholder, e) + case (e, index) => Seq(Literal(s"col${index + 1}"), e) + }) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rowClass = classOf[GenericInternalRow].getName - val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"this.$values = null;") - - ev.copy(code = s""" - boolean ${ev.isNull} = false; - this.$values = new Object[${children.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - children.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - }""" - }) + - s""" - final InternalRow ${ev.value} = new $rowClass($values); - this.$values = null; - """) + /** + * Entry to use in the function registry. + */ + val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = { + val info: ExpressionInfo = new ExpressionInfo( + "org.apache.spark.sql.catalyst.expressions.NamedStruct", + "struct", + "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.", + "") + ("struct", (info, this)) } - - override def prettyName: String = "struct" } - /** - * Creates a struct with the given field names and values - * - * @param children Seq(name1, val1, name2, val2, ...) + * Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]]. */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") -// scalastyle:on line.size.limit -case class CreateNamedStruct(children: Seq[Expression]) extends Expression { +trait CreateNamedStructLike extends Expression { + lazy val (nameExprs, valExprs) = children.grouped(2).map { + case Seq(name, value) => (name, value) + }.toList.unzip - /** - * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this - * StructType. - */ - def flatten: Seq[NamedExpression] = valExprs.zip(names).map { - case (v, n) => Alias(v, n.toString)() - } + lazy val names = nameExprs.map(_.eval(EmptyRow)) - private lazy val (nameExprs, valExprs) = - children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + override def nullable: Boolean = false - private lazy val names = nameExprs.map(_.eval(EmptyRow)) + override def foldable: Boolean = valExprs.forall(_.foldable) override lazy val dataType: StructType = { val fields = names.zip(valExprs).map { - case (name, valExpr: NamedExpression) => - StructField(name.asInstanceOf[UTF8String].toString, - valExpr.dataType, valExpr.nullable, valExpr.metadata) - case (name, valExpr) => - StructField(name.asInstanceOf[UTF8String].toString, - valExpr.dataType, valExpr.nullable, Metadata.empty) + case (name, expr) => + val metadata = expr match { + case ne: NamedExpression => ne.metadata + case _ => Metadata.empty + } + StructField(name.toString, expr.dataType, expr.nullable, metadata) } StructType(fields) } - override def foldable: Boolean = valExprs.forall(_.foldable) - - override def nullable: Boolean = false - override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") @@ -274,8 +245,8 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - s"Only foldable StringType expressions are allowed to appear at odd position , got :" + - s" ${invalidNames.mkString(",")}") + "Only foldable StringType expressions are allowed to appear at odd position, got:" + + s" ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { @@ -284,9 +255,29 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { } } + /** + * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this + * StructType. + */ + def flatten: Seq[NamedExpression] = valExprs.zip(names).map { + case (v, n) => Alias(v, n.toString)() + } + override def eval(input: InternalRow): Any = { InternalRow(valExprs.map(_.eval(input)): _*) } +} + +/** + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") +// scalastyle:on line.size.limit +case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName @@ -316,44 +307,6 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { override def prettyName: String = "named_struct" } -/** - * Returns a Row containing the evaluation of all children expressions. This is a variant that - * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with - * this expression automatically at runtime. - */ -case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { - - override def foldable: Boolean = children.forall(_.foldable) - - override lazy val resolved: Boolean = childrenResolved - - override lazy val dataType: StructType = { - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } - } - StructType(fields) - } - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - InternalRow(children.map(_.eval(input)): _*) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval = GenerateUnsafeProjection.createCode(ctx, children) - ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) - } - - override def prettyName: String = "struct_unsafe" -} - - /** * Creates a struct with the given field names and values. This is a variant that returns * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with @@ -361,31 +314,7 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { * * @param children Seq(name1, val1, name2, val2, ...) */ -case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression { - - private lazy val (nameExprs, valExprs) = - children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip - - private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) - - override lazy val dataType: StructType = { - val fields = names.zip(valExprs).map { - case (name, valExpr: NamedExpression) => - StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata) - case (name, valExpr) => - StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) - } - StructType(fields) - } - - override def foldable: Boolean = valExprs.forall(_.foldable) - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - InternalRow(valExprs.map(_.eval(input)): _*) - } - +case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 38e9bb6c162a..35aca91cf882 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -681,8 +681,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // inline table comes in two styles: // style 1: values (1), (2), (3) -- multiple columns are supported // style 2: values 1, 2, 3 -- only a single column is supported here - case CreateStruct(children) => children // style 1 - case child => Seq(child) // style 2 + case struct: CreateNamedStruct => struct.valExprs // style 1 + case child => Seq(child) // style 2 } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 590774c04304..817de48de279 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import org.scalatest.ShouldMatchers + import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -25,7 +27,8 @@ import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ -class AnalysisSuite extends AnalysisTest { + +class AnalysisSuite extends AnalysisTest with ShouldMatchers { import org.apache.spark.sql.catalyst.analysis.TestRelations._ test("union project *") { @@ -218,9 +221,36 @@ class AnalysisSuite extends AnalysisTest { // CreateStruct is a special case that we should not trim Alias for it. plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col")) - checkAnalysis(plan, plan) - plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col")) - checkAnalysis(plan, plan) + expected = testRelation.select(CreateNamedStruct(Seq( + Literal(a.name), a, + Literal("a+1"), (a + 1))).as("col")) + checkAnalysis(plan, expected) + } + + test("Analysis may leave unnecassary aliases") { + val att1 = testRelation.output.head + var plan = testRelation.select( + CreateStruct(Seq(att1, ((att1.as("aa")) + 1).as("a_plus_1"))).as("col"), + att1 + ) + val prevPlan = getAnalyzer(true).execute(plan) + plan = prevPlan.select(CreateArray(Seq( + CreateStruct(Seq(att1, (att1 + 1).as("a_plus_1"))).as("col1"), + /** alias should be eliminated by [[CleanupAliases]] */ + "col".attr.as("col2") + )).as("arr")) + plan = getAnalyzer(true).execute(plan) + + val expectedPlan = prevPlan.select( + CreateArray(Seq( + CreateNamedStruct(Seq( + Literal(att1.name), att1, + Literal("a_plus_1"), (att1 + 1))), + 'col.struct(prevPlan.output(0).dataType.asInstanceOf[StructType]).notNull + )).as("arr") + ) + + checkAnalysis(plan, expectedPlan) } test("SPARK-10534: resolve attribute references in order by clause") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 0c307b2b8576..c21c6de32c0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -243,7 +243,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val b = AttributeReference("b", IntegerType)() checkMetadata(CreateStruct(Seq(a, b))) checkMetadata(CreateNamedStruct(Seq("a", a, "b", b))) - checkMetadata(CreateStructUnsafe(Seq(a, b))) checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 05e867bf5be9..067b0bac6303 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -183,6 +183,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => UnresolvedAlias(a, Some(Column.generateAlias)) + // Wait until the struct is resolved. This will generate a nicer looking alias. + case struct: CreateNamedStructLike => UnresolvedAlias(struct) + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index f873f34a845e..6141fab4aff0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -137,7 +137,7 @@ object ColumnStatStruct { private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - private def getStruct(exprs: Seq[Expression]): CreateStruct = { + private def getStruct(exprs: Seq[Expression]): CreateNamedStruct = { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() @@ -168,7 +168,7 @@ object ColumnStatStruct { } } - def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match { + def apply(attr: Attribute, relativeSD: Double): CreateNamedStruct = attr.dataType match { // Use aggregate functions to compute statistics we need. case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD)) case StringType => getStruct(stringColumnStat(attr, relativeSD)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 6741703d9d82..d496af686d75 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -14,4 +14,4 @@ select 'foo' from myview where int_col == 0 group by 1; select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1; -- group-by should not produce any rows (sort aggregate). -select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1; +select 'foo', max(struct(int_col)) as agg_struct from myview where int_col == 0 group by 1; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 9127bd4dd4c6..dede3a09ce75 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -44,8 +44,8 @@ struct -- !query 5 -select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1 +select 'foo', max(struct(int_col)) as agg_struct from myview where int_col == 0 group by 1 -- !query 5 schema -struct> +struct> -- !query 5 output diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 6eb571b91ffa..90000445dffb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -190,6 +190,12 @@ private[hive] class TestHiveSparkSession( new File(Thread.currentThread().getContextClassLoader.getResource(path).getFile) } + private def quoteHiveFile(path : String) = if (Utils.isWindows) { + getHiveFile(path).getPath.replace('\\', '/') + } else { + getHiveFile(path).getPath + } + def getWarehousePath(): String = { val tempConf = new SQLConf sc.conf.getAll.foreach { case (k, v) => tempConf.setConfString(k, v) } @@ -225,16 +231,16 @@ private[hive] class TestHiveSparkSession( val hiveQTestUtilTables: Seq[TestTable] = Seq( TestTable("src", "CREATE TABLE src (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), + s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), TestTable("src1", "CREATE TABLE src1 (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), + s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), TestTable("srcpart", () => { sql( "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { sql( - s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') """.stripMargin) } @@ -244,7 +250,7 @@ private[hive] class TestHiveSparkSession( "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { sql( - s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') """.stripMargin) } @@ -269,7 +275,7 @@ private[hive] class TestHiveSparkSession( sql( s""" - |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' |INTO TABLE src_thrift """.stripMargin) }), @@ -308,7 +314,7 @@ private[hive] class TestHiveSparkSession( |) """.stripMargin.cmd, s""" - |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/episodes.avro")}' |INTO TABLE episodes """.stripMargin.cmd ), @@ -379,7 +385,7 @@ private[hive] class TestHiveSparkSession( TestTable("src_json", s"""CREATE TABLE src_json (json STRING) STORED AS TEXTFILE """.stripMargin.cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) + s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) ) hiveQTestUtilTables.foreach(registerTestTable) diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql index de0116a4dcba..cdda29af50e3 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql @@ -7,4 +7,4 @@ having b.key in (select a.key where a.value > 'val_9' and a.value = min(b.value)) order by b.key -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (named_struct('gen_attr_0', `gen_attr_0`, 'gen_attr_4', `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index c7f10e569fa4..12d18dc87ceb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst import java.nio.charset.StandardCharsets import java.nio.file.{Files, NoSuchFileException, Paths} +import scala.io.Source import scala.util.control.NonFatal import org.apache.spark.sql.Column @@ -109,12 +110,15 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { Files.write(path, answerText.getBytes(StandardCharsets.UTF_8)) } else { val goldenFileName = s"sqlgen/$answerFile.sql" - val resourceFile = getClass.getClassLoader.getResource(goldenFileName) - if (resourceFile == null) { + val resourceStream = getClass.getClassLoader.getResourceAsStream(goldenFileName) + if (resourceStream == null) { throw new NoSuchFileException(goldenFileName) } - val path = resourceFile.getPath - val answerText = new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8) + val answerText = try { + Source.fromInputStream(resourceStream).mkString + } finally { + resourceStream.close + } val sqls = answerText.split(separator) assert(sqls.length == 2, "Golden sql files should have a separator.") val expectedSQL = sqls(1).trim() From 0cba535af3c65618f342fa2d7db9647f5e6f6f1b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 1 Nov 2016 17:30:37 +0100 Subject: [PATCH 073/381] Revert "[SPARK-16839][SQL] redundant aliases after cleanupAliases" This reverts commit 5441a6269e00e3903ae6c1ea8deb4ddf3d2e9975. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 12 +- .../sql/catalyst/analysis/Analyzer.scala | 53 +++-- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../sql/catalyst/expressions/Projection.scala | 2 + .../expressions/complexTypeCreator.scala | 211 ++++++++++++------ .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 38 +--- .../expressions/ComplexTypeSuite.scala | 1 + .../scala/org/apache/spark/sql/Column.scala | 3 - .../command/AnalyzeColumnCommand.scala | 4 +- .../resources/sql-tests/inputs/group-by.sql | 2 +- .../sql-tests/results/group-by.sql.out | 4 +- .../apache/spark/sql/hive/test/TestHive.scala | 20 +- .../resources/sqlgen/subquery_in_having_2.sql | 2 +- .../sql/catalyst/LogicalPlanToSQLSuite.scala | 12 +- 15 files changed, 200 insertions(+), 170 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 5002655fc03c..9289db57b6d6 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1222,16 +1222,16 @@ test_that("column functions", { # Test struct() df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) - result <- collect(select(df, alias(struct("a", "c"), "d"))) + result <- collect(select(df, struct("a", "c"))) expected <- data.frame(row.names = 1:2) - expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)), - listToStruct(list(a = 4L, c = 6L))) + expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) - result <- collect(select(df, alias(struct(df$a, df$b), "d"))) + result <- collect(select(df, struct(df$a, df$b))) expected <- data.frame(row.names = 1:2) - expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)), - listToStruct(list(a = 4L, b = 5L))) + expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) # Test encode(), decode() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5011f2fdbf9b..f8f4799322b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.trees.{TreeNodeRef} import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types._ @@ -83,7 +83,6 @@ class Analyzer( ResolveTableValuedFunctions :: ResolveRelations :: ResolveReferences :: - ResolveCreateNamedStruct :: ResolveDeserializer :: ResolveNewInstance :: ResolveUpCast :: @@ -654,12 +653,11 @@ class Analyzer( case s: Star => s.expand(child, resolver) case o => o :: Nil }) - case c: CreateNamedStruct if containsStar(c.valExprs) => - val newChildren = c.children.grouped(2).flatMap { - case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children - case kv => kv - } - c.copy(children = newChildren.toList ) + case c: CreateStruct if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) case c: CreateArray if containsStar(c.children) => c.copy(children = c.children.flatMap { case s: Star => s.expand(child, resolver) @@ -1143,7 +1141,7 @@ class Analyzer( case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => // Get the left hand side expressions. val expressions = e match { - case cns : CreateNamedStruct => cns.valExprs + case CreateStruct(exprs) => exprs case expr => Seq(expr) } resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => @@ -2074,8 +2072,18 @@ object EliminateUnions extends Rule[LogicalPlan] { */ object CleanupAliases extends Rule[LogicalPlan] { private def trimAliases(e: Expression): Expression = { + var stop = false e.transformDown { - case Alias(child, _) => child + // CreateStruct is a special case, we need to retain its top level Aliases as they decide the + // name of StructField. We also need to stop transform down this expression, or the Aliases + // under CreateStruct will be mistakenly trimmed. + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child } } @@ -2108,8 +2116,15 @@ object CleanupAliases extends Rule[LogicalPlan] { case a: AppendColumns => a case other => + var stop = false other transformExpressionsDown { - case Alias(child, _) => child + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child } } } @@ -2202,19 +2217,3 @@ object TimeWindowing extends Rule[LogicalPlan] { } } } - -/** - * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. - */ -object ResolveCreateNamedStruct extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { - case e: CreateNamedStruct if !e.resolved => - val children = e.children.grouped(2).flatMap { - case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => - Seq(Literal(e.name), e) - case kv => - kv - } - CreateNamedStruct(children.toList) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b028d07fb8d0..3e836ca375e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -357,7 +357,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), - CreateStruct.registryEntry, + expression[CreateStruct]("struct"), // misc functions expression[AssertTrue]("assert_true"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 03e054d09851..a81fa1ce3adc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -119,6 +119,7 @@ object UnsafeProjection { */ def create(exprs: Seq[Expression]): UnsafeProjection = { val unsafeExprs = exprs.map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) GenerateUnsafeProjection.generate(unsafeExprs) @@ -144,6 +145,7 @@ object UnsafeProjection { subexpressionEliminationEnabled: Boolean): UnsafeProjection = { val e = exprs.map(BindReferences.bindReference(_, inputSchema)) .map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index e9623f96e1cf..917aa0873130 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -174,70 +172,101 @@ case class CreateMap(children: Seq[Expression]) extends Expression { } /** - * An expression representing a not yet available attribute name. This expression is unevaluable - * and as its name suggests it is a temporary place holder until we're able to determine the - * actual attribute name. + * Returns a Row containing the evaluation of all children expressions. */ -case object NamePlaceholder extends LeafExpression with Unevaluable { - override lazy val resolved: Boolean = false - override def foldable: Boolean = false +@ExpressionDescription( + usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.") +case class CreateStruct(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val dataType: StructType = { + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } + } + StructType(fields) + } + override def nullable: Boolean = false - override def dataType: DataType = StringType - override def prettyName: String = "NamePlaceholder" - override def toString: String = prettyName -} -/** - * Returns a Row containing the evaluation of all children expressions. - */ -object CreateStruct extends FunctionBuilder { - def apply(children: Seq[Expression]): CreateNamedStruct = { - CreateNamedStruct(children.zipWithIndex.flatMap { - case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) - case (e: NamedExpression, _) => Seq(NamePlaceholder, e) - case (e, index) => Seq(Literal(s"col${index + 1}"), e) - }) + override def eval(input: InternalRow): Any = { + InternalRow(children.map(_.eval(input)): _*) } - /** - * Entry to use in the function registry. - */ - val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = { - val info: ExpressionInfo = new ExpressionInfo( - "org.apache.spark.sql.catalyst.expressions.NamedStruct", - "struct", - "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.", - "") - ("struct", (info, this)) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val rowClass = classOf[GenericInternalRow].getName + val values = ctx.freshName("values") + ctx.addMutableState("Object[]", values, s"this.$values = null;") + + ev.copy(code = s""" + boolean ${ev.isNull} = false; + this.$values = new Object[${children.size}];""" + + ctx.splitExpressions( + ctx.INPUT_ROW, + children.zipWithIndex.map { case (e, i) => + val eval = e.genCode(ctx) + eval.code + s""" + if (${eval.isNull}) { + $values[$i] = null; + } else { + $values[$i] = ${eval.value}; + }""" + }) + + s""" + final InternalRow ${ev.value} = new $rowClass($values); + this.$values = null; + """) } + + override def prettyName: String = "struct" } + /** - * Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]]. + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) */ -trait CreateNamedStructLike extends Expression { - lazy val (nameExprs, valExprs) = children.grouped(2).map { - case Seq(name, value) => (name, value) - }.toList.unzip +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") +// scalastyle:on line.size.limit +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { - lazy val names = nameExprs.map(_.eval(EmptyRow)) + /** + * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this + * StructType. + */ + def flatten: Seq[NamedExpression] = valExprs.zip(names).map { + case (v, n) => Alias(v, n.toString)() + } - override def nullable: Boolean = false + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip - override def foldable: Boolean = valExprs.forall(_.foldable) + private lazy val names = nameExprs.map(_.eval(EmptyRow)) override lazy val dataType: StructType = { val fields = names.zip(valExprs).map { - case (name, expr) => - val metadata = expr match { - case ne: NamedExpression => ne.metadata - case _ => Metadata.empty - } - StructField(name.toString, expr.dataType, expr.nullable, metadata) + case (name, valExpr: NamedExpression) => + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, valExpr.metadata) + case (name, valExpr) => + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, Metadata.empty) } StructType(fields) } + override def foldable: Boolean = valExprs.forall(_.foldable) + + override def nullable: Boolean = false + override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") @@ -245,8 +274,8 @@ trait CreateNamedStructLike extends Expression { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - "Only foldable StringType expressions are allowed to appear at odd position, got:" + - s" ${invalidNames.mkString(",")}") + s"Only foldable StringType expressions are allowed to appear at odd position , got :" + + s" ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { @@ -255,29 +284,9 @@ trait CreateNamedStructLike extends Expression { } } - /** - * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this - * StructType. - */ - def flatten: Seq[NamedExpression] = valExprs.zip(names).map { - case (v, n) => Alias(v, n.toString)() - } - override def eval(input: InternalRow): Any = { InternalRow(valExprs.map(_.eval(input)): _*) } -} - -/** - * Creates a struct with the given field names and values - * - * @param children Seq(name1, val1, name2, val2, ...) - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") -// scalastyle:on line.size.limit -case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName @@ -307,6 +316,44 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc override def prettyName: String = "named_struct" } +/** + * Returns a Row containing the evaluation of all children expressions. This is a variant that + * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with + * this expression automatically at runtime. + */ +case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + InternalRow(children.map(_.eval(input)): _*) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval = GenerateUnsafeProjection.createCode(ctx, children) + ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) + } + + override def prettyName: String = "struct_unsafe" +} + + /** * Creates a struct with the given field names and values. This is a variant that returns * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with @@ -314,7 +361,31 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc * * @param children Seq(name1, val1, name2, val2, ...) */ -case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { +case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression { + + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + + private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + + override lazy val dataType: StructType = { + val fields = names.zip(valExprs).map { + case (name, valExpr: NamedExpression) => + StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata) + case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + } + StructType(fields) + } + + override def foldable: Boolean = valExprs.forall(_.foldable) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + InternalRow(valExprs.map(_.eval(input)): _*) + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 35aca91cf882..38e9bb6c162a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -681,8 +681,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // inline table comes in two styles: // style 1: values (1), (2), (3) -- multiple columns are supported // style 2: values 1, 2, 3 -- only a single column is supported here - case struct: CreateNamedStruct => struct.valExprs // style 1 - case child => Seq(child) // style 2 + case CreateStruct(children) => children // style 1 + case child => Seq(child) // style 2 } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 817de48de279..590774c04304 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.ShouldMatchers - import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -27,8 +25,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ - -class AnalysisSuite extends AnalysisTest with ShouldMatchers { +class AnalysisSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.analysis.TestRelations._ test("union project *") { @@ -221,36 +218,9 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { // CreateStruct is a special case that we should not trim Alias for it. plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col")) - expected = testRelation.select(CreateNamedStruct(Seq( - Literal(a.name), a, - Literal("a+1"), (a + 1))).as("col")) - checkAnalysis(plan, expected) - } - - test("Analysis may leave unnecassary aliases") { - val att1 = testRelation.output.head - var plan = testRelation.select( - CreateStruct(Seq(att1, ((att1.as("aa")) + 1).as("a_plus_1"))).as("col"), - att1 - ) - val prevPlan = getAnalyzer(true).execute(plan) - plan = prevPlan.select(CreateArray(Seq( - CreateStruct(Seq(att1, (att1 + 1).as("a_plus_1"))).as("col1"), - /** alias should be eliminated by [[CleanupAliases]] */ - "col".attr.as("col2") - )).as("arr")) - plan = getAnalyzer(true).execute(plan) - - val expectedPlan = prevPlan.select( - CreateArray(Seq( - CreateNamedStruct(Seq( - Literal(att1.name), att1, - Literal("a_plus_1"), (att1 + 1))), - 'col.struct(prevPlan.output(0).dataType.asInstanceOf[StructType]).notNull - )).as("arr") - ) - - checkAnalysis(plan, expectedPlan) + checkAnalysis(plan, plan) + plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col")) + checkAnalysis(plan, plan) } test("SPARK-10534: resolve attribute references in order by clause") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index c21c6de32c0b..0c307b2b8576 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -243,6 +243,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val b = AttributeReference("b", IntegerType)() checkMetadata(CreateStruct(Seq(a, b))) checkMetadata(CreateNamedStruct(Seq("a", a, "b", b))) + checkMetadata(CreateStructUnsafe(Seq(a, b))) checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 067b0bac6303..05e867bf5be9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -183,9 +183,6 @@ class Column(protected[sql] val expr: Expression) extends Logging { case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => UnresolvedAlias(a, Some(Column.generateAlias)) - // Wait until the struct is resolved. This will generate a nicer looking alias. - case struct: CreateNamedStructLike => UnresolvedAlias(struct) - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 6141fab4aff0..f873f34a845e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -137,7 +137,7 @@ object ColumnStatStruct { private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - private def getStruct(exprs: Seq[Expression]): CreateNamedStruct = { + private def getStruct(exprs: Seq[Expression]): CreateStruct = { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() @@ -168,7 +168,7 @@ object ColumnStatStruct { } } - def apply(attr: Attribute, relativeSD: Double): CreateNamedStruct = attr.dataType match { + def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match { // Use aggregate functions to compute statistics we need. case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD)) case StringType => getStruct(stringColumnStat(attr, relativeSD)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index d496af686d75..6741703d9d82 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -14,4 +14,4 @@ select 'foo' from myview where int_col == 0 group by 1; select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1; -- group-by should not produce any rows (sort aggregate). -select 'foo', max(struct(int_col)) as agg_struct from myview where int_col == 0 group by 1; +select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index dede3a09ce75..9127bd4dd4c6 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -44,8 +44,8 @@ struct -- !query 5 -select 'foo', max(struct(int_col)) as agg_struct from myview where int_col == 0 group by 1 +select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1 -- !query 5 schema -struct> +struct> -- !query 5 output diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 90000445dffb..6eb571b91ffa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -190,12 +190,6 @@ private[hive] class TestHiveSparkSession( new File(Thread.currentThread().getContextClassLoader.getResource(path).getFile) } - private def quoteHiveFile(path : String) = if (Utils.isWindows) { - getHiveFile(path).getPath.replace('\\', '/') - } else { - getHiveFile(path).getPath - } - def getWarehousePath(): String = { val tempConf = new SQLConf sc.conf.getAll.foreach { case (k, v) => tempConf.setConfString(k, v) } @@ -231,16 +225,16 @@ private[hive] class TestHiveSparkSession( val hiveQTestUtilTables: Seq[TestTable] = Seq( TestTable("src", "CREATE TABLE src (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), TestTable("src1", "CREATE TABLE src1 (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), TestTable("srcpart", () => { sql( "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { sql( - s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' + s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') """.stripMargin) } @@ -250,7 +244,7 @@ private[hive] class TestHiveSparkSession( "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { sql( - s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' + s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') """.stripMargin) } @@ -275,7 +269,7 @@ private[hive] class TestHiveSparkSession( sql( s""" - |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' + |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' |INTO TABLE src_thrift """.stripMargin) }), @@ -314,7 +308,7 @@ private[hive] class TestHiveSparkSession( |) """.stripMargin.cmd, s""" - |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/episodes.avro")}' + |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' |INTO TABLE episodes """.stripMargin.cmd ), @@ -385,7 +379,7 @@ private[hive] class TestHiveSparkSession( TestTable("src_json", s"""CREATE TABLE src_json (json STRING) STORED AS TEXTFILE """.stripMargin.cmd, - s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) ) hiveQTestUtilTables.foreach(registerTestTable) diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql index cdda29af50e3..de0116a4dcba 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql @@ -7,4 +7,4 @@ having b.key in (select a.key where a.value > 'val_9' and a.value = min(b.value)) order by b.key -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (named_struct('gen_attr_0', `gen_attr_0`, 'gen_attr_4', `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index 12d18dc87ceb..c7f10e569fa4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst import java.nio.charset.StandardCharsets import java.nio.file.{Files, NoSuchFileException, Paths} -import scala.io.Source import scala.util.control.NonFatal import org.apache.spark.sql.Column @@ -110,15 +109,12 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { Files.write(path, answerText.getBytes(StandardCharsets.UTF_8)) } else { val goldenFileName = s"sqlgen/$answerFile.sql" - val resourceStream = getClass.getClassLoader.getResourceAsStream(goldenFileName) - if (resourceStream == null) { + val resourceFile = getClass.getClassLoader.getResource(goldenFileName) + if (resourceFile == null) { throw new NoSuchFileException(goldenFileName) } - val answerText = try { - Source.fromInputStream(resourceStream).mkString - } finally { - resourceStream.close - } + val path = resourceFile.getPath + val answerText = new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8) val sqls = answerText.split(separator) assert(sqls.length == 2, "Golden sql files should have a separator.") val expectedSQL = sqls(1).trim() From 8ac09108fcf3fb62a812333a5b386b566a9d98ec Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 1 Nov 2016 10:46:36 -0700 Subject: [PATCH 074/381] [SPARK-17848][ML] Move LabelCol datatype cast into Predictor.fit ## What changes were proposed in this pull request? 1, move cast to `Predictor` 2, and then, remove unnecessary cast ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #15414 from zhengruifeng/move_cast. --- .../scala/org/apache/spark/ml/Predictor.scala | 12 ++- .../spark/ml/classification/Classifier.scala | 4 +- .../ml/classification/GBTClassifier.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../spark/ml/classification/NaiveBayes.scala | 2 +- .../GeneralizedLinearRegression.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../org/apache/spark/ml/PredictorSuite.scala | 82 +++++++++++++++++++ .../LogisticRegressionSuite.scala | 1 - 9 files changed, 98 insertions(+), 11 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e29d7f48a1d6..aa92edde7acd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params /** * :: DeveloperApi :: - * Abstraction for prediction problems (regression and classification). + * Abstraction for prediction problems (regression and classification). It accepts all NumericType + * labels and will automatically cast it to DoubleType in [[fit()]]. * * @tparam FeaturesType Type of features. * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. @@ -87,7 +88,12 @@ abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) - copyValues(train(dataset).setParent(this)) + + // Cast LabelCol to DoubleType and keep the metadata. + val labelMeta = dataset.schema($(labelCol)).metadata + val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) + + copyValues(train(casted).setParent(this)) } override def copy(extra: ParamMap): Learner @@ -121,7 +127,7 @@ abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index d1b21b16f234..a3da3067e1b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -71,7 +71,7 @@ abstract class Classifier[ * and put it in an RDD with strong types. * * @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]]) - * and features ([[Vector]]). Labels are cast to [[DoubleType]]. + * and features ([[Vector]]). * @param numClasses Number of classes label can take. Labels must be integers in the range * [0, numClasses). * @throws SparkException if any label is not an integer >= 0 @@ -79,7 +79,7 @@ abstract class Classifier[ protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + s" $numClasses, but requires numClasses > 0.") - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" + s" dataset with invalid label $label. Labels must be integers in range" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 8bffe0cda032..f8f164e8c14b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -128,7 +128,7 @@ class GBTClassifier @Since("1.4.0") ( // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports // 2 classes now. This lets us provide a more precise error message. val oldDataset: RDD[LabeledPoint] = - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + s" dataset with invalid label $label. Labels must be in {0,1}; note that" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8fdaae04c42e..c4651054fd76 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") ( LogisticRegressionModel = { val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 994ed993c99d..b03a07a6bc1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -171,7 +171,7 @@ class NaiveBayes @Since("1.5.0") ( // Aggregates term frequencies per label. // TODO: Calling aggregateByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd + val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))( seqOp = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 33cb25c8c7f6..8656ecf609ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 519f3bdec82d..ae876b383973 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -190,7 +190,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select( - col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala new file mode 100644 index 000000000000..03e0c536a973 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext { + + import PredictorSuite._ + + test("should support all NumericType labels and not support other types") { + val df = spark.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3)), + (1, Vectors.dense(0, 3, 9)), + (0, Vectors.dense(0, 2, 6)) + )).toDF("label", "features") + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + + val predictor = new MockPredictor() + + types.foreach { t => + predictor.fit(df.select(col("label").cast(t), col("features"))) + } + + intercept[IllegalArgumentException] { + predictor.fit(df.select(col("label").cast(StringType), col("features"))) + } + } +} + +object PredictorSuite { + + class MockPredictor(override val uid: String) + extends Predictor[Vector, MockPredictor, MockPredictionModel] { + + def this() = this(Identifiable.randomUID("mockpredictor")) + + override def train(dataset: Dataset[_]): MockPredictionModel = { + require(dataset.schema("label").dataType == DoubleType) + new MockPredictionModel(uid) + } + + override def copy(extra: ParamMap): MockPredictor = + throw new NotImplementedError() + } + + class MockPredictionModel(override val uid: String) + extends PredictionModel[Vector, MockPredictionModel] { + + def this() = this(Identifiable.randomUID("mockpredictormodel")) + + override def predict(features: Vector): Double = + throw new NotImplementedError() + + override def copy(extra: ParamMap): MockPredictionModel = + throw new NotImplementedError() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index bc631dc6d314..8771fd2e9d2b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1807,7 +1807,6 @@ class LogisticRegressionSuite .objectiveHistory .sliding(2) .forall(x => x(0) >= x(1))) - } test("binary logistic regression with weighted data") { From 8cdf143f4b1ca5c6bc0256808e6f42d9ef299cbd Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Tue, 1 Nov 2016 11:17:35 -0700 Subject: [PATCH 075/381] [SPARK-18103][FOLLOW-UP][SQL][MINOR] Rename `MetadataLogFileCatalog` to `MetadataLogFileIndex` ## What changes were proposed in this pull request? This is a follow-up to https://github.com/apache/spark/pull/15634. ## How was this patch tested? N/A Author: Liwei Lin Closes #15712 from lw-lin/18103. --- .../{MetadataLogFileCatalog.scala => MetadataLogFileIndex.scala} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{MetadataLogFileCatalog.scala => MetadataLogFileIndex.scala} (100%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala From 8a538c97b556f80f67c80519af0ce879557050d5 Mon Sep 17 00:00:00 2001 From: Ergin Seyfe Date: Tue, 1 Nov 2016 11:18:42 -0700 Subject: [PATCH 076/381] [SPARK-18189][SQL] Fix serialization issue in KeyValueGroupedDataset ## What changes were proposed in this pull request? Likewise [DataSet.scala](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L156) KeyValueGroupedDataset should mark the queryExecution as transient. As mentioned in the Jira ticket, without transient we saw serialization issues like ``` Caused by: java.io.NotSerializableException: org.apache.spark.sql.execution.QueryExecution Serialization stack: - object not serializable (class: org.apache.spark.sql.execution.QueryExecution, value: == ``` ## How was this patch tested? Run the query which is specified in the Jira ticket before and after: ``` val a = spark.createDataFrame(sc.parallelize(Seq((1,2),(3,4)))).as[(Int,Int)] val grouped = a.groupByKey( {x:(Int,Int)=>x._1} ) val mappedGroups = grouped.mapGroups((k,x)=> {(k,1)} ) val yyy = sc.broadcast(1) val last = mappedGroups.rdd.map(xx=> { val simpley = yyy.value 1 } ) ``` Author: Ergin Seyfe Closes #15706 from seyfe/keyvaluegrouped_serialization. --- .../scala/org/apache/spark/repl/ReplSuite.scala | 17 +++++++++++++++++ .../spark/sql/KeyValueGroupedDataset.scala | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 9262e938c2a6..96d2dfc2658b 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -473,4 +473,21 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("AssertionError", output) assertDoesNotContain("Exception", output) } + + test("SPARK-18189: Fix serialization issue in KeyValueGroupedDataset") { + val resultValue = 12345 + val output = runInterpreter("local", + s""" + |val keyValueGrouped = Seq((1, 2), (3, 4)).toDS().groupByKey(_._1) + |val mapGroups = keyValueGrouped.mapGroups((k, v) => (k, 1)) + |val broadcasted = sc.broadcast($resultValue) + | + |// Using broadcast triggers serialization issue in KeyValueGroupedDataset + |val dataset = mapGroups.map(_ => broadcasted.value) + |dataset.collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains(s": Array[Int] = Array($resultValue, $resultValue)", output) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 4cb0313aa903..31ce8eb25e80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.expressions.ReduceAggregator class KeyValueGroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], vEncoder: Encoder[V], - val queryExecution: QueryExecution, + @transient val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { From d0272b436512b71f04313e109d3d21a6e9deefca Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 1 Nov 2016 11:25:11 -0700 Subject: [PATCH 077/381] [SPARK-18148][SQL] Misleading Error Message for Aggregation Without Window/GroupBy ## What changes were proposed in this pull request? Aggregation Without Window/GroupBy expressions will fail in `checkAnalysis`, the error message is a bit misleading, we should generate a more specific error message for this case. For example, ``` spark.read.load("/some-data") .withColumn("date_dt", to_date($"date")) .withColumn("year", year($"date_dt")) .withColumn("week", weekofyear($"date_dt")) .withColumn("user_count", count($"userId")) .withColumn("daily_max_in_week", max($"user_count").over(weeklyWindow)) ) ``` creates the following output: ``` org.apache.spark.sql.AnalysisException: expression '`randomColumn`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; ``` In the error message above, `randomColumn` doesn't appear in the query(acturally it's added by function `withColumn`), so the message is not enough for the user to address the problem. ## How was this patch tested? Manually test Before: ``` scala> spark.sql("select col, count(col) from tbl") org.apache.spark.sql.AnalysisException: expression 'tbl.`col`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;; ``` After: ``` scala> spark.sql("select col, count(col) from tbl") org.apache.spark.sql.AnalysisException: grouping expressions sequence is empty, and 'tbl.`col`' is not an aggregate function. Wrap '(count(col#231L) AS count(col)#239L)' in windowing function(s) or wrap 'tbl.`col`' in first() (or first_value) if you don't care which value you get.;; ``` Also add new test sqls in `group-by.sql`. Author: jiangxingbo Closes #15672 from jiangxb1987/groupBy-empty. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 ++ .../resources/sql-tests/inputs/group-by.sql | 41 +++++-- .../sql-tests/results/group-by.sql.out | 116 +++++++++++++++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 35 ------ 4 files changed, 140 insertions(+), 64 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 9a7c2a944b58..3455a567b778 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -214,6 +214,18 @@ trait CheckAnalysis extends PredicateHelper { s"appear in the arguments of an aggregate function.") } } + case e: Attribute if groupingExprs.isEmpty => + // Collect all [[AggregateExpressions]]s. + val aggExprs = aggregateExprs.filter(_.collect { + case a: AggregateExpression => a + }.nonEmpty) + failAnalysis( + s"grouping expressions sequence is empty, " + + s"and '${e.sql}' is not an aggregate function. " + + s"Wrap '${aggExprs.map(_.sql).mkString("(", ", ", ")")}' in windowing " + + s"function(s) or wrap '${e.sql}' in first() (or first_value) " + + s"if you don't care which value you get." + ) case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.sql}' is neither present in the group by, " + diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 6741703d9d82..d950ec83d98c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -1,17 +1,34 @@ --- Temporary data. -create temporary view myview as values 128, 256 as v(int_col); +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null) +AS testData(a, b); --- group by should produce all input rows, -select int_col, count(*) from myview group by int_col; +-- Aggregate with empty GroupBy expressions. +SELECT a, COUNT(b) FROM testData; +SELECT COUNT(a), COUNT(b) FROM testData; --- group by should produce a single row. -select 'foo', count(*) from myview group by 1; +-- Aggregate with non-empty GroupBy expressions. +SELECT a, COUNT(b) FROM testData GROUP BY a; +SELECT a, COUNT(b) FROM testData GROUP BY b; +SELECT COUNT(a), COUNT(b) FROM testData GROUP BY a; --- group-by should not produce any rows (whole stage code generation). -select 'foo' from myview where int_col == 0 group by 1; +-- Aggregate grouped by literals. +SELECT 'foo', COUNT(a) FROM testData GROUP BY 1; --- group-by should not produce any rows (hash aggregate). -select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1; +-- Aggregate grouped by literals (whole stage code generation). +SELECT 'foo' FROM testData WHERE a = 0 GROUP BY 1; --- group-by should not produce any rows (sort aggregate). -select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1; +-- Aggregate grouped by literals (hash aggregate). +SELECT 'foo', APPROX_COUNT_DISTINCT(a) FROM testData WHERE a = 0 GROUP BY 1; + +-- Aggregate grouped by literals (sort aggregate). +SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1; + +-- Aggregate with complex GroupBy expressions. +SELECT a + b, COUNT(b) FROM testData GROUP BY a + b; +SELECT a + 2, COUNT(b) FROM testData GROUP BY a + 1; +SELECT a + 1 + 1, COUNT(b) FROM testData GROUP BY a + 1; + +-- Aggregate with nulls. +SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) +FROM testData; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 9127bd4dd4c6..a91f04e098b1 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,9 +1,11 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 14 -- !query 0 -create temporary view myview as values 128, 256 as v(int_col) +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null) +AS testData(a, b) -- !query 0 schema struct<> -- !query 0 output @@ -11,41 +13,121 @@ struct<> -- !query 1 -select int_col, count(*) from myview group by int_col +SELECT a, COUNT(b) FROM testData -- !query 1 schema -struct +struct<> -- !query 1 output -128 1 -256 1 +org.apache.spark.sql.AnalysisException +grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) AS `count(b)`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.; -- !query 2 -select 'foo', count(*) from myview group by 1 +SELECT COUNT(a), COUNT(b) FROM testData -- !query 2 schema -struct +struct -- !query 2 output -foo 2 +7 7 -- !query 3 -select 'foo' from myview where int_col == 0 group by 1 +SELECT a, COUNT(b) FROM testData GROUP BY a -- !query 3 schema -struct +struct -- !query 3 output - +1 2 +2 2 +3 2 +NULL 1 -- !query 4 -select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1 +SELECT a, COUNT(b) FROM testData GROUP BY b -- !query 4 schema -struct +struct<> -- !query 4 output - +org.apache.spark.sql.AnalysisException +expression 'testdata.`a`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; -- !query 5 -select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1 +SELECT COUNT(a), COUNT(b) FROM testData GROUP BY a -- !query 5 schema -struct> +struct -- !query 5 output +0 1 +2 2 +2 2 +3 2 + + +-- !query 6 +SELECT 'foo', COUNT(a) FROM testData GROUP BY 1 +-- !query 6 schema +struct +-- !query 6 output +foo 7 + + +-- !query 7 +SELECT 'foo' FROM testData WHERE a = 0 GROUP BY 1 +-- !query 7 schema +struct +-- !query 7 output + + +-- !query 8 +SELECT 'foo', APPROX_COUNT_DISTINCT(a) FROM testData WHERE a = 0 GROUP BY 1 +-- !query 8 schema +struct +-- !query 8 output + + + +-- !query 9 +SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1 +-- !query 9 schema +struct> +-- !query 9 output + + + +-- !query 10 +SELECT a + b, COUNT(b) FROM testData GROUP BY a + b +-- !query 10 schema +struct<(a + b):int,count(b):bigint> +-- !query 10 output +2 1 +3 2 +4 2 +5 1 +NULL 1 + + +-- !query 11 +SELECT a + 2, COUNT(b) FROM testData GROUP BY a + 1 +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +expression 'testdata.`a`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; + + +-- !query 12 +SELECT a + 1 + 1, COUNT(b) FROM testData GROUP BY a + 1 +-- !query 12 schema +struct<((a + 1) + 1):int,count(b):bigint> +-- !query 12 output +3 2 +4 2 +5 2 +NULL 1 + + +-- !query 13 +SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) +FROM testData +-- !query 13 schema +struct +-- !query 13 output +-0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1a43d0b2205c..9a3d93cf17b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -463,20 +463,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } - test("agg") { - checkAnswer( - sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), - Seq(Row(1, 3), Row(2, 3), Row(3, 3))) - } - - test("aggregates with nulls") { - checkAnswer( - sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + - "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), - Row(0, -1.5, 1, 3, 2, 1.0, 1, 6, 3) - ) - } - test("select *") { checkAnswer( sql("SELECT * FROM testData"), @@ -1178,27 +1164,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1)) } - test("throw errors for non-aggregate attributes with aggregation") { - def checkAggregation(query: String, isInvalidQuery: Boolean = true) { - if (isInvalidQuery) { - val e = intercept[AnalysisException](sql(query).queryExecution.analyzed) - assert(e.getMessage contains "group by") - } else { - // Should not throw - sql(query).queryExecution.analyzed - } - } - - checkAggregation("SELECT key, COUNT(*) FROM testData") - checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", isInvalidQuery = false) - - checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") - checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) - - checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") - checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) - } - testQuietly( "SPARK-16748: SparkExceptions during planning should not wrapped in TreeNodeException") { intercept[SparkException] { From cfac17ee1cec414663b957228e469869eb7673c1 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 1 Nov 2016 12:35:34 -0700 Subject: [PATCH 078/381] [SPARK-18167] Disable flaky SQLQuerySuite test We now know it's a persistent environmental issue that is causing this test to sometimes fail. One hypothesis is that some configuration is leaked from another suite, and depending on suite ordering this can cause this test to fail. I am planning on mining the jenkins logs to try to narrow down which suite could be causing this. For now, disable the test. Author: Eric Liang Closes #15720 from ericl/disable-flaky-test. --- .../org/apache/spark/sql/hive/execution/SQLQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8b916932ff54..b9353b5b5d2a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1565,7 +1565,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } - test("SPARK-10562: partition by column with mixed case name") { + ignore("SPARK-10562: partition by column with mixed case name") { def runOnce() { withTable("tbl10562") { val df = Seq(2012 -> "a").toDF("Year", "val") From 01dd0083011741c2bbe5ae1d2a25f2c9a1302b76 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 1 Nov 2016 12:46:41 -0700 Subject: [PATCH 079/381] [SPARK-17764][SQL] Add `to_json` supporting to convert nested struct column to JSON string ## What changes were proposed in this pull request? This PR proposes to add `to_json` function in contrast with `from_json` in Scala, Java and Python. It'd be useful if we can convert a same column from/to json. Also, some datasources do not support nested types. If we are forced to save a dataframe into those data sources, we might be able to work around by this function. The usage is as below: ``` scala val df = Seq(Tuple1(Tuple1(1))).toDF("a") df.select(to_json($"a").as("json")).show() ``` ``` bash +--------+ | json| +--------+ |{"_1":1}| +--------+ ``` ## How was this patch tested? Unit tests in `JsonFunctionsSuite` and `JsonExpressionsSuite`. Author: hyukjinkwon Closes #15354 from HyukjinKwon/SPARK-17764. --- python/pyspark/sql/functions.py | 23 +++++++++ python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/streaming.py | 2 +- .../expressions/jsonExpressions.scala | 48 ++++++++++++++++++- .../sql/catalyst}/json/JacksonGenerator.scala | 5 +- .../sql/catalyst/json/JacksonUtils.scala | 26 ++++++++++ .../expressions/JsonExpressionsSuite.scala | 9 ++++ .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../datasources/json/JsonFileFormat.scala | 2 +- .../org/apache/spark/sql/functions.scala | 44 ++++++++++++++++- .../apache/spark/sql/JsonFunctionsSuite.scala | 30 +++++++++--- 11 files changed, 177 insertions(+), 16 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/json/JacksonGenerator.scala (98%) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7fa3fd2de7dd..45e3c22bfc6a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1744,6 +1744,29 @@ def from_json(col, schema, options={}): return Column(jc) +@ignore_unicode_prefix +@since(2.1) +def to_json(col, options={}): + """ + Converts a column containing a [[StructType]] into a JSON string. Throws an exception, + in the case of an unsupported type. + + :param col: name of column containing the struct + :param options: options to control converting. accepts the same options as the json datasource + + >>> from pyspark.sql import Row + >>> from pyspark.sql.types import * + >>> data = [(1, Row(name='Alice', age=2))] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json=u'{"age":2,"name":"Alice"}')] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.to_json(_to_java_column(col), options) + return Column(jc) + + @since(1.5) def size(col): """ diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index bc786ef95ed0..b0c51b1e9992 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -161,7 +161,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None): """ Loads a JSON file (`JSON Lines text format or newline-delimited JSON - <[http://jsonlines.org/>`_) or an RDD of Strings storing JSON objects (one object per + `_) or an RDD of Strings storing JSON objects (one object per record) and returns the result as a :class`DataFrame`. If the ``schema`` parameter is not specified, this function goes diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 559647bbabf6..1c94413e3c45 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -641,7 +641,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, timestampFormat=None): """ Loads a JSON file stream (`JSON Lines text format or newline-delimited JSON - <[http://jsonlines.org/>`_) and returns a :class`DataFrame`. + `_) and returns a :class`DataFrame`. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 65dbd6a4e3f1..244a5a34f359 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayOutputStream, StringWriter} +import java.io.{ByteArrayOutputStream, CharArrayWriter, StringWriter} import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions, SparkSQLJsonProcessingException} +import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.util.ParseModes import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -494,3 +495,46 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child: override def inputTypes: Seq[AbstractDataType] = StringType :: Nil } + +/** + * Converts a [[StructType]] to a json output string. + */ +case class StructToJson(options: Map[String, String], child: Expression) + extends Expression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + @transient + lazy val writer = new CharArrayWriter() + + @transient + lazy val gen = + new JacksonGenerator(child.dataType.asInstanceOf[StructType], writer) + + override def dataType: DataType = StringType + override def children: Seq[Expression] = child :: Nil + + override def checkInputDataTypes(): TypeCheckResult = { + if (StructType.acceptsType(child.dataType)) { + try { + JacksonUtils.verifySchema(child.dataType.asInstanceOf[StructType]) + TypeCheckResult.TypeCheckSuccess + } catch { + case e: UnsupportedOperationException => + TypeCheckResult.TypeCheckFailure(e.getMessage) + } + } else { + TypeCheckResult.TypeCheckFailure( + s"$prettyName requires that the expression is a struct expression.") + } + } + + override def eval(input: InternalRow): Any = { + gen.write(child.eval(input).asInstanceOf[InternalRow]) + gen.flush() + val json = writer.toString + writer.reset() + UTF8String.fromString(json) + } + + override def inputTypes: Seq[AbstractDataType] = StructType :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 5b55b701862b..4b548e0e7f97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -15,15 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import java.io.Writer import com.fasterxml.jackson.core._ -import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.JSONOptions +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala index c4d9abb2c07e..3b23c6cd2816 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.json import com.fasterxml.jackson.core.{JsonParser, JsonToken} +import org.apache.spark.sql.types._ + object JacksonUtils { /** * Advance the parser until a null or a specific token is found @@ -29,4 +31,28 @@ object JacksonUtils { case x => x != stopOn } } + + /** + * Verify if the schema is supported in JSON parsing. + */ + def verifySchema(schema: StructType): Unit = { + def verifyType(name: String, dataType: DataType): Unit = dataType match { + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | + DoubleType | StringType | TimestampType | DateType | BinaryType | _: DecimalType => + + case st: StructType => st.foreach(field => verifyType(field.name, field.dataType)) + + case at: ArrayType => verifyType(name, at.elementType) + + case mt: MapType => verifyType(name, mt.keyType) + + case udt: UserDefinedType[_] => verifyType(name, udt.sqlType) + + case _ => + throw new UnsupportedOperationException( + s"Unable to convert column $name of type ${dataType.simpleString} to JSON.") + } + + schema.foreach(field => verifyType(field.name, field.dataType)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 84623934d95d..f9db649bc240 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -343,4 +343,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { null ) } + + test("to_json") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + val struct = Literal.create(create_row(1), schema) + checkEvaluation( + StructToJson(Map.empty, struct), + """{"a":1}""" + ) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6e0a2471e0fb..eb2b20afc37c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.json.JacksonGenerator import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -45,7 +46,6 @@ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 5a409c04c929..0e38aefecb67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextOutputWriter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5f1efd22d820..944a476114fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2883,10 +2883,10 @@ object functions { * (Scala-specific) Parses a column containing a JSON string into a [[StructType]] with the * specified schema. Returns `null`, in the case of an unparseable string. * + * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string * @param options options to control how the json is parsed. accepts the same options and the * json data source. - * @param e a string column containing JSON data. * * @group collection_funcs * @since 2.1.0 @@ -2936,6 +2936,48 @@ object functions { def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = from_json(e, DataType.fromJson(schema).asInstanceOf[StructType], options) + + /** + * (Scala-specific) Converts a column containing a [[StructType]] into a JSON string with the + * specified schema. Throws an exception, in the case of an unsupported type. + * + * @param e a struct column. + * @param options options to control how the struct column is converted into a json string. + * accepts the same options and the json data source. + * + * @group collection_funcs + * @since 2.1.0 + */ + def to_json(e: Column, options: Map[String, String]): Column = withExpr { + StructToJson(options, e.expr) + } + + /** + * (Java-specific) Converts a column containing a [[StructType]] into a JSON string with the + * specified schema. Throws an exception, in the case of an unsupported type. + * + * @param e a struct column. + * @param options options to control how the struct column is converted into a json string. + * accepts the same options and the json data source. + * + * @group collection_funcs + * @since 2.1.0 + */ + def to_json(e: Column, options: java.util.Map[String, String]): Column = + to_json(e, options.asScala.toMap) + + /** + * Converts a column containing a [[StructType]] into a JSON string with the + * specified schema. Throws an exception, in the case of an unsupported type. + * + * @param e a struct column. + * + * @group collection_funcs + * @since 2.1.0 + */ + def to_json(e: Column): Column = + to_json(e, Map.empty[String, String]) + /** * Returns length of array or map. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 518d6e92b2ff..59ae889cf3b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions.from_json +import org.apache.spark.sql.functions.{from_json, struct, to_json} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{CalendarIntervalType, IntegerType, StructType} class JsonFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -31,7 +31,6 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row("alice", "5")) } - val tuples: Seq[(String, String)] = ("1", """{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") :: ("2", """{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: @@ -97,7 +96,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(expr, expected) } - test("json_parser") { + test("from_json") { val df = Seq("""{"a": 1}""").toDS() val schema = new StructType().add("a", IntegerType) @@ -106,7 +105,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Row(1)) :: Nil) } - test("json_parser missing columns") { + test("from_json missing columns") { val df = Seq("""{"a": 1}""").toDS() val schema = new StructType().add("b", IntegerType) @@ -115,7 +114,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Row(null)) :: Nil) } - test("json_parser invalid json") { + test("from_json invalid json") { val df = Seq("""{"a" 1}""").toDS() val schema = new StructType().add("a", IntegerType) @@ -123,4 +122,23 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { df.select(from_json($"value", schema)), Row(null) :: Nil) } + + test("to_json") { + val df = Seq(Tuple1(Tuple1(1))).toDF("a") + + checkAnswer( + df.select(to_json($"a")), + Row("""{"_1":1}""") :: Nil) + } + + test("to_json unsupported type") { + val df = Seq(Tuple1(Tuple1("interval -3 month 7 hours"))).toDF("a") + .select(struct($"a._1".cast(CalendarIntervalType).as("a")).as("c")) + val e = intercept[AnalysisException]{ + // Unsupported type throws an exception + df.select(to_json($"c")).collect() + } + assert(e.getMessage.contains( + "Unable to convert column a of type calendarinterval to JSON.")) + } } From 6e6298154aba63831a292117797798131a646869 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Nov 2016 16:23:47 -0700 Subject: [PATCH 080/381] [SPARK-17350][SQL] Disable default use of KryoSerializer in Thrift Server In SPARK-4761 / #3621 (December 2014) we enabled Kryo serialization by default in the Spark Thrift Server. However, I don't think that the original rationale for doing this still holds now that most Spark SQL serialization is now performed via encoders and our UnsafeRow format. In addition, the use of Kryo as the default serializer can introduce performance problems because the creation of new KryoSerializer instances is expensive and we haven't performed instance-reuse optimizations in several code paths (including DirectTaskResult deserialization). Given all of this, I propose to revert back to using JavaSerializer as the default serializer in the Thrift Server. /cc liancheng Author: Josh Rosen Closes #14906 from JoshRosen/disable-kryo-in-thriftserver. --- docs/configuration.md | 5 ++--- .../spark/sql/hive/thriftserver/SparkSQLEnv.scala | 10 ---------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 780fc94908d3..0017219e0726 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -767,7 +767,7 @@ Apart from these, the following properties are also available, and may be useful spark.kryo.referenceTracking - true (false when using Spark SQL Thrift Server) + true Whether to track references to the same object when serializing data with Kryo, which is necessary if your object graphs have loops and useful for efficiency if they contain multiple @@ -838,8 +838,7 @@ Apart from these, the following properties are also available, and may be useful spark.serializer - org.apache.spark.serializer.
JavaSerializer (org.apache.spark.serializer.
- KryoSerializer when using Spark SQL Thrift Server) + org.apache.spark.serializer.
JavaSerializer Class to use for serializing objects that will be sent over the network or need to be cached diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 638911599aad..78a309497ab5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive.thriftserver import java.io.PrintStream -import scala.collection.JavaConverters._ - import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} @@ -37,8 +35,6 @@ private[hive] object SparkSQLEnv extends Logging { def init() { if (sqlContext == null) { val sparkConf = new SparkConf(loadDefaults = true) - val maybeSerializer = sparkConf.getOption("spark.serializer") - val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking") // If user doesn't specify the appName, we want to get [SparkSQL::localHostName] instead of // the default appName [SparkSQLCLIDriver] in cli or beeline. val maybeAppName = sparkConf @@ -47,12 +43,6 @@ private[hive] object SparkSQLEnv extends Logging { sparkConf .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}")) - .set( - "spark.serializer", - maybeSerializer.getOrElse("org.apache.spark.serializer.KryoSerializer")) - .set( - "spark.kryo.referenceTracking", - maybeKryoReferenceTracking.getOrElse("false")) val sparkSession = SparkSession.builder.config(sparkConf).enableHiveSupport().getOrCreate() sparkContext = sparkSession.sparkContext From b929537b6eb0f8f34497c3dbceea8045bf5dffdb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Nov 2016 16:49:41 -0700 Subject: [PATCH 081/381] [SPARK-18182] Expose ReplayListenerBus.read() overload which takes string iterator The `ReplayListenerBus.read()` method is used when implementing a custom `ApplicationHistoryProvider`. The current interface only exposes a `read()` method which takes an `InputStream` and performs stream-to-lines conversion itself, but it would also be useful to expose an overloaded method which accepts an iterator of strings, thereby enabling events to be provided from non-`InputStream` sources. Author: Josh Rosen Closes #15698 from JoshRosen/replay-listener-bus-interface. --- .../spark/scheduler/ReplayListenerBus.scala | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 2424586431aa..0bd5a6bc59a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -53,13 +53,24 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { sourceName: String, maybeTruncated: Boolean = false, eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { + val lines = Source.fromInputStream(logData).getLines() + replay(lines, sourceName, maybeTruncated, eventsFilter) + } + /** + * Overloaded variant of [[replay()]] which accepts an iterator of lines instead of an + * [[InputStream]]. Exposed for use by custom ApplicationHistoryProvider implementations. + */ + def replay( + lines: Iterator[String], + sourceName: String, + maybeTruncated: Boolean, + eventsFilter: ReplayEventsFilter): Unit = { var currentLine: String = null var lineNumber: Int = 0 try { - val lineEntries = Source.fromInputStream(logData) - .getLines() + val lineEntries = lines .zipWithIndex .filter { case (line, _) => eventsFilter(line) } From 91c33a0ca5c8287f710076ed7681e5aa13ca068f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 1 Nov 2016 17:00:00 -0700 Subject: [PATCH 082/381] [SPARK-18088][ML] Various ChiSqSelector cleanups ## What changes were proposed in this pull request? - Renamed kbest to numTopFeatures - Renamed alpha to fpr - Added missing Since annotations - Doc cleanups ## How was this patch tested? Added new standardized unit tests for spark.ml. Improved existing unit test coverage a bit. Author: Joseph K. Bradley Closes #15647 from jkbradley/chisqselector-follow-ups. --- docs/ml-features.md | 12 +- docs/mllib-feature-extraction.md | 15 +- .../spark/ml/feature/ChiSqSelector.scala | 59 ++++---- .../mllib/api/python/PythonMLLibAPI.scala | 4 +- .../spark/mllib/feature/ChiSqSelector.scala | 45 +++--- .../spark/ml/feature/ChiSqSelectorSuite.scala | 135 ++++++++++-------- .../mllib/feature/ChiSqSelectorSuite.scala | 17 +-- python/pyspark/ml/feature.py | 37 ++--- python/pyspark/mllib/feature.py | 58 ++++---- 9 files changed, 197 insertions(+), 185 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 64c6a160239c..352887d3ba6e 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1338,14 +1338,14 @@ for more details on the API. `ChiSqSelector` stands for Chi-Squared feature selection. It operates on labeled data with categorical features. ChiSqSelector uses the [Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which -features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`: +features to choose. It supports three selection methods: `numTopFeatures`, `percentile`, `fpr`: -* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. -* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number. -* `FPR` chooses all features whose false positive rate meets some threshold. +* `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. +* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number. +* `fpr` chooses all features whose p-value is below a threshold, thus controlling the false positive rate of selection. -By default, the selection method is `KBest`, the default number of top features is 50. User can use -`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods. +By default, the selection method is `numTopFeatures`, with the default number of top features set to 50. +The user can choose a selection method using `setSelectorType`. **Examples** diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 87e1e027e945..42568c312e70 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -227,22 +227,19 @@ both speed and statistical learning behavior. [`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements Chi-Squared feature selection. It operates on labeled data with categorical features. ChiSqSelector uses the [Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which -features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`: +features to choose. It supports three selection methods: `numTopFeatures`, `percentile`, `fpr`: -* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. -* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number. -* `FPR` chooses all features whose false positive rate meets some threshold. +* `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. +* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number. +* `fpr` chooses all features whose p-value is below a threshold, thus controlling the false positive rate of selection. -By default, the selection method is `KBest`, the default number of top features is 50. User can use -`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods. +By default, the selection method is `numTopFeatures`, with the default number of top features set to 50. +The user can choose a selection method using `setSelectorType`. The number of features to select can be tuned using a held-out validation set. ### Model Fitting -`ChiSqSelector` takes a `numTopFeatures` parameter specifying the number of top features that -the selector will select. - The [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method takes an input of `RDD[LabeledPoint]` with categorical features, learns the summary statistics, and then returns a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index d0385e220e1e..653fa41124f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -42,69 +42,80 @@ private[feature] trait ChiSqSelectorParams extends Params with HasFeaturesCol with HasOutputCol with HasLabelCol { /** - * Number of features that selector will select (ordered by statistic value descending). If the + * Number of features that selector will select, ordered by ascending p-value. If the * number of features is less than numTopFeatures, then this will select all features. - * Only applicable when selectorType = "kbest". + * Only applicable when selectorType = "numTopFeatures". * The default value of numTopFeatures is 50. * * @group param */ + @Since("1.6.0") final val numTopFeatures = new IntParam(this, "numTopFeatures", - "Number of features that selector will select, ordered by statistics value descending. If the" + + "Number of features that selector will select, ordered by ascending p-value. If the" + " number of features is < numTopFeatures, then this will select all features.", ParamValidators.gtEq(1)) setDefault(numTopFeatures -> 50) /** @group getParam */ + @Since("1.6.0") def getNumTopFeatures: Int = $(numTopFeatures) /** * Percentile of features that selector will select, ordered by statistics value descending. * Only applicable when selectorType = "percentile". * Default value is 0.1. + * @group param */ + @Since("2.1.0") final val percentile = new DoubleParam(this, "percentile", - "Percentile of features that selector will select, ordered by statistics value descending.", + "Percentile of features that selector will select, ordered by ascending p-value.", ParamValidators.inRange(0, 1)) setDefault(percentile -> 0.1) /** @group getParam */ + @Since("2.1.0") def getPercentile: Double = $(percentile) /** * The highest p-value for features to be kept. * Only applicable when selectorType = "fpr". * Default value is 0.05. + * @group param */ - final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.", + final val fpr = new DoubleParam(this, "fpr", "The highest p-value for features to be kept.", ParamValidators.inRange(0, 1)) - setDefault(alpha -> 0.05) + setDefault(fpr -> 0.05) /** @group getParam */ - def getAlpha: Double = $(alpha) + def getFpr: Double = $(fpr) /** * The selector type of the ChisqSelector. - * Supported options: "kbest" (default), "percentile" and "fpr". + * Supported options: "numTopFeatures" (default), "percentile", "fpr". + * @group param */ + @Since("2.1.0") final val selectorType = new Param[String](this, "selectorType", "The selector type of the ChisqSelector. " + - "Supported options: kbest (default), percentile and fpr.", - ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray)) - setDefault(selectorType -> OldChiSqSelector.KBest) + "Supported options: " + OldChiSqSelector.supportedSelectorTypes.mkString(", "), + ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes)) + setDefault(selectorType -> OldChiSqSelector.NumTopFeatures) /** @group getParam */ + @Since("2.1.0") def getSelectorType: String = $(selectorType) } /** * Chi-Squared feature selection, which selects categorical features to use for predicting a * categorical label. - * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. - * `kbest` chooses the `k` top features according to a chi-squared test. - * `percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `fpr` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `kbest`, the default number of top features is 50. + * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`. + * - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. + * - `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false + * positive rate of selection. + * By default, the selection method is `numTopFeatures`, with the default number of top features + * set to 50. */ @Since("1.6.0") final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String) @@ -113,10 +124,6 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str @Since("1.6.0") def this() = this(Identifiable.randomUID("chiSqSelector")) - /** @group setParam */ - @Since("2.1.0") - def setSelectorType(value: String): this.type = set(selectorType, value) - /** @group setParam */ @Since("1.6.0") def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) @@ -127,7 +134,11 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str /** @group setParam */ @Since("2.1.0") - def setAlpha(value: Double): this.type = set(alpha, value) + def setFpr(value: Double): this.type = set(fpr, value) + + /** @group setParam */ + @Since("2.1.0") + def setSelectorType(value: String): this.type = set(selectorType, value) /** @group setParam */ @Since("1.6.0") @@ -153,15 +164,15 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str .setSelectorType($(selectorType)) .setNumTopFeatures($(numTopFeatures)) .setPercentile($(percentile)) - .setAlpha($(alpha)) + .setFpr($(fpr)) val model = selector.fit(input) copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) } @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType)) - otherPairs.foreach { case (_, paramName: String) => + val otherPairs = OldChiSqSelector.supportedSelectorTypes.filter(_ != $(selectorType)) + otherPairs.foreach { paramName: String => if (isSet(getParam(paramName))) { logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 904000f50d0a..034e3625e8c0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -638,13 +638,13 @@ private[python] class PythonMLLibAPI extends Serializable { selectorType: String, numTopFeatures: Int, percentile: Double, - alpha: Double, + fpr: Double, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { new ChiSqSelector() .setSelectorType(selectorType) .setNumTopFeatures(numTopFeatures) .setPercentile(percentile) - .setAlpha(alpha) + .setFpr(fpr) .fit(data.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index f8276de4f23d..f9156b642785 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -161,7 +161,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { Loader.checkSchema[Data](dataFrame.schema) val features = dataArray.rdd.map { - case Row(feature: Int) => (feature) + case Row(feature: Int) => feature }.collect() new ChiSqSelectorModel(features) @@ -171,18 +171,20 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { /** * Creates a ChiSquared feature selector. - * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. - * `kbest` chooses the `k` top features according to a chi-squared test. - * `percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `fpr` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `kbest`, the default number of top features is 50. + * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`. + * - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. + * - `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false + * positive rate of selection. + * By default, the selection method is `numTopFeatures`, with the default number of top features + * set to 50. */ @Since("1.3.0") class ChiSqSelector @Since("2.1.0") () extends Serializable { var numTopFeatures: Int = 50 var percentile: Double = 0.1 - var alpha: Double = 0.05 - var selectorType = ChiSqSelector.KBest + var fpr: Double = 0.05 + var selectorType = ChiSqSelector.NumTopFeatures /** * The is the same to call this() and setNumTopFeatures(numTopFeatures) @@ -207,15 +209,15 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { } @Since("2.1.0") - def setAlpha(value: Double): this.type = { - require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]") - alpha = value + def setFpr(value: Double): this.type = { + require(0.0 <= value && value <= 1.0, "FPR must be in [0,1]") + fpr = value this } @Since("2.1.0") def setSelectorType(value: String): this.type = { - require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value), + require(ChiSqSelector.supportedSelectorTypes.contains(value), s"ChiSqSelector Type: $value was not supported.") selectorType = value this @@ -232,7 +234,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex val features = selectorType match { - case ChiSqSelector.KBest => + case ChiSqSelector.NumTopFeatures => chiSqTestResult .sortBy { case (res, _) => res.pValue } .take(numTopFeatures) @@ -242,7 +244,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => chiSqTestResult - .filter { case (res, _) => res.pValue < alpha } + .filter { case (res, _) => res.pValue < fpr } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } @@ -251,22 +253,17 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { } } -@Since("2.1.0") -object ChiSqSelector { +private[spark] object ChiSqSelector { - /** String name for `kbest` selector type. */ - private[spark] val KBest: String = "kbest" + /** String name for `numTopFeatures` selector type. */ + val NumTopFeatures: String = "numTopFeatures" /** String name for `percentile` selector type. */ - private[spark] val Percentile: String = "percentile" + val Percentile: String = "percentile" /** String name for `fpr` selector type. */ private[spark] val FPR: String = "fpr" - /** Set of selector type and param pairs that ChiSqSelector supports. */ - private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures", - Percentile -> "percentile", FPR -> "alpha") - /** Set of selector types that ChiSqSelector supports. */ - private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1) + val supportedSelectorTypes: Array[String] = Array(NumTopFeatures, Percentile, FPR) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 6af06d82d671..80970fd74488 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -19,85 +19,72 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.feature import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Dataset, Row} class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - test("Test Chi-Square selector") { - import testImplicits._ - val data = Seq( - LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), - LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), - LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), - LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0))) - ) + @transient var dataset: Dataset[_] = _ - val preFilteredData = Seq( - Vectors.dense(8.0), - Vectors.dense(0.0), - Vectors.dense(0.0), - Vectors.dense(8.0) - ) + override def beforeAll(): Unit = { + super.beforeAll() - val df = sc.parallelize(data.zip(preFilteredData)) - .map(x => (x._1.label, x._1.features, x._2)) - .toDF("label", "data", "preFilteredData") - - val selector = new ChiSqSelector() - .setSelectorType("kbest") - .setNumTopFeatures(1) - .setFeaturesCol("data") - .setLabelCol("label") - .setOutputCol("filtered") - - selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } - - selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df) - .select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } + // Toy dataset, including the top feature for a chi-squared test. + // These data are chosen such that each feature's test has a distinct p-value. + /* To verify the results with R, run: + library(stats) + x1 <- c(8.0, 0.0, 0.0, 7.0, 8.0) + x2 <- c(7.0, 9.0, 9.0, 9.0, 7.0) + x3 <- c(0.0, 6.0, 8.0, 5.0, 3.0) + y <- c(0.0, 1.0, 1.0, 2.0, 2.0) + chisq.test(x1,y) + chisq.test(x2,y) + chisq.test(x3,y) + */ + dataset = spark.createDataFrame(Seq( + (0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0))), Vectors.dense(8.0)), + (1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0))), Vectors.dense(0.0)), + (1.0, Vectors.dense(Array(0.0, 9.0, 8.0)), Vectors.dense(0.0)), + (2.0, Vectors.dense(Array(7.0, 9.0, 5.0)), Vectors.dense(7.0)), + (2.0, Vectors.dense(Array(8.0, 7.0, 3.0)), Vectors.dense(8.0)) + )).toDF("label", "features", "topFeature") + } - val preFilteredData2 = Seq( - Vectors.dense(8.0, 7.0), - Vectors.dense(0.0, 9.0), - Vectors.dense(0.0, 9.0), - Vectors.dense(8.0, 9.0) - ) + test("params") { + ParamsSuite.checkParams(new ChiSqSelector) + val model = new ChiSqSelectorModel("myModel", + new org.apache.spark.mllib.feature.ChiSqSelectorModel(Array(1, 3, 4))) + ParamsSuite.checkParams(model) + } - val df2 = sc.parallelize(data.zip(preFilteredData2)) - .map(x => (x._1.label, x._1.features, x._2)) - .toDF("label", "data", "preFilteredData") + test("Test Chi-Square selector: numTopFeatures") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1) + ChiSqSelectorSuite.testSelector(selector, dataset) + } - selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2) - .select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } + test("Test Chi-Square selector: percentile") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.34) + ChiSqSelectorSuite.testSelector(selector, dataset) } - test("ChiSqSelector read/write") { - val t = new ChiSqSelector() - .setFeaturesCol("myFeaturesCol") - .setLabelCol("myLabelCol") - .setOutputCol("myOutputCol") - .setNumTopFeatures(2) - testDefaultReadWrite(t) + test("Test Chi-Square selector: fpr") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("fpr").setFpr(0.2) + ChiSqSelectorSuite.testSelector(selector, dataset) } - test("ChiSqSelectorModel read/write") { - val oldModel = new feature.ChiSqSelectorModel(Array(1, 3)) - val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel) - val newInstance = testDefaultReadWrite(instance) - assert(newInstance.selectedFeatures === instance.selectedFeatures) + test("read/write") { + def checkModelData(model: ChiSqSelectorModel, model2: ChiSqSelectorModel): Unit = { + assert(model.selectedFeatures === model2.selectedFeatures) + } + val nb = new ChiSqSelector + testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and not support other types") { @@ -108,3 +95,25 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext } } } + +object ChiSqSelectorSuite { + + private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): Unit = { + selector.fit(dataset).transform(dataset).select("filtered", "topFeature").collect() + .foreach { case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "selectorType" -> "percentile", + "numTopFeatures" -> 1, + "percentile" -> 0.12, + "outputCol" -> "myOutput" + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index ac702b4b7c69..77219e500617 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -54,33 +54,34 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))), + Seq(LabeledPoint(0.0, Vectors.dense(Array(8.0))), LabeledPoint(1.0, Vectors.dense(Array(0.0))), LabeledPoint(1.0, Vectors.dense(Array(0.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0)))) val model = new ChiSqSelector(1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) - }.collect().toSet - assert(filteredData == preFilteredData) + }.collect().toSeq + assert(filteredData === preFilteredData) } - test("ChiSqSelector by FPR transform test (sparse & dense vector)") { + test("ChiSqSelector by fpr transform test (sparse & dense vector)") { val labeledDiscreteData = sc.parallelize( Seq(LabeledPoint(0.0, Vectors.sparse(4, Array((0, 8.0), (1, 7.0)))), LabeledPoint(1.0, Vectors.sparse(4, Array((1, 9.0), (2, 6.0), (3, 4.0)))), LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 4.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0, 9.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), + Seq(LabeledPoint(0.0, Vectors.dense(Array(0.0))), LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(2.0, Vectors.dense(Array(9.0)))) - val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData) + val model: ChiSqSelectorModel = new ChiSqSelector().setSelectorType("fpr") + .setFpr(0.1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) - }.collect().toSet - assert(filteredData == preFilteredData) + }.collect().toSeq + assert(filteredData === preFilteredData) } test("model load / save") { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 94afe82a3647..635cf1304588 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2606,42 +2606,43 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja selectorType = Param(Params._dummy(), "selectorType", "The selector type of the ChisqSelector. " + - "Supported options: kbest (default), percentile and fpr.", + "Supported options: numTopFeatures (default), percentile and fpr.", typeConverter=TypeConverters.toString) numTopFeatures = \ Param(Params._dummy(), "numTopFeatures", - "Number of features that selector will select, ordered by statistics value " + - "descending. If the number of features is < numTopFeatures, then this will select " + + "Number of features that selector will select, ordered by ascending p-value. " + + "If the number of features is < numTopFeatures, then this will select " + "all features.", typeConverter=TypeConverters.toInt) percentile = Param(Params._dummy(), "percentile", "Percentile of features that selector " + - "will select, ordered by statistics value descending.", + "will select, ordered by ascending p-value.", typeConverter=TypeConverters.toFloat) - alpha = Param(Params._dummy(), "alpha", "The highest p-value for features to be kept.", - typeConverter=TypeConverters.toFloat) + fpr = Param(Params._dummy(), "fpr", "The highest p-value for features to be kept.", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, - labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05): + labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05): """ __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ - labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05) + labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05) """ super(ChiSqSelector, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) - self._setDefault(numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05) + self._setDefault(numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1, + fpr=0.05) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("2.0.0") def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, - labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05): + labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05): """ setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, \ - labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05) + labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05) Sets params for this ChiSqSelector. """ kwargs = self.setParams._input_kwargs @@ -2665,7 +2666,7 @@ def getSelectorType(self): def setNumTopFeatures(self, value): """ Sets the value of :py:attr:`numTopFeatures`. - Only applicable when selectorType = "kbest". + Only applicable when selectorType = "numTopFeatures". """ return self._set(numTopFeatures=value) @@ -2692,19 +2693,19 @@ def getPercentile(self): return self.getOrDefault(self.percentile) @since("2.1.0") - def setAlpha(self, value): + def setFpr(self, value): """ - Sets the value of :py:attr:`alpha`. + Sets the value of :py:attr:`fpr`. Only applicable when selectorType = "fpr". """ - return self._set(alpha=value) + return self._set(fpr=value) @since("2.1.0") - def getAlpha(self): + def getFpr(self): """ - Gets the value of alpha or its default value. + Gets the value of fpr or its default value. """ - return self.getOrDefault(self.alpha) + return self.getOrDefault(self.fpr) def _create_model(self, java_model): return ChiSqSelectorModel(java_model) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 50ef7c7901c2..7eaa2282cb8b 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -274,52 +274,48 @@ def transform(self, vector): class ChiSqSelector(object): """ Creates a ChiSquared feature selector. - The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. - `kbest` chooses the `k` top features according to a chi-squared test. + The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`. + `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. `percentile` is similar but chooses a fraction of all features instead of a fixed number. - `fpr` chooses all features whose false positive rate meets some threshold. - By default, the selection method is `kbest`, the default number of top features is 50. + `fpr` chooses all features whose p-value is below a threshold, thus controlling the false + positive rate of selection. + By default, the selection method is `numTopFeatures`, with the default number of top features + set to 50. - >>> data = [ + >>> data = sc.parallelize([ ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})), ... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})), ... LabeledPoint(1.0, [0.0, 9.0, 8.0]), - ... LabeledPoint(2.0, [8.0, 9.0, 5.0]) - ... ] - >>> model = ChiSqSelector().setNumTopFeatures(1).fit(sc.parallelize(data)) + ... LabeledPoint(2.0, [7.0, 9.0, 5.0]), + ... LabeledPoint(2.0, [8.0, 7.0, 3.0]) + ... ]) + >>> model = ChiSqSelector(numTopFeatures=1).fit(data) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) SparseVector(1, {}) - >>> model.transform(DenseVector([8.0, 9.0, 5.0])) - DenseVector([8.0]) - >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit( - ... sc.parallelize(data)) + >>> model.transform(DenseVector([7.0, 9.0, 5.0])) + DenseVector([7.0]) + >>> model = ChiSqSelector(selectorType="fpr", fpr=0.2).fit(data) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) SparseVector(1, {}) - >>> model.transform(DenseVector([8.0, 9.0, 5.0])) - DenseVector([8.0]) - >>> data = [ - ... LabeledPoint(0.0, SparseVector(4, {0: 8.0, 1: 7.0})), - ... LabeledPoint(1.0, SparseVector(4, {1: 9.0, 2: 6.0, 3: 4.0})), - ... LabeledPoint(1.0, [0.0, 9.0, 8.0, 4.0]), - ... LabeledPoint(2.0, [8.0, 9.0, 5.0, 9.0]) - ... ] - >>> model = ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(sc.parallelize(data)) - >>> model.transform(DenseVector([1.0,2.0,3.0,4.0])) - DenseVector([4.0]) + >>> model.transform(DenseVector([7.0, 9.0, 5.0])) + DenseVector([7.0]) + >>> model = ChiSqSelector(selectorType="percentile", percentile=0.34).fit(data) + >>> model.transform(DenseVector([7.0, 9.0, 5.0])) + DenseVector([7.0]) .. versionadded:: 1.4.0 """ - def __init__(self, numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05): + def __init__(self, numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1, fpr=0.05): self.numTopFeatures = numTopFeatures self.selectorType = selectorType self.percentile = percentile - self.alpha = alpha + self.fpr = fpr @since('2.1.0') def setNumTopFeatures(self, numTopFeatures): """ set numTopFeature for feature selection by number of top features. - Only applicable when selectorType = "kbest". + Only applicable when selectorType = "numTopFeatures". """ self.numTopFeatures = int(numTopFeatures) return self @@ -334,19 +330,19 @@ def setPercentile(self, percentile): return self @since('2.1.0') - def setAlpha(self, alpha): + def setFpr(self, fpr): """ - set alpha [0.0, 1.0] for feature selection by FPR. + set FPR [0.0, 1.0] for feature selection by FPR. Only applicable when selectorType = "fpr". """ - self.alpha = float(alpha) + self.fpr = float(fpr) return self @since('2.1.0') def setSelectorType(self, selectorType): """ set the selector type of the ChisqSelector. - Supported options: "kbest" (default), "percentile" and "fpr". + Supported options: "numTopFeatures" (default), "percentile", "fpr". """ self.selectorType = str(selectorType) return self @@ -362,7 +358,7 @@ def fit(self, data): Apply feature discretizer before using this function. """ jmodel = callMLlibFunc("fitChiSqSelector", self.selectorType, self.numTopFeatures, - self.percentile, self.alpha, data) + self.percentile, self.fpr, data) return ChiSqSelectorModel(jmodel) From 77a98162d1ec28247053b8b3ad4af28baa950797 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 1 Nov 2016 18:06:57 -0700 Subject: [PATCH 083/381] [SPARK-18025] Use commit protocol API in structured streaming ## What changes were proposed in this pull request? This patch adds a new commit protocol implementation ManifestFileCommitProtocol that follows the existing streaming flow, and uses it in FileStreamSink to consolidate the write path in structured streaming with the batch mode write path. This deletes a lot of code, and would make it trivial to support other functionalities that are currently available in batch but not in streaming, including all file formats and bucketing. ## How was this patch tested? Should be covered by existing tests. Author: Reynold Xin Closes #15710 from rxin/SPARK-18025. --- .../datasources/FileCommitProtocol.scala | 11 +- .../execution/datasources/FileFormat.scala | 14 -- ...iteOutput.scala => FileFormatWriter.scala} | 20 +- .../InsertIntoHadoopFsRelationCommand.scala | 25 +- .../parquet/ParquetFileFormat.scala | 11 - .../parquet/ParquetOutputWriter.scala | 116 +-------- .../execution/streaming/FileStreamSink.scala | 229 ++---------------- .../ManifestFileCommitProtocol.scala | 114 +++++++++ .../apache/spark/sql/internal/SQLConf.scala | 3 +- .../sql/streaming/FileStreamSinkSuite.scala | 106 +------- 10 files changed, 174 insertions(+), 475 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/{WriteOutput.scala => FileFormatWriter.scala} (97%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala index 1ce9ae4266c1..f5dd5ce22919 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala @@ -32,9 +32,9 @@ import org.apache.spark.util.Utils object FileCommitProtocol { - class TaskCommitMessage(obj: Any) extends Serializable + class TaskCommitMessage(val obj: Any) extends Serializable - object EmptyTaskCommitMessage extends TaskCommitMessage(Unit) + object EmptyTaskCommitMessage extends TaskCommitMessage(null) /** * Instantiates a FileCommitProtocol using the given className. @@ -62,8 +62,11 @@ object FileCommitProtocol { /** - * An interface to define how a Spark job commits its outputs. Implementations must be serializable, - * as the committer instance instantiated on the driver will be used for tasks on executors. + * An interface to define how a single Spark job commits its outputs. Two notes: + * + * 1. Implementations must be serializable, as the committer instance instantiated on the driver + * will be used for tasks on executors. + * 2. A committer should not be reused across multiple Spark jobs. * * The proper call sequence is: * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 9d153cec731a..4f4aaaa5026f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -55,20 +55,6 @@ trait FileFormat { options: Map[String, String], dataSchema: StructType): OutputWriterFactory - /** - * Returns a [[OutputWriterFactory]] for generating output writers that can write data. - * This method is current used only by FileStreamSinkWriter to generate output writers that - * does not use output committers to write data. The OutputWriter generated by the returned - * [[OutputWriterFactory]] must implement the method `newWriter(path)`.. - */ - def buildWriter( - sqlContext: SQLContext, - dataSchema: StructType, - options: Map[String, String]): OutputWriterFactory = { - // TODO: Remove this default implementation when the other formats have been ported - throw new UnsupportedOperationException(s"buildWriter is not supported for $this") - } - /** * Returns whether this format support returning columnar batch or not. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index a07855111b40..bc00a0a749c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -43,8 +43,8 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter -/** A helper object for writing data out to a location. */ -object WriteOutput extends Logging { +/** A helper object for writing FileFormat data out to a location. */ +object FileFormatWriter extends Logging { /** A shared job description for all the write tasks. */ private class WriteJobDescription( @@ -55,7 +55,6 @@ object WriteOutput extends Logging { val partitionColumns: Seq[Attribute], val nonPartitionColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], - val isAppend: Boolean, val path: String) extends Serializable { @@ -82,18 +81,18 @@ object WriteOutput extends Logging { sparkSession: SparkSession, plan: LogicalPlan, fileFormat: FileFormat, - outputPath: Path, + committer: FileCommitProtocol, + outputPath: String, hadoopConf: Configuration, partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], refreshFunction: (Seq[TablePartitionSpec]) => Unit, - options: Map[String, String], - isAppend: Boolean): Unit = { + options: Map[String, String]): Unit = { val job = Job.getInstance(hadoopConf) job.setOutputKeyClass(classOf[Void]) job.setOutputValueClass(classOf[InternalRow]) - FileOutputFormat.setOutputPath(job, outputPath) + FileOutputFormat.setOutputPath(job, new Path(outputPath)) val partitionSet = AttributeSet(partitionColumns) val dataColumns = plan.output.filterNot(partitionSet.contains) @@ -111,16 +110,11 @@ object WriteOutput extends Logging { partitionColumns = partitionColumns, nonPartitionColumns = dataColumns, bucketSpec = bucketSpec, - isAppend = isAppend, - path = outputPath.toString) + path = outputPath) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - val committer = FileCommitProtocol.instantiate( - sparkSession.sessionState.conf.fileCommitProtocolClass, - outputPath.toString, - isAppend) committer.setupJob(job) try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index a1221d0ae6d2..230c74a47ba2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -84,17 +84,22 @@ case class InsertIntoHadoopFsRelationCommand( val isAppend = pathExists && (mode == SaveMode.Append) if (doInsertion) { - WriteOutput.write( - sparkSession, - query, - fileFormat, - qualifiedOutputPath, - hadoopConf, - partitionColumns, - bucketSpec, - refreshFunction, - options, + val committer = FileCommitProtocol.instantiate( + sparkSession.sessionState.conf.fileCommitProtocolClass, + outputPath.toString, isAppend) + + FileFormatWriter.write( + sparkSession = sparkSession, + plan = query, + fileFormat = fileFormat, + committer = committer, + outputPath = qualifiedOutputPath.toString, + hadoopConf = hadoopConf, + partitionColumns = partitionColumns, + bucketSpec = bucketSpec, + refreshFunction = refreshFunction, + options = options) } else { logInfo("Skipping insertion into a relation that already exists.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 77c83ba38efe..b8ea7f40c4ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -415,17 +415,6 @@ class ParquetFileFormat } } } - - override def buildWriter( - sqlContext: SQLContext, - dataSchema: StructType, - options: Map[String, String]): OutputWriterFactory = { - new ParquetOutputWriterFactory( - sqlContext.conf, - dataSchema, - sqlContext.sessionState.newHadoopConf(), - options) - } } object ParquetFileFormat extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala index 92d4f27be3fd..5c0f8af17a23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -17,125 +17,13 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl -import org.apache.parquet.hadoop.{ParquetOutputFormat, ParquetRecordWriter} -import org.apache.parquet.hadoop.codec.CodecConfig -import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.SerializableConfiguration - - -/** - * A factory for generating OutputWriters for writing parquet files. This implemented is different - * from the [[ParquetOutputWriter]] as this does not use any [[OutputCommitter]]. It simply - * writes the data to the path used to generate the output writer. Callers of this factory - * has to ensure which files are to be considered as committed. - */ -private[parquet] class ParquetOutputWriterFactory( - sqlConf: SQLConf, - dataSchema: StructType, - hadoopConf: Configuration, - options: Map[String, String]) - extends OutputWriterFactory { - - private val serializableConf: SerializableConfiguration = { - val job = Job.getInstance(hadoopConf) - val conf = ContextUtil.getConfiguration(job) - val parquetOptions = new ParquetOptions(options, sqlConf) - - // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override - // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why - // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is - // bundled with `ParquetOutputFormat[Row]`. - job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) - - ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) - - // We want to clear this temporary metadata from saving into Parquet file. - // This metadata is only useful for detecting optional columns when pushing down filters. - val dataSchemaToWrite = StructType.removeMetadata( - StructType.metadataKeyForOptionalField, - dataSchema).asInstanceOf[StructType] - ParquetWriteSupport.setSchema(dataSchemaToWrite, conf) - - // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) - // and `CatalystWriteSupport` (writing actual rows to Parquet files). - conf.set( - SQLConf.PARQUET_BINARY_AS_STRING.key, - sqlConf.isParquetBinaryAsString.toString) - - conf.set( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sqlConf.isParquetINT96AsTimestamp.toString) - - conf.set( - SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, - sqlConf.writeLegacyParquetFormat.toString) - - // Sets compression scheme - conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) - new SerializableConfiguration(conf) - } - - /** - * Returns a [[OutputWriter]] that writes data to the give path without using - * [[OutputCommitter]]. - */ - override def newWriter(path: String): OutputWriter = new OutputWriter { - - // Create TaskAttemptContext that is used to pass on Configuration to the ParquetRecordWriter - private val hadoopTaskAttemptId = new TaskAttemptID(new TaskID(new JobID, TaskType.MAP, 0), 0) - private val hadoopAttemptContext = new TaskAttemptContextImpl( - serializableConf.value, hadoopTaskAttemptId) - - // Instance of ParquetRecordWriter that does not use OutputCommitter - private val recordWriter = createNoCommitterRecordWriter(path, hadoopAttemptContext) - - override def write(row: Row): Unit = { - throw new UnsupportedOperationException("call writeInternal") - } - - protected[sql] override def writeInternal(row: InternalRow): Unit = { - recordWriter.write(null, row) - } - - override def close(): Unit = recordWriter.close(hadoopAttemptContext) - } - - /** Create a [[ParquetRecordWriter]] that writes the given path without using OutputCommitter */ - private def createNoCommitterRecordWriter( - path: String, - hadoopAttemptContext: TaskAttemptContext): RecordWriter[Void, InternalRow] = { - // Custom ParquetOutputFormat that disable use of committer and writes to the given path - val outputFormat = new ParquetOutputFormat[InternalRow]() { - override def getOutputCommitter(c: TaskAttemptContext): OutputCommitter = { null } - override def getDefaultWorkFile(c: TaskAttemptContext, ext: String): Path = { new Path(path) } - } - outputFormat.getRecordWriter(hadoopAttemptContext) - } - - /** Disable the use of the older API. */ - override def newInstance( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - throw new UnsupportedOperationException("this version of newInstance not supported for " + - "ParquetOutputWriterFactory") - } - - override def getFileExtension(context: TaskAttemptContext): String = { - CodecConfig.from(context).getCodec.getExtension + ".parquet" - } -} - +import org.apache.spark.sql.execution.datasources.OutputWriter // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[parquet] class ParquetOutputWriter(path: String, context: TaskAttemptContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 02c5b857ee7f..daec2b545097 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -17,23 +17,12 @@ package org.apache.spark.sql.execution.streaming -import java.util.UUID - -import scala.collection.mutable.ArrayBuffer - -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkEnv, SparkException, TaskContext, TaskContextImpl} import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.UnsafeKVExternalSorter -import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriter, PartitioningUtils} -import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter +import org.apache.spark.sql.execution.datasources.{FileCommitProtocol, FileFormat, FileFormatWriter} object FileStreamSink { // The name of the subdirectory that is used to store metadata about which files are valid. @@ -59,207 +48,41 @@ class FileStreamSink( private val fileLog = new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString) private val hadoopConf = sparkSession.sessionState.newHadoopConf() - private val fs = basePath.getFileSystem(hadoopConf) override def addBatch(batchId: Long, data: DataFrame): Unit = { if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) { logInfo(s"Skipping already committed batch $batchId") } else { - val writer = new FileStreamSinkWriter( - data, fileFormat, path, partitionColumnNames, hadoopConf, options) - val fileStatuses = writer.write() - if (fileLog.add(batchId, fileStatuses)) { - logInfo(s"Committed batch $batchId") - } else { - throw new IllegalStateException(s"Race while writing batch $batchId") + val committer = FileCommitProtocol.instantiate( + sparkSession.sessionState.conf.streamingFileCommitProtocolClass, path, isAppend = false) + committer match { + case manifestCommitter: ManifestFileCommitProtocol => + manifestCommitter.setupManifestOptions(fileLog, batchId) + case _ => // Do nothing } - } - } - - override def toString: String = s"FileSink[$path]" -} - - -/** - * Writes data given to a [[FileStreamSink]] to the given `basePath` in the given `fileFormat`, - * partitioned by the given `partitionColumnNames`. This writer always appends data to the - * directory if it already has data. - */ -class FileStreamSinkWriter( - data: DataFrame, - fileFormat: FileFormat, - basePath: String, - partitionColumnNames: Seq[String], - hadoopConf: Configuration, - options: Map[String, String]) extends Serializable with Logging { - - PartitioningUtils.validatePartitionColumn( - data.schema, partitionColumnNames, data.sqlContext.conf.caseSensitiveAnalysis) - - private val serializableConf = new SerializableConfiguration(hadoopConf) - private val dataSchema = data.schema - private val dataColumns = data.logicalPlan.output - - // Get the actual partition columns as attributes after matching them by name with - // the given columns names. - private val partitionColumns = partitionColumnNames.map { col => - val nameEquality = data.sparkSession.sessionState.conf.resolver - data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse { - throw new RuntimeException(s"Partition column $col not found in schema $dataSchema") - } - } - - // Columns that are to be written to the files. If there are partitioning columns, then - // those will not be written to the files. - private val writeColumns = { - val partitionSet = AttributeSet(partitionColumns) - dataColumns.filterNot(partitionSet.contains) - } - - // An OutputWriterFactory for generating writers in the executors for writing the files. - private val outputWriterFactory = - fileFormat.buildWriter(data.sqlContext, writeColumns.toStructType, options) - - /** Expressions that given a partition key build a string like: col1=val/col2=val/... */ - private def partitionStringExpression: Seq[Expression] = { - partitionColumns.zipWithIndex.flatMap { case (c, i) => - val escaped = - ScalaUDF( - PartitioningUtils.escapePathName _, - StringType, - Seq(Cast(c, StringType)), - Seq(StringType)) - val str = If(IsNull(c), Literal(PartitioningUtils.DEFAULT_PARTITION_NAME), escaped) - val partitionName = Literal(c.name + "=") :: str :: Nil - if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName - } - } - - /** Generate a new output writer from the writer factory */ - private def newOutputWriter(path: Path): OutputWriter = { - val newWriter = outputWriterFactory.newWriter(path.toString) - newWriter.initConverter(dataSchema) - newWriter - } - /** Write the dataframe to files. This gets called in the driver by the [[FileStreamSink]]. */ - def write(): Array[SinkFileStatus] = { - data.sqlContext.sparkContext.runJob( - data.queryExecution.toRdd, - (taskContext: TaskContext, iterator: Iterator[InternalRow]) => { - if (partitionColumns.isEmpty) { - Seq(writePartitionToSingleFile(iterator)) - } else { - writePartitionToPartitionedFiles(iterator) + // Get the actual partition columns as attributes after matching them by name with + // the given columns names. + val partitionColumns: Seq[Attribute] = partitionColumnNames.map { col => + val nameEquality = data.sparkSession.sessionState.conf.resolver + data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse { + throw new RuntimeException(s"Partition column $col not found in schema ${data.schema}") } - }).flatten - } - - /** - * Writes a RDD partition to a single file without dynamic partitioning. - * This gets called in the executor, and it uses a [[OutputWriter]] to write the data. - */ - def writePartitionToSingleFile(iterator: Iterator[InternalRow]): SinkFileStatus = { - var writer: OutputWriter = null - try { - val path = new Path(basePath, UUID.randomUUID.toString) - val fs = path.getFileSystem(serializableConf.value) - writer = newOutputWriter(path) - while (iterator.hasNext) { - writer.writeInternal(iterator.next) - } - writer.close() - writer = null - SinkFileStatus(fs.getFileStatus(path)) - } catch { - case cause: Throwable => - logError("Aborting task.", cause) - // call failure callbacks first, so we could have a chance to cleanup the writer. - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) - throw new SparkException("Task failed while writing rows.", cause) - } finally { - if (writer != null) { - writer.close() } - } - } - - /** - * Writes a RDD partition to multiple dynamically partitioned files. - * This gets called in the executor. It first sorts the data based on the partitioning columns - * and then writes the data of each key to separate files using [[OutputWriter]]s. - */ - def writePartitionToPartitionedFiles(iterator: Iterator[InternalRow]): Seq[SinkFileStatus] = { - - // Returns the partitioning columns for sorting - val getSortingKey = UnsafeProjection.create(partitionColumns, dataColumns) - - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create(writeColumns, dataColumns) - - // Returns the partition path given a partition key - val getPartitionString = - UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) - // Sort the data before write, so that we only need one writer at the same time. - val sorter = new UnsafeKVExternalSorter( - partitionColumns.toStructType, - StructType.fromAttributes(writeColumns), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get().taskMemoryManager().pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)) - - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } - logDebug(s"Sorting complete. Writing out partition files one at a time.") - - val sortedIterator = sorter.sortedIterator() - val paths = new ArrayBuffer[Path] - - // Write the sorted data to partitioned files, one for each unique key - var currentWriter: OutputWriter = null - try { - var currentKey: UnsafeRow = null - while (sortedIterator.next()) { - val nextKey = sortedIterator.getKey - - // If key changes, close current writer, and open a new writer to a new partitioned file - if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } - currentKey = nextKey.copy() - val partitionPath = getPartitionString(currentKey).getString(0) - val path = new Path(new Path(basePath, partitionPath), UUID.randomUUID.toString) - paths += path - currentWriter = newOutputWriter(path) - logInfo(s"Writing partition $currentKey to $path") - } - currentWriter.writeInternal(sortedIterator.getValue) - } - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } - if (paths.nonEmpty) { - val fs = paths.head.getFileSystem(serializableConf.value) - paths.map(p => SinkFileStatus(fs.getFileStatus(p))) - } else Seq.empty - } catch { - case cause: Throwable => - logError("Aborting task.", cause) - // call failure callbacks first, so we could have a chance to cleanup the writer. - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) - throw new SparkException("Task failed while writing rows.", cause) - } finally { - if (currentWriter != null) { - currentWriter.close() - } + FileFormatWriter.write( + sparkSession = sparkSession, + plan = data.logicalPlan, + fileFormat = fileFormat, + committer = committer, + outputPath = path, + hadoopConf = hadoopConf, + partitionColumns = partitionColumns, + bucketSpec = None, + refreshFunction = _ => (), + options = options) } } + + override def toString: String = s"FileSink[$path]" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala new file mode 100644 index 000000000000..510312267a98 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.UUID + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.datasources.FileCommitProtocol +import org.apache.spark.sql.execution.datasources.FileCommitProtocol.TaskCommitMessage + +/** + * A [[FileCommitProtocol]] that tracks the list of valid files in a manifest file, used in + * structured streaming. + * + * @param path path to write the final output to. + */ +class ManifestFileCommitProtocol(path: String) + extends FileCommitProtocol with Serializable with Logging { + + // Track the list of files added by a task, only used on the executors. + @transient private var addedFiles: ArrayBuffer[String] = _ + + @transient private var fileLog: FileStreamSinkLog = _ + private var batchId: Long = _ + + /** + * Sets up the manifest log output and the batch id for this job. + * Must be called before any other function. + */ + def setupManifestOptions(fileLog: FileStreamSinkLog, batchId: Long): Unit = { + this.fileLog = fileLog + this.batchId = batchId + } + + override def setupJob(jobContext: JobContext): Unit = { + require(fileLog != null, "setupManifestOptions must be called before this function") + // Do nothing + } + + override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { + require(fileLog != null, "setupManifestOptions must be called before this function") + val fileStatuses = taskCommits.flatMap(_.obj.asInstanceOf[Seq[SinkFileStatus]]).toArray + + if (fileLog.add(batchId, fileStatuses)) { + logInfo(s"Committed batch $batchId") + } else { + throw new IllegalStateException(s"Race while writing batch $batchId") + } + } + + override def abortJob(jobContext: JobContext): Unit = { + require(fileLog != null, "setupManifestOptions must be called before this function") + // Do nothing + } + + override def setupTask(taskContext: TaskAttemptContext): Unit = { + addedFiles = new ArrayBuffer[String] + } + + override def newTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + val uuid = UUID.randomUUID.toString + val filename = f"part-$split%05d-$uuid$ext" + + val file = dir.map { d => + new Path(new Path(path, d), filename).toString + }.getOrElse { + new Path(path, filename).toString + } + + addedFiles += file + file + } + + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { + if (addedFiles.nonEmpty) { + val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration) + val statuses: Seq[SinkFileStatus] = + addedFiles.map(f => SinkFileStatus(fs.getFileStatus(new Path(f)))) + new TaskCommitMessage(statuses) + } else { + new TaskCommitMessage(Seq.empty[SinkFileStatus]) + } + } + + override def abortTask(taskContext: TaskAttemptContext): Unit = { + // Do nothing + // TODO: we can also try delete the addedFiles as a best-effort cleanup. + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 29e79847aa38..7bb3ac02fa5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -31,6 +31,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.execution.datasources.HadoopCommitProtocolWrapper +import org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -523,7 +524,7 @@ object SQLConf { SQLConfigBuilder("spark.sql.streaming.commitProtocolClass") .internal() .stringConf - .createWithDefault(classOf[HadoopCommitProtocolWrapper].getName) + .createWithDefault(classOf[ManifestFileCommitProtocol].getName) val FILE_SINK_LOG_DELETION = SQLConfigBuilder("spark.sql.streaming.fileSink.log.deletion") .internal() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 18b42a81a098..902cf0534471 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -17,106 +17,16 @@ package org.apache.spark.sql.streaming -import java.io.File - -import org.apache.commons.io.FileUtils -import org.apache.commons.io.filefilter.{DirectoryFileFilter, RegexFileFilter} - import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{FileStreamSinkWriter, MemoryStream, MetadataLogFileIndex} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.execution.streaming.{MemoryStream, MetadataLogFileIndex} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils class FileStreamSinkSuite extends StreamTest { import testImplicits._ - - test("FileStreamSinkWriter - unpartitioned data") { - val path = Utils.createTempDir() - path.delete() - - val hadoopConf = spark.sparkContext.hadoopConfiguration - val fileFormat = new parquet.ParquetFileFormat() - - def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { - val df = spark - .range(start, end, 1, numPartitions) - .select($"id", lit(100).as("data")) - val writer = new FileStreamSinkWriter( - df, fileFormat, path.toString, partitionColumnNames = Nil, hadoopConf, Map.empty) - writer.write().map(_.path.stripPrefix("file://")) - } - - // Write and check whether new files are written correctly - val files1 = writeRange(0, 10, 2) - assert(files1.size === 2, s"unexpected number of files: $files1") - checkFilesExist(path, files1, "file not written") - checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 10).map(Row(_, 100))) - - // Append and check whether new files are written correctly and old files still exist - val files2 = writeRange(10, 20, 3) - assert(files2.size === 3, s"unexpected number of files: $files2") - assert(files2.intersect(files1).isEmpty, "old files returned") - checkFilesExist(path, files2, s"New file not written") - checkFilesExist(path, files1, s"Old file not found") - checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 20).map(Row(_, 100))) - } - - test("FileStreamSinkWriter - partitioned data") { - implicit val e = ExpressionEncoder[java.lang.Long] - val path = Utils.createTempDir() - path.delete() - - val hadoopConf = spark.sparkContext.hadoopConfiguration - val fileFormat = new parquet.ParquetFileFormat() - - def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { - val df = spark - .range(start, end, 1, numPartitions) - .flatMap(x => Iterator(x, x, x)).toDF("id") - .select($"id", lit(100).as("data1"), lit(1000).as("data2")) - - require(df.rdd.partitions.size === numPartitions) - val writer = new FileStreamSinkWriter( - df, fileFormat, path.toString, partitionColumnNames = Seq("id"), hadoopConf, Map.empty) - writer.write().map(_.path.stripPrefix("file://")) - } - - def checkOneFileWrittenPerKey(keys: Seq[Int], filesWritten: Seq[String]): Unit = { - keys.foreach { id => - assert( - filesWritten.count(_.contains(s"/id=$id/")) == 1, - s"no file for id=$id. all files: \n\t${filesWritten.mkString("\n\t")}" - ) - } - } - - // Write and check whether new files are written correctly - val files1 = writeRange(0, 10, 2) - assert(files1.size === 10, s"unexpected number of files:\n${files1.mkString("\n")}") - checkFilesExist(path, files1, "file not written") - checkOneFileWrittenPerKey(0 until 10, files1) - - val answer1 = (0 until 10).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _)) - checkAnswer(spark.read.load(path.getCanonicalPath), answer1) - - // Append and check whether new files are written correctly and old files still exist - val files2 = writeRange(0, 20, 3) - assert(files2.size === 20, s"unexpected number of files:\n${files2.mkString("\n")}") - assert(files2.intersect(files1).isEmpty, "old files returned") - checkFilesExist(path, files2, s"New file not written") - checkFilesExist(path, files1, s"Old file not found") - checkOneFileWrittenPerKey(0 until 20, files2) - - val answer2 = (0 until 20).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _)) - checkAnswer(spark.read.load(path.getCanonicalPath), answer1 ++ answer2) - } - test("FileStreamSink - unpartitioned writing and batch reading") { val inputData = MemoryStream[Int] val df = inputData.toDF() @@ -270,18 +180,4 @@ class FileStreamSinkSuite extends StreamTest { } } - private def checkFilesExist(dir: File, expectedFiles: Seq[String], msg: String): Unit = { - import scala.collection.JavaConverters._ - val files = - FileUtils.listFiles(dir, new RegexFileFilter("[^.]+"), DirectoryFileFilter.DIRECTORY) - .asScala - .map(_.getCanonicalPath) - .toSet - - expectedFiles.foreach { f => - assert(files.contains(f), - s"\n$msg\nexpected file:\n\t$f\nfound files:\n${files.mkString("\n\t")}") - } - } - } From ad4832a9faf2c0c869bbcad9d71afe1cecbd3ec8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 1 Nov 2016 21:20:53 -0700 Subject: [PATCH 084/381] [SPARK-18216][SQL] Make Column.expr public ## What changes were proposed in this pull request? Column.expr is private[sql], but it's an actually really useful field to have for debugging. We should open it up, similar to how we use QueryExecution. ## How was this patch tested? N/A - this is a simple visibility change. Author: Reynold Xin Closes #15724 from rxin/SPARK-18216. --- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 05e867bf5be9..249408e0fbce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -118,6 +118,9 @@ class TypedColumn[-T, U]( * $"a" === $"b" * }}} * + * Note that the internal Catalyst expression can be accessed via "expr", but this method is for + * debugging purposes only and can change in any future Spark releases. + * * @groupname java_expr_ops Java-specific expression operators * @groupname expr_ops Expression operators * @groupname df_ops DataFrame functions @@ -126,7 +129,7 @@ class TypedColumn[-T, U]( * @since 1.3.0 */ @InterfaceStability.Stable -class Column(protected[sql] val expr: Expression) extends Logging { +class Column(val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) From 1ecfafa0869cb3a3e367bda8be252a69874dc4de Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 1 Nov 2016 22:14:53 -0700 Subject: [PATCH 085/381] [SPARK-17838][SPARKR] Check named arguments for options and use formatted R friendly message from JVM exception message ## What changes were proposed in this pull request? This PR proposes to - improve the R-friendly error messages rather than raw JVM exception one. As `read.json`, `read.text`, `read.orc`, `read.parquet` and `read.jdbc` are executed in the same path with `read.df`, and `write.json`, `write.text`, `write.orc`, `write.parquet` and `write.jdbc` shares the same path with `write.df`, it seems it is safe to call `handledCallJMethod` to handle JVM messages. - prevent `zero-length variable name` and prints the ignored options as an warning message. **Before** ``` r > read.json("path", a = 1, 2, 3, "a") Error in env[[name]] <- value : zero-length variable name ``` ``` r > read.json("arbitrary_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: Path does not exist: file:/...; at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$12.apply(DataSource.scala:398) ... > read.orc("arbitrary_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: Path does not exist: file:/...; at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$12.apply(DataSource.scala:398) ... > read.text("arbitrary_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: Path does not exist: file:/...; at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$12.apply(DataSource.scala:398) ... > read.parquet("arbitrary_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: Path does not exist: file:/...; at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$12.apply(DataSource.scala:398) ... ``` ``` r > write.json(df, "existing_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: path file:/... already exists.; at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:68) > write.orc(df, "existing_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: path file:/... already exists.; at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:68) > write.text(df, "existing_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: path file:/... already exists.; at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:68) > write.parquet(df, "existing_path") Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: path file:/... already exists.; at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:68) ``` **After** ``` r read.json("arbitrary_path", a = 1, 2, 3, "a") Unnamed arguments ignored: 2, 3, a. ``` ``` r > read.json("arbitrary_path") Error in json : analysis error - Path does not exist: file:/... > read.orc("arbitrary_path") Error in orc : analysis error - Path does not exist: file:/... > read.text("arbitrary_path") Error in text : analysis error - Path does not exist: file:/... > read.parquet("arbitrary_path") Error in parquet : analysis error - Path does not exist: file:/... ``` ``` r > write.json(df, "existing_path") Error in json : analysis error - path file:/... already exists.; > write.orc(df, "existing_path") Error in orc : analysis error - path file:/... already exists.; > write.text(df, "existing_path") Error in text : analysis error - path file:/... already exists.; > write.parquet(df, "existing_path") Error in parquet : analysis error - path file:/... already exists.; ``` ## How was this patch tested? Unit tests in `test_utils.R` and `test_sparkSQL.R`. Author: hyukjinkwon Closes #15608 from HyukjinKwon/SPARK-17838. --- R/pkg/R/DataFrame.R | 10 +++--- R/pkg/R/SQLContext.R | 17 ++++----- R/pkg/R/utils.R | 44 ++++++++++++++++------- R/pkg/inst/tests/testthat/test_sparkSQL.R | 16 +++++++++ R/pkg/inst/tests/testthat/test_utils.R | 2 ++ 5 files changed, 64 insertions(+), 25 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1df8bbf9fe60..1cf9b38ea648 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -788,7 +788,7 @@ setMethod("write.json", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "json", path)) + invisible(handledCallJMethod(write, "json", path)) }) #' Save the contents of SparkDataFrame as an ORC file, preserving the schema. @@ -819,7 +819,7 @@ setMethod("write.orc", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "orc", path)) + invisible(handledCallJMethod(write, "orc", path)) }) #' Save the contents of SparkDataFrame as a Parquet file, preserving the schema. @@ -851,7 +851,7 @@ setMethod("write.parquet", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "parquet", path)) + invisible(handledCallJMethod(write, "parquet", path)) }) #' @rdname write.parquet @@ -895,7 +895,7 @@ setMethod("write.text", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "text", path)) + invisible(handledCallJMethod(write, "text", path)) }) #' Distinct @@ -3342,7 +3342,7 @@ setMethod("write.jdbc", jprops <- varargsToJProperties(...) write <- callJMethod(x@sdf, "write") write <- callJMethod(write, "mode", jmode) - invisible(callJMethod(write, "jdbc", url, tableName, jprops)) + invisible(handledCallJMethod(write, "jdbc", url, tableName, jprops)) }) #' randomSplit diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 216ca51666ba..38d83c6e5c52 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -350,7 +350,7 @@ read.json.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "json", paths) + sdf <- handledCallJMethod(read, "json", paths) dataFrame(sdf) } @@ -422,7 +422,7 @@ read.orc <- function(path, ...) { path <- suppressWarnings(normalizePath(path)) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "orc", path) + sdf <- handledCallJMethod(read, "orc", path) dataFrame(sdf) } @@ -444,7 +444,7 @@ read.parquet.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "parquet", paths) + sdf <- handledCallJMethod(read, "parquet", paths) dataFrame(sdf) } @@ -496,7 +496,7 @@ read.text.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "text", paths) + sdf <- handledCallJMethod(read, "text", paths) dataFrame(sdf) } @@ -914,12 +914,13 @@ read.jdbc <- function(url, tableName, } else { numPartitions <- numToInt(numPartitions) } - sdf <- callJMethod(read, "jdbc", url, tableName, as.character(partitionColumn), - numToInt(lowerBound), numToInt(upperBound), numPartitions, jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, as.character(partitionColumn), + numToInt(lowerBound), numToInt(upperBound), numPartitions, jprops) } else if (length(predicates) > 0) { - sdf <- callJMethod(read, "jdbc", url, tableName, as.list(as.character(predicates)), jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, as.list(as.character(predicates)), + jprops) } else { - sdf <- callJMethod(read, "jdbc", url, tableName, jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, jprops) } dataFrame(sdf) } diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index c4e78cbb804d..20004549cc03 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -338,21 +338,41 @@ varargsToEnv <- function(...) { # into string. varargsToStrEnv <- function(...) { pairs <- list(...) + nameList <- names(pairs) env <- new.env() - for (name in names(pairs)) { - value <- pairs[[name]] - if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) { - stop(paste0("Unsupported type for ", name, " : ", class(value), - ". Supported types are logical, numeric, character and NULL.")) - } - if (is.logical(value)) { - env[[name]] <- tolower(as.character(value)) - } else if (is.null(value)) { - env[[name]] <- value - } else { - env[[name]] <- as.character(value) + ignoredNames <- list() + + if (is.null(nameList)) { + # When all arguments are not named, names(..) returns NULL. + ignoredNames <- pairs + } else { + for (i in seq_along(pairs)) { + name <- nameList[i] + value <- pairs[i] + if (identical(name, "")) { + # When some of arguments are not named, name is "". + ignoredNames <- append(ignoredNames, value) + } else { + value <- pairs[[name]] + if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) { + stop(paste0("Unsupported type for ", name, " : ", class(value), + ". Supported types are logical, numeric, character and NULL."), call. = FALSE) + } + if (is.logical(value)) { + env[[name]] <- tolower(as.character(value)) + } else if (is.null(value)) { + env[[name]] <- value + } else { + env[[name]] <- as.character(value) + } + } } } + + if (length(ignoredNames) != 0) { + warning(paste0("Unnamed arguments ignored: ", paste(ignoredNames, collapse = ", "), "."), + call. = FALSE) + } env } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9289db57b6d6..806019d7524f 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2660,6 +2660,14 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume # DataFrameWriter.save() without path. expect_error(write.df(df, source = "csv"), "Error in save : illegal argument - 'path' is not specified") + expect_error(write.json(df, jsonPath), + "Error in json : analysis error - path file:.*already exists") + expect_error(write.text(df, jsonPath), + "Error in text : analysis error - path file:.*already exists") + expect_error(write.orc(df, jsonPath), + "Error in orc : analysis error - path file:.*already exists") + expect_error(write.parquet(df, jsonPath), + "Error in parquet : analysis error - path file:.*already exists") # Arguments checking in R side. expect_error(write.df(df, "data.tmp", source = c(1, 2)), @@ -2679,6 +2687,11 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", "It must be specified manually")) expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") + expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") + expect_error(read.text("arbitrary_path"), "Error in text : analysis error - Path does not exist") + expect_error(read.orc("arbitrary_path"), "Error in orc : analysis error - Path does not exist") + expect_error(read.parquet("arbitrary_path"), + "Error in parquet : analysis error - Path does not exist") # Arguments checking in R side. expect_error(read.df(path = c(3)), @@ -2686,6 +2699,9 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume expect_error(read.df(jsonPath, source = c(1, 2)), paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) + + expect_warning(read.json(jsonPath, a = 1, 2, 3, "a"), + "Unnamed arguments ignored: 2, 3, a.") }) unlink(parquetPath) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index a20254e9b3fa..607c407f04f9 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -224,6 +224,8 @@ test_that("varargsToStrEnv", { expect_error(varargsToStrEnv(a = list(1, "a")), paste0("Unsupported type for a : list. Supported types are logical, ", "numeric, character and NULL.")) + expect_warning(varargsToStrEnv(a = 1, 2, 3, 4), "Unnamed arguments ignored: 2, 3, 4.") + expect_warning(varargsToStrEnv(1, 2, 3, 4), "Unnamed arguments ignored: 1, 2, 3, 4.") }) sparkR.session.stop() From 1bbf9ff634745148e782370009aa31d3a042638c Mon Sep 17 00:00:00 2001 From: Michael Allman Date: Tue, 1 Nov 2016 22:20:19 -0700 Subject: [PATCH 086/381] [SPARK-17992][SQL] Return all partitions from HiveShim when Hive throws a metastore exception when attempting to fetch partitions by filter (Link to Jira issue: https://issues.apache.org/jira/browse/SPARK-17992) ## What changes were proposed in this pull request? We recently added table partition pruning for partitioned Hive tables converted to using `TableFileCatalog`. When the Hive configuration option `hive.metastore.try.direct.sql` is set to `false`, Hive will throw an exception for unsupported filter expressions. For example, attempting to filter on an integer partition column will throw a `org.apache.hadoop.hive.metastore.api.MetaException`. I discovered this behavior because VideoAmp uses the CDH version of Hive with a Postgresql metastore DB. In this configuration, CDH sets `hive.metastore.try.direct.sql` to `false` by default, and queries that filter on a non-string partition column will fail. Rather than throw an exception in query planning, this patch catches this exception, logs a warning and returns all table partitions instead. Clients of this method are already expected to handle the possibility that the filters will not be honored. ## How was this patch tested? A unit test was added. Author: Michael Allman Closes #15673 from mallman/spark-17992-catch_hive_partition_filter_exception. --- .../spark/sql/hive/client/HiveShim.scala | 31 ++++++-- .../sql/hive/client/HiveClientBuilder.scala | 56 ++++++++++++++ .../sql/hive/client/HiveClientSuite.scala | 61 +++++++++++++++ .../spark/sql/hive/client/VersionsSuite.scala | 77 +++++-------------- 4 files changed, 160 insertions(+), 65 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 85edaf63db88..3d9642dd1463 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -29,7 +29,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{Function => HiveFunction, FunctionType, NoSuchObjectException, PrincipalType, ResourceType, ResourceUri} +import org.apache.hadoop.hive.metastore.api.{Function => HiveFunction, FunctionType, MetaException, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException, Partition, Table} import org.apache.hadoop.hive.ql.plan.AddPartitionDesc @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegralType, StringType} import org.apache.spark.util.Utils @@ -586,17 +587,31 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] } else { logDebug(s"Hive metastore filter is '$filter'.") + val tryDirectSqlConfVar = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL + val tryDirectSql = + hive.getConf.getBoolean(tryDirectSqlConfVar.varname, tryDirectSqlConfVar.defaultBoolVal) try { + // Hive may throw an exception when calling this method in some circumstances, such as + // when filtering on a non-string partition column when the hive config key + // hive.metastore.try.direct.sql is false getPartitionsByFilterMethod.invoke(hive, table, filter) .asInstanceOf[JArrayList[Partition]] } catch { - case e: InvocationTargetException => - // SPARK-18167 retry to investigate the flaky test. This should be reverted before - // the release is cut. - val retry = Try(getPartitionsByFilterMethod.invoke(hive, table, filter)) - logError("getPartitionsByFilter failed, retry success = " + retry.isSuccess) - logError("all partitions: " + getAllPartitions(hive, table)) - throw e + case ex: InvocationTargetException if ex.getCause.isInstanceOf[MetaException] && + !tryDirectSql => + logWarning("Caught Hive MetaException attempting to get partition metadata by " + + "filter from Hive. Falling back to fetching all partition metadata, which will " + + "degrade performance. Modifying your Hive metastore configuration to set " + + s"${tryDirectSqlConfVar.varname} to true may resolve this problem.", ex) + // HiveShim clients are expected to handle a superset of the requested partitions + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] + case ex: InvocationTargetException if ex.getCause.isInstanceOf[MetaException] && + tryDirectSql => + throw new RuntimeException("Caught Hive MetaException attempting to get partition " + + "metadata by filter from Hive. You can set the Spark configuration setting " + + s"${SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key} to false to work around this " + + "problem, however this will result in degraded performance. Please report a bug: " + + "https://issues.apache.org/jira/browse/SPARK", ex) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala new file mode 100644 index 000000000000..591a968c8284 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.io.File + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.util.VersionInfo + +import org.apache.spark.SparkConf +import org.apache.spark.util.Utils + +private[client] class HiveClientBuilder { + private val sparkConf = new SparkConf() + + // In order to speed up test execution during development or in Jenkins, you can specify the path + // of an existing Ivy cache: + private val ivyPath: Option[String] = { + sys.env.get("SPARK_VERSIONS_SUITE_IVY_PATH").orElse( + Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) + } + + private def buildConf() = { + lazy val warehousePath = Utils.createTempDir() + lazy val metastorePath = Utils.createTempDir() + metastorePath.delete() + Map( + "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$metastorePath;create=true", + "hive.metastore.warehouse.dir" -> warehousePath.toString) + } + + def buildClient(version: String, hadoopConf: Configuration): HiveClient = { + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = version, + hadoopVersion = VersionInfo.getVersion, + sparkConf = sparkConf, + hadoopConf = hadoopConf, + config = buildConf(), + ivyPath = ivyPath).createClient() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala new file mode 100644 index 000000000000..4790331168bd --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.conf.HiveConf + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} +import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.types.IntegerType + +class HiveClientSuite extends SparkFunSuite { + private val clientBuilder = new HiveClientBuilder + + private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname + + test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { + val testPartitionCount = 5 + + val storageFormat = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + compressed = false, + properties = Map.empty) + + val hadoopConf = new Configuration() + hadoopConf.setBoolean(tryDirectSqlKey, false) + val client = clientBuilder.buildClient(HiveUtils.hiveExecutionVersion, hadoopConf) + client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (part INT)") + + val partitions = (1 to testPartitionCount).map { part => + CatalogTablePartition(Map("part" -> part.toString), storageFormat) + } + client.createPartitions( + "default", "test", partitions, ignoreIfExists = false) + + val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), + Seq(EqualTo(AttributeReference("part", IntegerType)(), Literal(3)))) + + assert(filteredPartitions.size == testPartitionCount) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9a10957c8efa..081b0ed9bd68 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -23,9 +23,8 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat -import org.apache.hadoop.util.VersionInfo -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} @@ -48,46 +47,19 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { - private val sparkConf = new SparkConf() - - // In order to speed up test execution during development or in Jenkins, you can specify the path - // of an existing Ivy cache: - private val ivyPath: Option[String] = { - sys.env.get("SPARK_VERSIONS_SUITE_IVY_PATH").orElse( - Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) - } - - private def buildConf() = { - lazy val warehousePath = Utils.createTempDir() - lazy val metastorePath = Utils.createTempDir() - metastorePath.delete() - Map( - "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$metastorePath;create=true", - "hive.metastore.warehouse.dir" -> warehousePath.toString) - } + private val clientBuilder = new HiveClientBuilder + import clientBuilder.buildClient test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion( - hiveMetastoreVersion = HiveUtils.hiveExecutionVersion, - hadoopVersion = VersionInfo.getVersion, - sparkConf = sparkConf, - hadoopConf = new Configuration(), - config = buildConf(), - ivyPath = ivyPath).createClient() + val badClient = buildClient(HiveUtils.hiveExecutionVersion, new Configuration()) val db = new CatalogDatabase("default", "desc", "loc", Map()) badClient.createDatabase(db, ignoreIfExists = true) } test("hadoop configuration preserved") { - val hadoopConf = new Configuration(); + val hadoopConf = new Configuration() hadoopConf.set("test", "success") - val client = IsolatedClientLoader.forVersion( - hiveMetastoreVersion = HiveUtils.hiveExecutionVersion, - hadoopVersion = VersionInfo.getVersion, - sparkConf = sparkConf, - hadoopConf = hadoopConf, - config = buildConf(), - ivyPath = ivyPath).createClient() + val client = buildClient(HiveUtils.hiveExecutionVersion, hadoopConf) assert("success" === client.getConf("test", null)) } @@ -109,15 +81,7 @@ class VersionsSuite extends SparkFunSuite with Logging { // TODO: currently only works on mysql where we manually create the schema... ignore("failure sanity check") { val e = intercept[Throwable] { - val badClient = quietly { - IsolatedClientLoader.forVersion( - hiveMetastoreVersion = "13", - hadoopVersion = VersionInfo.getVersion, - sparkConf = sparkConf, - hadoopConf = new Configuration(), - config = buildConf(), - ivyPath = ivyPath).createClient() - } + val badClient = quietly { buildClient("13", new Configuration()) } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } @@ -130,16 +94,9 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: create client") { client = null System.gc() // Hack to avoid SEGV on some JVM versions. - val hadoopConf = new Configuration(); + val hadoopConf = new Configuration() hadoopConf.set("test", "success") - client = - IsolatedClientLoader.forVersion( - hiveMetastoreVersion = version, - hadoopVersion = VersionInfo.getVersion, - sparkConf = sparkConf, - hadoopConf = hadoopConf, - config = buildConf(), - ivyPath = ivyPath).createClient() + client = buildClient(version, hadoopConf) } def table(database: String, tableName: String): CatalogTable = { @@ -287,15 +244,19 @@ class VersionsSuite extends SparkFunSuite with Logging { client.runSqlHive("CREATE TABLE src_part (value INT) PARTITIONED BY (key1 INT, key2 INT)") } + val testPartitionCount = 2 + test(s"$version: createPartitions") { - val partition1 = CatalogTablePartition(Map("key1" -> "1", "key2" -> "1"), storageFormat) - val partition2 = CatalogTablePartition(Map("key1" -> "1", "key2" -> "2"), storageFormat) + val partitions = (1 to testPartitionCount).map { key2 => + CatalogTablePartition(Map("key1" -> "1", "key2" -> key2.toString), storageFormat) + } client.createPartitions( - "default", "src_part", Seq(partition1, partition2), ignoreIfExists = true) + "default", "src_part", partitions, ignoreIfExists = true) } test(s"$version: getPartitions(catalogTable)") { - assert(2 == client.getPartitions(client.getTable("default", "src_part")).size) + assert(testPartitionCount == + client.getPartitions(client.getTable("default", "src_part")).size) } test(s"$version: getPartitionsByFilter") { @@ -306,6 +267,8 @@ class VersionsSuite extends SparkFunSuite with Logging { // Hive 0.12 doesn't support getPartitionsByFilter, it ignores the filter condition. if (version != "0.12") { assert(result.size == 1) + } else { + assert(result.size == testPartitionCount) } } @@ -327,7 +290,7 @@ class VersionsSuite extends SparkFunSuite with Logging { } test(s"$version: getPartitions(db: String, table: String)") { - assert(2 == client.getPartitions("default", "src_part", None).size) + assert(testPartitionCount == client.getPartitions("default", "src_part", None).size) } test(s"$version: loadPartition") { From 620da3b4828b3580c7ed7339b2a07938e6be1bb1 Mon Sep 17 00:00:00 2001 From: frreiss Date: Tue, 1 Nov 2016 23:00:17 -0700 Subject: [PATCH 087/381] [SPARK-17475][STREAMING] Delete CRC files if the filesystem doesn't use checksum files ## What changes were proposed in this pull request? When the metadata logs for various parts of Structured Streaming are stored on non-HDFS filesystems such as NFS or ext4, the HDFSMetadataLog class leaves hidden HDFS-style checksum (CRC) files in the log directory, one file per batch. This PR modifies HDFSMetadataLog so that it detects the use of a filesystem that doesn't use CRC files and removes the CRC files. ## How was this patch tested? Modified an existing test case in HDFSMetadataLogSuite to check whether HDFSMetadataLog correctly removes CRC files on the local POSIX filesystem. Ran the entire regression suite. Author: frreiss Closes #15027 from frreiss/fred-17475. --- .../spark/sql/execution/streaming/HDFSMetadataLog.scala | 5 +++++ .../sql/execution/streaming/HDFSMetadataLogSuite.scala | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index c7235320fd6b..9a0f87cf0498 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -148,6 +148,11 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) // It will fail if there is an existing file (someone has committed the batch) logDebug(s"Attempting to write log #${batchIdToPath(batchId)}") fileManager.rename(tempPath, batchIdToPath(batchId)) + + // SPARK-17475: HDFSMetadataLog should not leak CRC files + // If the underlying filesystem didn't rename the CRC file, delete it. + val crcPath = new Path(tempPath.getParent(), s".${tempPath.getName()}.crc") + if (fileManager.exists(crcPath)) fileManager.delete(crcPath) return } catch { case e: IOException if isFileAlreadyExistsException(e) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 9c1d26dcb224..d03e08d9a576 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -119,6 +119,12 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { assert(metadataLog.get(1).isEmpty) assert(metadataLog.get(2).isDefined) assert(metadataLog.getLatest().get._1 == 2) + + // There should be exactly one file, called "2", in the metadata directory. + // This check also tests for regressions of SPARK-17475 + val allFiles = new File(metadataLog.metadataPath.toString).listFiles().toSeq + assert(allFiles.size == 1) + assert(allFiles(0).getName() == "2") } } From abefe2ec428dc24a4112c623fb6fbe4b2ca60a2b Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 2 Nov 2016 14:15:10 +0800 Subject: [PATCH 088/381] [SPARK-18183][SPARK-18184] Fix INSERT [INTO|OVERWRITE] TABLE ... PARTITION for Datasource tables ## What changes were proposed in this pull request? There are a couple issues with the current 2.1 behavior when inserting into Datasource tables with partitions managed by Hive. (1) OVERWRITE TABLE ... PARTITION will actually overwrite the entire table instead of just the specified partition. (2) INSERT|OVERWRITE does not work with partitions that have custom locations. This PR fixes both of these issues for Datasource tables managed by Hive. The behavior for legacy tables or when `manageFilesourcePartitions = false` is unchanged. There is one other issue in that INSERT OVERWRITE with dynamic partitions will overwrite the entire table instead of just the updated partitions, but this behavior is pretty complicated to implement for Datasource tables. We should address that in a future release. ## How was this patch tested? Unit tests. Author: Eric Liang Closes #15705 from ericl/sc-4942. --- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 9 +++- .../plans/logical/basicLogicalOperators.scala | 19 ++++++- .../sql/catalyst/parser/PlanParserSuite.scala | 15 ++++-- .../apache/spark/sql/DataFrameWriter.scala | 4 +- .../datasources/CatalogFileIndex.scala | 5 +- .../datasources/DataSourceStrategy.scala | 30 +++++++++-- .../InsertIntoDataSourceCommand.scala | 6 +-- .../spark/sql/hive/HiveStrategies.scala | 3 +- .../CreateHiveTableAsSelectCommand.scala | 5 +- .../PartitionProviderCompatibilitySuite.scala | 52 +++++++++++++++++++ 11 files changed, 129 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 66e52ca68af1..e901683be685 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -367,7 +367,7 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(TableIdentifier(tableName)), - Map.empty, logicalPlan, overwrite, false) + Map.empty, logicalPlan, OverwriteOptions(overwrite), false) def as(alias: String): LogicalPlan = logicalPlan match { case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 38e9bb6c162a..ac1577b3abb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -177,12 +177,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " + "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx) } + val overwrite = ctx.OVERWRITE != null + val overwritePartition = + if (overwrite && partitionKeys.nonEmpty && dynamicPartitionKeys.isEmpty) { + Some(partitionKeys.map(t => (t._1, t._2.get))) + } else { + None + } InsertIntoTable( UnresolvedRelation(tableIdent, None), partitionKeys, query, - ctx.OVERWRITE != null, + OverwriteOptions(overwrite, overwritePartition), ctx.EXISTS != null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a48974c6322a..7a15c2285d58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.catalog.CatalogTypes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -345,18 +346,32 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override lazy val statistics: Statistics = super.statistics.copy(isBroadcastable = true) } +/** + * Options for writing new data into a table. + * + * @param enabled whether to overwrite existing data in the table. + * @param specificPartition only data in the specified partition will be overwritten. + */ +case class OverwriteOptions( + enabled: Boolean, + specificPartition: Option[CatalogTypes.TablePartitionSpec] = None) { + if (specificPartition.isDefined) { + assert(enabled, "Overwrite must be enabled when specifying a partition to overwrite.") + } +} + case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], child: LogicalPlan, - overwrite: Boolean, + overwrite: OverwriteOptions, ifNotExists: Boolean) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty - assert(overwrite || !ifNotExists) + assert(overwrite.enabled || !ifNotExists) assert(partition.values.forall(_.nonEmpty) || !ifNotExists) override lazy val resolved: Boolean = childrenResolved && table.resolved diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index ca86304d4d40..7400f3430e99 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -180,7 +180,16 @@ class PlanParserSuite extends PlanTest { partition: Map[String, Option[String]], overwrite: Boolean = false, ifNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + InsertIntoTable( + table("s"), partition, plan, + OverwriteOptions( + overwrite, + if (overwrite && partition.nonEmpty) { + Some(partition.map(kv => (kv._1, kv._2.get))) + } else { + None + }), + ifNotExists) // Single inserts assertEqual(s"insert overwrite table s $sql", @@ -196,9 +205,9 @@ class PlanParserSuite extends PlanTest { val plan2 = table("t").where('x > 5).select(star()) assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", InsertIntoTable( - table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + table("s"), Map.empty, plan.limit(1), OverwriteOptions(false), ifNotExists = false).union( InsertIntoTable( - table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + table("u"), Map.empty, plan2, OverwriteOptions(false), ifNotExists = false))) } test ("insert with if not exists") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 11dd1df90993..700f4835ac89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,7 +25,7 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Union} +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, OverwriteOptions, Union} import org.apache.spark.sql.execution.command.AlterTableRecoverPartitionsCommand import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, CreateTable, DataSource, HadoopFsRelation} import org.apache.spark.sql.types.StructType @@ -259,7 +259,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], child = df.logicalPlan, - overwrite = mode == SaveMode.Overwrite, + overwrite = OverwriteOptions(mode == SaveMode.Overwrite), ifNotExists = false)).toRdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 092aabc89a36..443a2ec033a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -67,7 +67,10 @@ class CatalogFileIndex( val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( table.identifier, filters) val partitions = selectedPartitions.map { p => - PartitionPath(p.toRow(partitionSchema), p.storage.locationUri.get) + val path = new Path(p.storage.locationUri.get) + val fs = path.getFileSystem(hadoopConf) + PartitionPath( + p.toRow(partitionSchema), path.makeQualified(fs.getUri, fs.getWorkingDirectory)) } val partitionSpec = PartitionSpec(partitionSchema, partitions) new PrunedInMemoryFileIndex( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 34b77cab65de..47c1f9d3fac1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.fs.Path + import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ @@ -174,14 +176,32 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { case LogicalRelation(r: HadoopFsRelation, _, _) => r.location.rootPaths }.flatten - val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append - if (overwrite && inputPaths.contains(outputPath)) { + val mode = if (overwrite.enabled) SaveMode.Overwrite else SaveMode.Append + if (overwrite.enabled && inputPaths.contains(outputPath)) { throw new AnalysisException( "Cannot overwrite a path that is also being read from.") } + val overwritingSinglePartition = (overwrite.specificPartition.isDefined && + t.sparkSession.sessionState.conf.manageFilesourcePartitions && + l.catalogTable.get.partitionProviderIsHive) + + val effectiveOutputPath = if (overwritingSinglePartition) { + val partition = t.sparkSession.sessionState.catalog.getPartition( + l.catalogTable.get.identifier, overwrite.specificPartition.get) + new Path(partition.storage.locationUri.get) + } else { + outputPath + } + + val effectivePartitionSchema = if (overwritingSinglePartition) { + Nil + } else { + query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) + } + def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = { - if (l.catalogTable.isDefined && + if (l.catalogTable.isDefined && updatedPartitions.nonEmpty && l.catalogTable.get.partitionColumnNames.nonEmpty && l.catalogTable.get.partitionProviderIsHive) { val metastoreUpdater = AlterTableAddPartitionCommand( @@ -194,8 +214,8 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { } val insertCmd = InsertIntoHadoopFsRelationCommand( - outputPath, - query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver), + effectiveOutputPath, + effectivePartitionSchema, t.bucketSpec, t.fileFormat, refreshPartitionsCallback, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index b2ff68a833fe..2eba1e9986ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OverwriteOptions} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.sources.InsertableRelation @@ -30,7 +30,7 @@ import org.apache.spark.sql.sources.InsertableRelation case class InsertIntoDataSourceCommand( logicalRelation: LogicalRelation, query: LogicalPlan, - overwrite: Boolean) + overwrite: OverwriteOptions) extends RunnableCommand { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) @@ -40,7 +40,7 @@ case class InsertIntoDataSourceCommand( val data = Dataset.ofRows(sparkSession, query) // Apply the schema of the existing table to the new data. val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite) + relation.insert(df, overwrite.enabled) // Invalidate the cache. sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 9d2930948d6b..ce1e3eb1a5bc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -46,7 +46,8 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable( table: MetastoreRelation, partition, child, overwrite, ifNotExists) => - InsertIntoHiveTable(table, partition, planLater(child), overwrite, ifNotExists) :: Nil + InsertIntoHiveTable( + table, partition, planLater(child), overwrite.enabled, ifNotExists) :: Nil case CreateTable(tableDesc, mode, Some(query)) if tableDesc.provider.get == "hive" => val newTableDesc = if (tableDesc.storage.serde.isEmpty) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index ef5a5a001fb6..cac43597aef2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -21,7 +21,7 @@ import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, OverwriteOptions} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.hive.MetastoreRelation @@ -88,7 +88,8 @@ case class CreateHiveTableAsSelectCommand( } else { try { sparkSession.sessionState.executePlan(InsertIntoTable( - metastoreRelation, Map(), query, overwrite = true, ifNotExists = false)).toRdd + metastoreRelation, Map(), query, overwrite = OverwriteOptions(true), + ifNotExists = false)).toRdd } catch { case NonFatal(e) => // drop the created table. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index 5f16960fb149..ac435bf6195b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -134,4 +134,56 @@ class PartitionProviderCompatibilitySuite } } } + + test("insert overwrite partition of legacy datasource table overwrites entire table") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { + withTable("test") { + withTempDir { dir => + setupPartitionedDatasourceTable("test", dir) + spark.sql( + """insert overwrite table test + |partition (partCol=1) + |select * from range(100)""".stripMargin) + assert(spark.sql("select * from test").count() == 100) + + // Dynamic partitions case + spark.sql("insert overwrite table test select id, id from range(10)".stripMargin) + assert(spark.sql("select * from test").count() == 10) + } + } + } + } + + test("insert overwrite partition of new datasource table overwrites just partition") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + withTable("test") { + withTempDir { dir => + setupPartitionedDatasourceTable("test", dir) + sql("msck repair table test") + spark.sql( + """insert overwrite table test + |partition (partCol=1) + |select * from range(100)""".stripMargin) + assert(spark.sql("select * from test").count() == 104) + + // Test overwriting a partition that has a custom location + withTempDir { dir2 => + sql( + s"""alter table test partition (partCol=1) + |set location '${dir2.getAbsolutePath}'""".stripMargin) + assert(sql("select * from test").count() == 4) + sql( + """insert overwrite table test + |partition (partCol=1) + |select * from range(30)""".stripMargin) + sql( + """insert overwrite table test + |partition (partCol=1) + |select * from range(20)""".stripMargin) + assert(sql("select * from test").count() == 24) + } + } + } + } + } } From a36653c5b7b2719f8bfddf4ddfc6e1b828ac9af1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 1 Nov 2016 23:37:03 -0700 Subject: [PATCH 089/381] [SPARK-18192] Support all file formats in structured streaming ## What changes were proposed in this pull request? This patch adds support for all file formats in structured streaming sinks. This is actually a very small change thanks to all the previous refactoring done using the new internal commit protocol API. ## How was this patch tested? Updated FileStreamSinkSuite to add test cases for json, text, and parquet. Author: Reynold Xin Closes #15711 from rxin/SPARK-18192. --- .../execution/datasources/DataSource.scala | 8 +-- .../sql/streaming/FileStreamSinkSuite.scala | 62 +++++++++---------- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index d980e6a15aab..3f956c427655 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -29,7 +29,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -37,7 +36,6 @@ import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{CalendarIntervalType, StructType} @@ -292,7 +290,7 @@ case class DataSource( case s: StreamSinkProvider => s.createSink(sparkSession.sqlContext, options, partitionColumns, outputMode) - case parquet: parquet.ParquetFileFormat => + case fileFormat: FileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") @@ -301,7 +299,7 @@ case class DataSource( throw new IllegalArgumentException( s"Data source $className does not support $outputMode output mode") } - new FileStreamSink(sparkSession, path, parquet, partitionColumns, options) + new FileStreamSink(sparkSession, path, fileFormat, partitionColumns, options) case _ => throw new UnsupportedOperationException( @@ -516,7 +514,7 @@ case class DataSource( val plan = data.logicalPlan plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { throw new AnalysisException( - s"Unable to resolve ${name} given [${plan.output.map(_.name).mkString(", ")}]") + s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]") }.asInstanceOf[Attribute] } // For partitioned relation r, r.schema's column ordering can be different from the column diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 902cf0534471..0f140f94f630 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import org.apache.spark.sql._ +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.streaming.{MemoryStream, MetadataLogFileIndex} @@ -142,42 +142,38 @@ class FileStreamSinkSuite extends StreamTest { } } - test("FileStreamSink - supported formats") { - def testFormat(format: Option[String]): Unit = { - val inputData = MemoryStream[Int] - val ds = inputData.toDS() + test("FileStreamSink - parquet") { + testFormat(None) // should not throw error as default format parquet when not specified + testFormat(Some("parquet")) + } - val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath - val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + test("FileStreamSink - text") { + testFormat(Some("text")) + } - var query: StreamingQuery = null + test("FileStreamSink - json") { + testFormat(Some("text")) + } - try { - val writer = - ds.map(i => (i, i * 1000)) - .toDF("id", "value") - .writeStream - if (format.nonEmpty) { - writer.format(format.get) - } - query = writer - .option("checkpointLocation", checkpointDir) - .start(outputDir) - } finally { - if (query != null) { - query.stop() - } - } - } + def testFormat(format: Option[String]): Unit = { + val inputData = MemoryStream[Int] + val ds = inputData.toDS() - testFormat(None) // should not throw error as default format parquet when not specified - testFormat(Some("parquet")) - val e = intercept[UnsupportedOperationException] { - testFormat(Some("text")) - } - Seq("text", "not support", "stream").foreach { s => - assert(e.getMessage.contains(s)) + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + var query: StreamingQuery = null + + try { + val writer = ds.map(i => (i, i * 1000)).toDF("id", "value").writeStream + if (format.nonEmpty) { + writer.format(format.get) + } + query = writer.option("checkpointLocation", checkpointDir).start(outputDir) + } finally { + if (query != null) { + query.stop() + } } } - } From 85c5424d466f4a5765c825e0e2ab30da97611285 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 1 Nov 2016 23:39:53 -0700 Subject: [PATCH 090/381] [SPARK-18144][SQL] logging StreamingQueryListener$QueryStartedEvent ## What changes were proposed in this pull request? The PR fixes the bug that the QueryStartedEvent is not logged the postToAll() in the original code is actually calling StreamingQueryListenerBus.postToAll() which has no listener at all....we shall post by sparkListenerBus.postToAll(s) and this.postToAll() to trigger local listeners as well as the listeners registered in LiveListenerBus zsxwing ## How was this patch tested? The following snapshot shows that QueryStartedEvent has been logged correctly ![image](https://cloud.githubusercontent.com/assets/678008/19821553/007a7d28-9d2d-11e6-9f13-49851559cdaa.png) Author: CodingCat Closes #15675 from CodingCat/SPARK-18144. --- .../streaming/StreamingQueryListenerBus.scala | 10 +++++++++- .../spark/sql/streaming/StreamingQuerySuite.scala | 7 ++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index fc2190d39da4..22e4c6380fcd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -41,6 +41,8 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) def post(event: StreamingQueryListener.Event) { event match { case s: QueryStartedEvent => + sparkListenerBus.post(s) + // post to local listeners to trigger callbacks postToAll(s) case _ => sparkListenerBus.post(event) @@ -50,7 +52,13 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { case e: StreamingQueryListener.Event => - postToAll(e) + // SPARK-18144: we broadcast QueryStartedEvent to all listeners attached to this bus + // synchronously and the ones attached to LiveListenerBus asynchronously. Therefore, + // we need to ignore QueryStartedEvent if this method is called within SparkListenerBus + // thread + if (!LiveListenerBus.withinListenerThread.value || !e.isInstanceOf[QueryStartedEvent]) { + postToAll(e) + } case _ => } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 464c443beb6e..31b7fe0b04da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -290,7 +290,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { // A StreamingQueryListener that gets the query status after the first completed trigger val listener = new StreamingQueryListener { @volatile var firstStatus: StreamingQueryStatus = null - override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { } + @volatile var queryStartedEvent = 0 + override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { + queryStartedEvent += 1 + } override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = { if (firstStatus == null) firstStatus = queryProgress.queryStatus } @@ -303,6 +306,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { q.processAllAvailable() eventually(timeout(streamingTimeout)) { assert(listener.firstStatus != null) + // test if QueryStartedEvent callback is called for only once + assert(listener.queryStartedEvent === 1) } listener.firstStatus } finally { From 2dc048081668665f85623839d5f663b402e42555 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 2 Nov 2016 00:08:30 -0700 Subject: [PATCH 091/381] [SPARK-17532] Add lock debugging info to thread dumps. ## What changes were proposed in this pull request? This adds information to the web UI thread dump page about the JVM locks held by threads and the locks that threads are blocked waiting to acquire. This should help find cases where lock contention is causing Spark applications to run slowly. ## How was this patch tested? Tested by applying this patch and viewing the change in the web UI. ![thread-lock-info](https://cloud.githubusercontent.com/assets/87915/18493057/6e5da870-79c3-11e6-8c20-f54c18a37544.png) Additions: - A "Thread Locking" column with the locks held by the thread or that are blocking the thread - Links from the a blocked thread to the thread holding the lock - Stack frames show where threads are inside `synchronized` blocks, "holding Monitor(...)" Author: Ryan Blue Closes #15088 from rdblue/SPARK-17532-add-thread-lock-info. --- .../org/apache/spark/ui/static/table.js | 3 +- .../ui/exec/ExecutorThreadDumpPage.scala | 12 +++++++ .../apache/spark/util/ThreadStackTrace.scala | 6 +++- .../scala/org/apache/spark/util/Utils.scala | 34 ++++++++++++++++--- 4 files changed, 49 insertions(+), 6 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 14b06bfe860e..0315ebf5c48a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -36,7 +36,7 @@ function toggleThreadStackTrace(threadId, forceAdd) { if (stackTrace.length == 0) { var stackTraceText = $('#' + threadId + "_td_stacktrace").html() var threadCell = $("#thread_" + threadId + "_tr") - threadCell.after("
" +
+        threadCell.after("
" +
             stackTraceText +  "
") } else { if (!forceAdd) { @@ -73,6 +73,7 @@ function onMouseOverAndOut(threadId) { $("#" + threadId + "_td_id").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_name").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_state").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_locking").toggleClass("threaddump-td-mouseover"); } function onSearchStringChange() { diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index a0ef80d9bdae..c6a07445f2a3 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -48,6 +48,16 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } }.map { thread => val threadId = thread.threadId + val blockedBy = thread.blockedByThreadId match { + case Some(blockedByThreadId) => + + case None => Text("") + } + val heldLocks = thread.holdingLocks.mkString(", ") + {threadId} {thread.threadName} {thread.threadState} + {blockedBy}{heldLocks} {thread.stackTrace} } @@ -86,6 +97,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage Thread ID Thread Name Thread State + Thread Locks {dumpRows} diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala index d4e0ad93b966..b1217980faf1 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala @@ -24,4 +24,8 @@ private[spark] case class ThreadStackTrace( threadId: Long, threadName: String, threadState: Thread.State, - stackTrace: String) + stackTrace: String, + blockedByThreadId: Option[Long], + blockedByLock: String, + holdingLocks: Seq[String]) + diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6027b07c0fee..22c28fba2087 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io._ -import java.lang.management.ManagementFactory +import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo} import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels @@ -2096,15 +2096,41 @@ private[spark] object Utils extends Logging { } } + private implicit class Lock(lock: LockInfo) { + def lockString: String = { + lock match { + case monitor: MonitorInfo => + s"Monitor(${lock.getClassName}@${lock.getIdentityHashCode}})" + case _ => + s"Lock(${lock.getClassName}@${lock.getIdentityHashCode}})" + } + } + } + /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ def getThreadDump(): Array[ThreadStackTrace] = { // We need to filter out null values here because dumpAllThreads() may return null array // elements for threads that are dead / don't exist. val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) threadInfos.sortBy(_.getThreadId).map { case threadInfo => - val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n") - ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, - threadInfo.getThreadState, stackTrace) + val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap + val stackTrace = threadInfo.getStackTrace.map { frame => + monitors.get(frame) match { + case Some(monitor) => + monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" + case None => + frame.toString + } + }.mkString("\n") + + // use a set to dedup re-entrant locks that are held at multiple places + val heldLocks = (threadInfo.getLockedSynchronizers.map(_.lockString) + ++ threadInfo.getLockedMonitors.map(_.lockString) + ).toSet + + ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, threadInfo.getThreadState, + stackTrace, if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId), + Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), heldLocks.toSeq) } } From bcbe44440e6c871e217f06d2a4696fd41f1d2606 Mon Sep 17 00:00:00 2001 From: Maria Rydzy Date: Wed, 2 Nov 2016 09:09:16 +0000 Subject: [PATCH 092/381] [MINOR] Use <= for clarity in Pi examples' Monte Carlo process ## What changes were proposed in this pull request? If my understanding is correct we should be rather looking at closed disk than the opened one. ## How was this patch tested? Run simple comparison, of the mean squared error of approaches with closed and opened disk. https://gist.github.com/mrydzy/1cf0e5c316ef9d6fbd91426b91f1969f The closed one performed slightly better, but the tested sample wasn't too big, so I rely mostly on the algorithm understanding. Author: Maria Rydzy Closes #15687 from mrydzy/master. --- .../src/main/java/org/apache/spark/examples/JavaSparkPi.java | 2 +- examples/src/main/python/pi.py | 2 +- examples/src/main/scala/org/apache/spark/examples/LocalPi.scala | 2 +- examples/src/main/scala/org/apache/spark/examples/SparkPi.scala | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index 7df145e3117b..89855e81f1f7 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -54,7 +54,7 @@ public static void main(String[] args) throws Exception { public Integer call(Integer integer) { double x = Math.random() * 2 - 1; double y = Math.random() * 2 - 1; - return (x * x + y * y < 1) ? 1 : 0; + return (x * x + y * y <= 1) ? 1 : 0; } }).reduce(new Function2() { @Override diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index e3f0c4aeef1b..37029b76798f 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -38,7 +38,7 @@ def f(_): x = random() * 2 - 1 y = random() * 2 - 1 - return 1 if x ** 2 + y ** 2 < 1 else 0 + return 1 if x ** 2 + y ** 2 <= 1 else 0 count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add) print("Pi is roughly %f" % (4.0 * count / n)) diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index 720d92fb9d02..121b768e4198 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -26,7 +26,7 @@ object LocalPi { for (i <- 1 to 100000) { val x = random * 2 - 1 val y = random * 2 - 1 - if (x*x + y*y < 1) count += 1 + if (x*x + y*y <= 1) count += 1 } println("Pi is roughly " + 4 * count / 100000.0) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 272c1a4fc2f4..a5cacf17a5cc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -34,7 +34,7 @@ object SparkPi { val count = spark.sparkContext.parallelize(1 until n, slices).map { i => val x = random * 2 - 1 val y = random * 2 - 1 - if (x*x + y*y < 1) 1 else 0 + if (x*x + y*y <= 1) 1 else 0 }.reduce(_ + _) println("Pi is roughly " + 4.0 * count / (n - 1)) spark.stop() From 98ede49496d0d7b4724085083d4f24436b92a7bf Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Wed, 2 Nov 2016 09:10:34 +0000 Subject: [PATCH 093/381] [SPARK-18198][DOC][STREAMING] Highlight code snippets ## What changes were proposed in this pull request? This patch uses `{% highlight lang %}...{% endhighlight %}` to highlight code snippets in the `Structured Streaming Kafka010 integration doc` and the `Spark Streaming Kafka010 integration doc`. This patch consists of two commits: - the first commit fixes only the leading spaces -- this is large - the second commit adds the highlight instructions -- this is much simpler and easier to review ## How was this patch tested? SKIP_API=1 jekyll build ## Screenshots **Before** ![snip20161101_3](https://cloud.githubusercontent.com/assets/15843379/19894258/47746524-a087-11e6-9a2a-7bff2d428d44.png) **After** ![snip20161101_1](https://cloud.githubusercontent.com/assets/15843379/19894324/8bebcd1e-a087-11e6-835b-88c4d2979cfa.png) Author: Liwei Lin Closes #15715 from lw-lin/doc-highlight-code-snippet. --- docs/streaming-kafka-0-10-integration.md | 391 +++++++++--------- .../structured-streaming-kafka-integration.md | 156 +++---- 2 files changed, 287 insertions(+), 260 deletions(-) diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index c1ef396907db..b645d3c3a4b5 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -17,69 +17,72 @@ For Scala/Java applications using SBT/Maven project definitions, link your strea
- import org.apache.kafka.clients.consumer.ConsumerRecord - import org.apache.kafka.common.serialization.StringDeserializer - import org.apache.spark.streaming.kafka010._ - import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent - import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe - - val kafkaParams = Map[String, Object]( - "bootstrap.servers" -> "localhost:9092,anotherhost:9092", - "key.deserializer" -> classOf[StringDeserializer], - "value.deserializer" -> classOf[StringDeserializer], - "group.id" -> "use_a_separate_group_id_for_each_stream", - "auto.offset.reset" -> "latest", - "enable.auto.commit" -> (false: java.lang.Boolean) - ) - - val topics = Array("topicA", "topicB") - val stream = KafkaUtils.createDirectStream[String, String]( - streamingContext, - PreferConsistent, - Subscribe[String, String](topics, kafkaParams) - ) - - stream.map(record => (record.key, record.value)) - +{% highlight scala %} +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.serialization.StringDeserializer +import org.apache.spark.streaming.kafka010._ +import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent +import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe + +val kafkaParams = Map[String, Object]( + "bootstrap.servers" -> "localhost:9092,anotherhost:9092", + "key.deserializer" -> classOf[StringDeserializer], + "value.deserializer" -> classOf[StringDeserializer], + "group.id" -> "use_a_separate_group_id_for_each_stream", + "auto.offset.reset" -> "latest", + "enable.auto.commit" -> (false: java.lang.Boolean) +) + +val topics = Array("topicA", "topicB") +val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Subscribe[String, String](topics, kafkaParams) +) + +stream.map(record => (record.key, record.value)) +{% endhighlight %} Each item in the stream is a [ConsumerRecord](http://kafka.apache.org/0100/javadoc/org/apache/kafka/clients/consumer/ConsumerRecord.html)
- import java.util.*; - import org.apache.spark.SparkConf; - import org.apache.spark.TaskContext; - import org.apache.spark.api.java.*; - import org.apache.spark.api.java.function.*; - import org.apache.spark.streaming.api.java.*; - import org.apache.spark.streaming.kafka010.*; - import org.apache.kafka.clients.consumer.ConsumerRecord; - import org.apache.kafka.common.TopicPartition; - import org.apache.kafka.common.serialization.StringDeserializer; - import scala.Tuple2; - - Map kafkaParams = new HashMap<>(); - kafkaParams.put("bootstrap.servers", "localhost:9092,anotherhost:9092"); - kafkaParams.put("key.deserializer", StringDeserializer.class); - kafkaParams.put("value.deserializer", StringDeserializer.class); - kafkaParams.put("group.id", "use_a_separate_group_id_for_each_stream"); - kafkaParams.put("auto.offset.reset", "latest"); - kafkaParams.put("enable.auto.commit", false); - - Collection topics = Arrays.asList("topicA", "topicB"); - - final JavaInputDStream> stream = - KafkaUtils.createDirectStream( - streamingContext, - LocationStrategies.PreferConsistent(), - ConsumerStrategies.Subscribe(topics, kafkaParams) - ); - - stream.mapToPair( - new PairFunction, String, String>() { - @Override - public Tuple2 call(ConsumerRecord record) { - return new Tuple2<>(record.key(), record.value()); - } - }) +{% highlight java %} +import java.util.*; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.*; +import org.apache.spark.streaming.api.java.*; +import org.apache.spark.streaming.kafka010.*; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.StringDeserializer; +import scala.Tuple2; + +Map kafkaParams = new HashMap<>(); +kafkaParams.put("bootstrap.servers", "localhost:9092,anotherhost:9092"); +kafkaParams.put("key.deserializer", StringDeserializer.class); +kafkaParams.put("value.deserializer", StringDeserializer.class); +kafkaParams.put("group.id", "use_a_separate_group_id_for_each_stream"); +kafkaParams.put("auto.offset.reset", "latest"); +kafkaParams.put("enable.auto.commit", false); + +Collection topics = Arrays.asList("topicA", "topicB"); + +final JavaInputDStream> stream = + KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Subscribe(topics, kafkaParams) + ); + +stream.mapToPair( + new PairFunction, String, String>() { + @Override + public Tuple2 call(ConsumerRecord record) { + return new Tuple2<>(record.key(), record.value()); + } + }) +{% endhighlight %}
@@ -109,32 +112,35 @@ If you have a use case that is better suited to batch processing, you can create
- // Import dependencies and create kafka params as in Create Direct Stream above - - val offsetRanges = Array( - // topic, partition, inclusive starting offset, exclusive ending offset - OffsetRange("test", 0, 0, 100), - OffsetRange("test", 1, 0, 100) - ) +{% highlight scala %} +// Import dependencies and create kafka params as in Create Direct Stream above - val rdd = KafkaUtils.createRDD[String, String](sparkContext, kafkaParams, offsetRanges, PreferConsistent) +val offsetRanges = Array( + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange("test", 0, 0, 100), + OffsetRange("test", 1, 0, 100) +) +val rdd = KafkaUtils.createRDD[String, String](sparkContext, kafkaParams, offsetRanges, PreferConsistent) +{% endhighlight %}
- // Import dependencies and create kafka params as in Create Direct Stream above - - OffsetRange[] offsetRanges = { - // topic, partition, inclusive starting offset, exclusive ending offset - OffsetRange.create("test", 0, 0, 100), - OffsetRange.create("test", 1, 0, 100) - }; - - JavaRDD> rdd = KafkaUtils.createRDD( - sparkContext, - kafkaParams, - offsetRanges, - LocationStrategies.PreferConsistent() - ); +{% highlight java %} +// Import dependencies and create kafka params as in Create Direct Stream above + +OffsetRange[] offsetRanges = { + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange.create("test", 0, 0, 100), + OffsetRange.create("test", 1, 0, 100) +}; + +JavaRDD> rdd = KafkaUtils.createRDD( + sparkContext, + kafkaParams, + offsetRanges, + LocationStrategies.PreferConsistent() +); +{% endhighlight %}
@@ -144,29 +150,33 @@ Note that you cannot use `PreferBrokers`, because without the stream there is no
- stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - rdd.foreachPartition { iter => - val o: OffsetRange = offsetRanges(TaskContext.get.partitionId) - println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") - } - } +{% highlight scala %} +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd.foreachPartition { iter => + val o: OffsetRange = offsetRanges(TaskContext.get.partitionId) + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } +} +{% endhighlight %}
- stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - final OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - rdd.foreachPartition(new VoidFunction>>() { - @Override - public void call(Iterator> consumerRecords) { - OffsetRange o = offsetRanges[TaskContext.get().partitionId()]; - System.out.println( - o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset()); - } - }); - } - }); +{% highlight java %} +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + final OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + rdd.foreachPartition(new VoidFunction>>() { + @Override + public void call(Iterator> consumerRecords) { + OffsetRange o = offsetRanges[TaskContext.get().partitionId()]; + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset()); + } + }); + } +}); +{% endhighlight %}
@@ -183,25 +193,28 @@ Kafka has an offset commit API that stores offsets in a special Kafka topic. By
- stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - - // some time later, after outputs have completed - stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges) - } - +{% highlight scala %} +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + + // some time later, after outputs have completed + stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges) +} +{% endhighlight %} As with HasOffsetRanges, the cast to CanCommitOffsets will only succeed if called on the result of createDirectStream, not after transformations. The commitAsync call is threadsafe, but must occur after outputs if you want meaningful semantics.
- stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - - // some time later, after outputs have completed - ((CanCommitOffsets) stream.inputDStream()).commitAsync(offsetRanges); - } - }); +{% highlight java %} +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + // some time later, after outputs have completed + ((CanCommitOffsets) stream.inputDStream()).commitAsync(offsetRanges); + } +}); +{% endhighlight %}
@@ -210,64 +223,68 @@ For data stores that support transactions, saving offsets in the same transactio
- // The details depend on your data store, but the general idea looks like this +{% highlight scala %} +// The details depend on your data store, but the general idea looks like this - // begin from the the offsets committed to the database - val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => - new TopicPartition(resultSet.string("topic"), resultSet.int("partition")) -> resultSet.long("offset") - }.toMap +// begin from the the offsets committed to the database +val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => + new TopicPartition(resultSet.string("topic"), resultSet.int("partition")) -> resultSet.long("offset") +}.toMap - val stream = KafkaUtils.createDirectStream[String, String]( - streamingContext, - PreferConsistent, - Assign[String, String](fromOffsets.keys.toList, kafkaParams, fromOffsets) - ) +val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Assign[String, String](fromOffsets.keys.toList, kafkaParams, fromOffsets) +) - stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - val results = yourCalculation(rdd) + val results = yourCalculation(rdd) - // begin your transaction + // begin your transaction - // update results - // update offsets where the end of existing offsets matches the beginning of this batch of offsets - // assert that offsets were updated correctly + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly - // end your transaction - } + // end your transaction +} +{% endhighlight %}
- // The details depend on your data store, but the general idea looks like this - - // begin from the the offsets committed to the database - Map fromOffsets = new HashMap<>(); - for (resultSet : selectOffsetsFromYourDatabase) - fromOffsets.put(new TopicPartition(resultSet.string("topic"), resultSet.int("partition")), resultSet.long("offset")); - } - - JavaInputDStream> stream = KafkaUtils.createDirectStream( - streamingContext, - LocationStrategies.PreferConsistent(), - ConsumerStrategies.Assign(fromOffsets.keySet(), kafkaParams, fromOffsets) - ); - - stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - - Object results = yourCalculation(rdd); - - // begin your transaction - - // update results - // update offsets where the end of existing offsets matches the beginning of this batch of offsets - // assert that offsets were updated correctly - - // end your transaction - } - }); +{% highlight java %} +// The details depend on your data store, but the general idea looks like this + +// begin from the the offsets committed to the database +Map fromOffsets = new HashMap<>(); +for (resultSet : selectOffsetsFromYourDatabase) + fromOffsets.put(new TopicPartition(resultSet.string("topic"), resultSet.int("partition")), resultSet.long("offset")); +} + +JavaInputDStream> stream = KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Assign(fromOffsets.keySet(), kafkaParams, fromOffsets) +); + +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + Object results = yourCalculation(rdd); + + // begin your transaction + + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly + + // end your transaction + } +}); +{% endhighlight %}
@@ -277,25 +294,29 @@ The new Kafka consumer [supports SSL](http://kafka.apache.org/documentation.html
- val kafkaParams = Map[String, Object]( - // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS - "security.protocol" -> "SSL", - "ssl.truststore.location" -> "/some-directory/kafka.client.truststore.jks", - "ssl.truststore.password" -> "test1234", - "ssl.keystore.location" -> "/some-directory/kafka.client.keystore.jks", - "ssl.keystore.password" -> "test1234", - "ssl.key.password" -> "test1234" - ) +{% highlight scala %} +val kafkaParams = Map[String, Object]( + // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS + "security.protocol" -> "SSL", + "ssl.truststore.location" -> "/some-directory/kafka.client.truststore.jks", + "ssl.truststore.password" -> "test1234", + "ssl.keystore.location" -> "/some-directory/kafka.client.keystore.jks", + "ssl.keystore.password" -> "test1234", + "ssl.key.password" -> "test1234" +) +{% endhighlight %}
- Map kafkaParams = new HashMap(); - // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS - kafkaParams.put("security.protocol", "SSL"); - kafkaParams.put("ssl.truststore.location", "/some-directory/kafka.client.truststore.jks"); - kafkaParams.put("ssl.truststore.password", "test1234"); - kafkaParams.put("ssl.keystore.location", "/some-directory/kafka.client.keystore.jks"); - kafkaParams.put("ssl.keystore.password", "test1234"); - kafkaParams.put("ssl.key.password", "test1234"); +{% highlight java %} +Map kafkaParams = new HashMap(); +// the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS +kafkaParams.put("security.protocol", "SSL"); +kafkaParams.put("ssl.truststore.location", "/some-directory/kafka.client.truststore.jks"); +kafkaParams.put("ssl.truststore.password", "test1234"); +kafkaParams.put("ssl.keystore.location", "/some-directory/kafka.client.keystore.jks"); +kafkaParams.put("ssl.keystore.password", "test1234"); +kafkaParams.put("ssl.key.password", "test1234"); +{% endhighlight %}
diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index a6c3b3a9024d..c4c9fb3f7d3d 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -19,97 +19,103 @@ application. See the [Deploying](#deploying) subsection below.
+{% highlight scala %} - // Subscribe to 1 topic - val ds1 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to 1 topic +val ds1 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] - // Subscribe to multiple topics - val ds2 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to multiple topics +val ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] - // Subscribe to a pattern - val ds3 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to a pattern +val ds3 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] +{% endhighlight %}
+{% highlight java %} - // Subscribe to 1 topic - Dataset ds1 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to 1 topic +Dataset ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - // Subscribe to multiple topics - Dataset ds2 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to multiple topics +Dataset ds2 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - // Subscribe to a pattern - Dataset ds3 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to a pattern +Dataset ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +{% endhighlight %}
+{% highlight python %} - # Subscribe to 1 topic - ds1 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to 1 topic +ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - # Subscribe to multiple topics - ds2 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to multiple topics +ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - # Subscribe to a pattern - ds3 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to a pattern +ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +{% endhighlight %}
From 70a5db7bbd192a4bc68bcfdc475ab221adf2fcdd Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Wed, 2 Nov 2016 09:21:26 +0000 Subject: [PATCH 094/381] [SPARK-18204][WEBUI] Remove SparkUI.appUIAddress ## What changes were proposed in this pull request? Removing `appUIAddress` attribute since it is no longer in use. ## How was this patch tested? Local build Author: Jacek Laskowski Closes #15603 from jaceklaskowski/sparkui-fixes. --- .../cluster/StandaloneSchedulerBackend.scala | 6 +++--- .../main/scala/org/apache/spark/ui/SparkUI.scala | 13 +++---------- .../main/scala/org/apache/spark/ui/WebUI.scala | 8 ++++---- .../org/apache/spark/ui/jobs/AllJobsPage.scala | 4 ++-- .../org/apache/spark/ui/UISeleniumSuite.scala | 16 ++++++++-------- .../test/scala/org/apache/spark/ui/UISuite.scala | 13 ++++++------- .../MesosCoarseGrainedSchedulerBackend.scala | 2 +- .../mesos/MesosFineGrainedSchedulerBackend.scala | 2 +- .../apache/spark/streaming/UISeleniumSuite.scala | 12 ++++++------ .../spark/deploy/yarn/ApplicationMaster.scala | 2 +- .../cluster/YarnClientSchedulerBackend.scala | 2 +- 11 files changed, 36 insertions(+), 44 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 04d40e2907cf..368cd30a2e11 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -93,7 +93,7 @@ private[spark] class StandaloneSchedulerBackend( val javaOpts = sparkJavaOpts ++ extraJavaOpts val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) - val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") + val webUrl = sc.ui.map(_.webUrl).getOrElse("") val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) // If we're using dynamic allocation, set our initial executor limit to 0 for now. // ExecutorAllocationManager will send the real initial limit to the Master later. @@ -103,8 +103,8 @@ private[spark] class StandaloneSchedulerBackend( } else { None } - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) + val appDesc = ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, + webUrl, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) client = new StandaloneAppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() launcherBackend.setState(SparkAppHandle.State.SUBMITTED) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index f631a047a707..b828532aba7a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -82,7 +82,7 @@ private[spark] class SparkUI private ( initialize() def getSparkUser: String = { - environmentListener.systemProperties.toMap.get("user.name").getOrElse("") + environmentListener.systemProperties.toMap.getOrElse("user.name", "") } def getAppName: String = appName @@ -94,16 +94,9 @@ private[spark] class SparkUI private ( /** Stop the server behind this web interface. Only valid after bind(). */ override def stop() { super.stop() - logInfo("Stopped Spark web UI at %s".format(appUIAddress)) + logInfo(s"Stopped Spark web UI at $webUrl") } - /** - * Return the application UI host:port. This does not include the scheme (http://). - */ - private[spark] def appUIHostPort = publicHostName + ":" + boundPort - - private[spark] def appUIAddress = s"http://$appUIHostPort" - def getSparkUI(appId: String): Option[SparkUI] = { if (appId == this.appId) Some(this) else None } @@ -136,7 +129,7 @@ private[spark] class SparkUI private ( private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) extends WebUITab(parent, prefix) { - def appName: String = parent.getAppName + def appName: String = parent.appName } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index a05e0efb7a3e..8c801558672f 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -56,8 +56,8 @@ private[spark] abstract class WebUI( private val className = Utils.getFormattedClassName(this) def getBasePath: String = basePath - def getTabs: Seq[WebUITab] = tabs.toSeq - def getHandlers: Seq[ServletContextHandler] = handlers.toSeq + def getTabs: Seq[WebUITab] = tabs + def getHandlers: Seq[ServletContextHandler] = handlers def getSecurityManager: SecurityManager = securityManager /** Attach a tab to this UI, along with all of its attached pages. */ @@ -133,7 +133,7 @@ private[spark] abstract class WebUI( def initialize(): Unit /** Bind to the HTTP server behind this web interface. */ - def bind() { + def bind(): Unit = { assert(!serverInfo.isDefined, s"Attempted to bind $className more than once!") try { val host = Option(conf.getenv("SPARK_LOCAL_IP")).getOrElse("0.0.0.0") @@ -156,7 +156,7 @@ private[spark] abstract class WebUI( def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) /** Stop the server behind this web interface. Only valid after bind(). */ - def stop() { + def stop(): Unit = { assert(serverInfo.isDefined, s"Attempted to stop $className before binding to a server!") serverInfo.get.stop() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 173fc3cf31ce..50e8e2d19e15 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -289,8 +289,8 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val startTime = listener.startTime val endTime = listener.endTime val activeJobs = listener.activeJobs.values.toSeq - val completedJobs = listener.completedJobs.reverse.toSeq - val failedJobs = listener.failedJobs.reverse.toSeq + val completedJobs = listener.completedJobs.reverse + val failedJobs = listener.failedJobs.reverse val activeJobsTable = jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index e5d408a16736..f4786e3931c9 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -473,7 +473,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( - sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0") + sc.ui.get.webUrl.stripSuffix("/") + "/stages/stage/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST getResponseCode(url, "GET") should be (200) getResponseCode(url, "POST") should be (200) @@ -486,7 +486,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( - sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/kill/?id=0") + sc.ui.get.webUrl.stripSuffix("/") + "/jobs/job/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST getResponseCode(url, "GET") should be (200) getResponseCode(url, "POST") should be (200) @@ -620,7 +620,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B test("live UI json application list") { withSpark(newSparkContext()) { sc => val appListRawJson = HistoryServerSuite.getUrl(new URL( - sc.ui.get.appUIAddress + "/api/v1/applications")) + sc.ui.get.webUrl + "/api/v1/applications")) val appListJsonAst = JsonMethods.parse(appListRawJson) appListJsonAst.children.length should be (1) val attempts = (appListJsonAst \ "attempts").children @@ -640,7 +640,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) rdd.count() - val stage0 = Source.fromURL(sc.ui.get.appUIAddress + + val stage0 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + "label="Stage 0";\n subgraph ")) @@ -651,7 +651,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage0.contains("{\n label="groupBy";\n " + "2 [label="MapPartitionsRDD [2]")) - val stage1 = Source.fromURL(sc.ui.get.appUIAddress + + val stage1 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + "label="Stage 1";\n subgraph ")) @@ -662,7 +662,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage1.contains("{\n label="groupBy";\n " + "5 [label="MapPartitionsRDD [5]")) - val stage2 = Source.fromURL(sc.ui.get.appUIAddress + + val stage2 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + "label="Stage 2";\n subgraph ")) @@ -687,7 +687,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def goToUi(ui: SparkUI, path: String): Unit = { - go to (ui.appUIAddress.stripSuffix("/") + path) + go to (ui.webUrl.stripSuffix("/") + path) } def parseDate(json: JValue): Long = { @@ -699,6 +699,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def apiUrl(ui: SparkUI, path: String): URL = { - new URL(ui.appUIAddress + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) + new URL(ui.webUrl + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 4abcfb7e5191..68c7657cb315 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -66,7 +66,7 @@ class UISuite extends SparkFunSuite { withSpark(newSparkContext()) { sc => // test if the ui is visible, and all the expected tabs are visible eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.get.appUIAddress).mkString + val html = Source.fromURL(sc.ui.get.webUrl).mkString assert(!html.contains("random data that should not be present")) assert(html.toLowerCase.contains("stages")) assert(html.toLowerCase.contains("storage")) @@ -176,19 +176,18 @@ class UISuite extends SparkFunSuite { } } - test("verify appUIAddress contains the scheme") { + test("verify webUrl contains the scheme") { withSpark(newSparkContext()) { sc => val ui = sc.ui.get - val uiAddress = ui.appUIAddress - val uiHostPort = ui.appUIHostPort - assert(uiAddress.equals("http://" + uiHostPort)) + val uiAddress = ui.webUrl + assert(uiAddress.startsWith("http://") || uiAddress.startsWith("https://")) } } - test("verify appUIAddress contains the port") { + test("verify webUrl contains the port") { withSpark(newSparkContext()) { sc => val ui = sc.ui.get - val splitUIAddress = ui.appUIAddress.split(':') + val splitUIAddress = ui.webUrl.split(':') val boundPort = ui.boundPort assert(splitUIAddress(2).toInt == boundPort) } diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 5063c1fe988b..842c05e7bf73 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -158,7 +158,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( sc.sparkUser, sc.appName, sc.conf, - sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.appUIAddress)), + sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.webUrl)), None, None, sc.conf.getOption("spark.mesos.driver.frameworkId") diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 09a252f3c74a..c1aa00151e69 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -77,7 +77,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( sc.sparkUser, sc.appName, sc.conf, - sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.appUIAddress)), + sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.webUrl)), Option.empty, Option.empty, sc.conf.getOption("spark.mesos.driver.frameworkId") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 454c3dffa3db..e7cec999c219 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -92,13 +92,13 @@ class UISeleniumSuite val sparkUI = ssc.sparkContext.ui.get eventually(timeout(10 seconds), interval(50 milliseconds)) { - go to (sparkUI.appUIAddress.stripSuffix("/")) + go to (sparkUI.webUrl.stripSuffix("/")) find(cssSelector( """ul li a[href*="streaming"]""")) should not be (None) } eventually(timeout(10 seconds), interval(50 milliseconds)) { // check whether streaming page exists - go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming") + go to (sparkUI.webUrl.stripSuffix("/") + "/streaming") val h3Text = findAll(cssSelector("h3")).map(_.text).toSeq h3Text should contain("Streaming Statistics") @@ -180,23 +180,23 @@ class UISeleniumSuite jobDetails should contain("Completed Stages:") // Check a batch page without id - go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming/batch/") + go to (sparkUI.webUrl.stripSuffix("/") + "/streaming/batch/") webDriver.getPageSource should include ("Missing id parameter") // Check a non-exist batch - go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming/batch/?id=12345") + go to (sparkUI.webUrl.stripSuffix("/") + "/streaming/batch/?id=12345") webDriver.getPageSource should include ("does not exist") } ssc.stop(false) eventually(timeout(10 seconds), interval(50 milliseconds)) { - go to (sparkUI.appUIAddress.stripSuffix("/")) + go to (sparkUI.webUrl.stripSuffix("/")) find(cssSelector( """ul li a[href*="streaming"]""")) should be(None) } eventually(timeout(10 seconds), interval(50 milliseconds)) { - go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming") + go to (sparkUI.webUrl.stripSuffix("/") + "/streaming") val h3Text = findAll(cssSelector("h3")).map(_.text).toSeq h3Text should not contain("Streaming Statistics") } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index aabae140af8b..f2b9dfb4d184 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -406,7 +406,7 @@ private[spark] class ApplicationMaster( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), + registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl).getOrElse(""), securityMgr) } else { // Sanity check; should never happen in normal operation, since sc should only be null diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d8b36c5feaf5..60da356ad14a 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -44,7 +44,7 @@ private[spark] class YarnClientSchedulerBackend( val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort - sc.ui.foreach { ui => conf.set("spark.driver.appUIAddress", ui.appUIAddress) } + sc.ui.foreach { ui => conf.set("spark.driver.appUIAddress", ui.webUrl) } val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ("--arg", hostport) From 9c8deef64efee20a0ddc9b612f90e77c80aede60 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 2 Nov 2016 09:39:15 +0000 Subject: [PATCH 095/381] [SPARK-18076][CORE][SQL] Fix default Locale used in DateFormat, NumberFormat to Locale.US ## What changes were proposed in this pull request? Fix `Locale.US` for all usages of `DateFormat`, `NumberFormat` ## How was this patch tested? Existing tests. Author: Sean Owen Closes #15610 from srowen/SPARK-18076. --- .../org/apache/spark/SparkHadoopWriter.scala | 8 +++---- .../apache/spark/deploy/SparkHadoopUtil.scala | 4 ++-- .../apache/spark/deploy/master/Master.scala | 5 ++-- .../apache/spark/deploy/worker/Worker.scala | 4 ++-- .../org/apache/spark/rdd/HadoopRDD.scala | 5 ++-- .../org/apache/spark/rdd/NewHadoopRDD.scala | 4 ++-- .../apache/spark/rdd/PairRDDFunctions.scala | 4 ++-- .../status/api/v1/JacksonMessageWriter.scala | 4 ++-- .../spark/status/api/v1/SimpleDateParam.scala | 6 ++--- .../scala/org/apache/spark/ui/UIUtils.scala | 3 ++- .../spark/util/logging/RollingPolicy.scala | 6 ++--- .../org/apache/spark/util/UtilsSuite.scala | 2 +- .../deploy/rest/mesos/MesosRestServer.scala | 11 ++++----- .../mllib/pmml/export/PMMLModelExport.scala | 4 ++-- .../expressions/datetimeExpressions.scala | 17 ++++++------- .../expressions/stringExpressions.scala | 2 +- .../spark/sql/catalyst/json/JSONOptions.scala | 6 +++-- .../sql/catalyst/util/DateTimeUtils.scala | 6 ++--- .../expressions/DateExpressionsSuite.scala | 24 +++++++++---------- .../catalyst/util/DateTimeUtilsSuite.scala | 6 ++--- .../datasources/csv/CSVInferSchema.scala | 4 ++-- .../datasources/csv/CSVOptions.scala | 5 ++-- .../sql/execution/metric/SQLMetrics.scala | 2 +- .../sql/execution/streaming/socket.scala | 4 ++-- .../apache/spark/sql/DateFunctionsSuite.scala | 11 +++++---- .../execution/datasources/csv/CSVSuite.scala | 9 +++---- .../datasources/csv/CSVTypeCastSuite.scala | 9 ++++--- .../hive/execution/InsertIntoHiveTable.scala | 9 +++---- .../spark/sql/hive/hiveWriterContainers.scala | 4 ++-- .../sql/sources/SimpleTextRelation.scala | 3 ++- .../apache/spark/streaming/ui/UIUtils.scala | 8 ++++--- 31 files changed, 103 insertions(+), 96 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 6550d703bc86..7f75a393bf8f 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.IOException import java.text.NumberFormat import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path @@ -67,12 +67,12 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { def setup(jobid: Int, splitid: Int, attemptid: Int) { setIDs(jobid, splitid, attemptid) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(now), + HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now), jobid, splitID, attemptID, conf.value) } def open() { - val numfmt = NumberFormat.getInstance() + val numfmt = NumberFormat.getInstance(Locale.US) numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) @@ -162,7 +162,7 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { private[spark] object SparkHadoopWriter { def createJobID(time: Date, id: Int): JobID = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) val jobtrackerID = formatter.format(time) new JobID(jobtrackerID, id) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 3f54ecc17ac3..23156072c3eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,7 +21,7 @@ import java.io.IOException import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.text.DateFormat -import java.util.{Arrays, Comparator, Date} +import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -357,7 +357,7 @@ class SparkHadoopUtil extends Logging { * @return a printable string value. */ private[spark] def tokenToString(token: Token[_ <: TokenIdentifier]): String = { - val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT) + val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT, Locale.US) val buffer = new StringBuilder(128) buffer.append(token.toString) try { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 8c91aa15167c..4618e6117a4f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.master import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -51,7 +51,8 @@ private[deploy] class Master( private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 0bedd9a20a96..8b1c6bf2e5fd 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat -import java.util.{Date, UUID} +import java.util.{Date, Locale, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} @@ -68,7 +68,7 @@ private[deploy] class Worker( ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e1cf3938de09..36a2f5c87e37 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.collection.immutable.Map import scala.reflect.ClassTag @@ -243,7 +243,8 @@ class HadoopRDD[K, V]( var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(createTime), + HadoopRDD.addLocalConfiguration( + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime), context.stageId, theSplit.index, context.attemptNumber, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index baf31fb65887..488e777fea37 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.reflect.ClassTag @@ -79,7 +79,7 @@ class NewHadoopRDD[K, V]( // private val serializableConf = new SerializableWritable(_conf) private val jobTrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) formatter.format(new Date()) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 068f4ed8ad74..67baad1c51bc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.nio.ByteBuffer import java.text.SimpleDateFormat -import java.util.{Date, HashMap => JHashMap} +import java.util.{Date, HashMap => JHashMap, Locale} import scala.collection.{mutable, Map} import scala.collection.JavaConverters._ @@ -1079,7 +1079,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf val job = NewAPIHadoopJob.getInstance(hadoopConf) - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) val jobtrackerID = formatter.format(new Date()) val stageId = self.id val jobConfiguration = job.getConfiguration diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index f6a9f9c5573d..76af33c1a18d 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -21,7 +21,7 @@ import java.lang.annotation.Annotation import java.lang.reflect.Type import java.nio.charset.StandardCharsets import java.text.SimpleDateFormat -import java.util.{Calendar, SimpleTimeZone} +import java.util.{Calendar, Locale, SimpleTimeZone} import javax.ws.rs.Produces import javax.ws.rs.core.{MediaType, MultivaluedMap} import javax.ws.rs.ext.{MessageBodyWriter, Provider} @@ -86,7 +86,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ private[spark] object JacksonMessageWriter { def makeISODateFormat: SimpleDateFormat = { - val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'") + val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'", Locale.US) val cal = Calendar.getInstance(new SimpleTimeZone(0, "GMT")) iso8601.setCalendar(cal) iso8601 diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index 0c71cd238222..d8d5e8958b23 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -17,7 +17,7 @@ package org.apache.spark.status.api.v1 import java.text.{ParseException, SimpleDateFormat} -import java.util.TimeZone +import java.util.{Locale, TimeZone} import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status @@ -25,12 +25,12 @@ import javax.ws.rs.core.Response.Status private[v1] class SimpleDateParam(val originalValue: String) { val timestamp: Long = { - val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz", Locale.US) try { format.parse(originalValue).getTime() } catch { case _: ParseException => - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + val gmtDay = new SimpleDateFormat("yyyy-MM-dd", Locale.US) gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) try { gmtDay.parse(originalValue).getTime() diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index c0d1a2220f62..66b097aa8166 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -36,7 +36,8 @@ private[spark] object UIUtils extends Logging { // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss", Locale.US) } def formatDate(date: Date): String = dateFormat.get.format(date) diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala index 5c4238c0381a..1f263df57c85 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -18,7 +18,7 @@ package org.apache.spark.util.logging import java.text.SimpleDateFormat -import java.util.Calendar +import java.util.{Calendar, Locale} import org.apache.spark.internal.Logging @@ -59,7 +59,7 @@ private[spark] class TimeBasedRollingPolicy( } @volatile private var nextRolloverTime = calculateNextRolloverTime() - private val formatter = new SimpleDateFormat(rollingFileSuffixPattern) + private val formatter = new SimpleDateFormat(rollingFileSuffixPattern, Locale.US) /** Should rollover if current time has exceeded next rollover time */ def shouldRollover(bytesToBeWritten: Long): Boolean = { @@ -109,7 +109,7 @@ private[spark] class SizeBasedRollingPolicy( } @volatile private var bytesWrittenSinceRollover = 0L - val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS") + val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS", Locale.US) /** Should rollover if the next set of bytes is going to exceed the size limit */ def shouldRollover(bytesToBeWritten: Long): Boolean = { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 15ef32f21d90..feacfb7642f2 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -264,7 +264,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val hour = minute * 60 def str: (Long) => String = Utils.msDurationToString(_) - val sep = new DecimalFormatSymbols(Locale.getDefault()).getDecimalSeparator() + val sep = new DecimalFormatSymbols(Locale.US).getDecimalSeparator assert(str(123) === "123 ms") assert(str(second) === "1" + sep + "0 s") diff --git a/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 3b96488a129a..ff60b88c6d53 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.rest.mesos import java.io.File import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.atomic.AtomicLong import javax.servlet.http.HttpServletResponse @@ -62,11 +62,10 @@ private[mesos] class MesosSubmitRequestServlet( private val DEFAULT_CORES = 1.0 private val nextDriverNumber = new AtomicLong(0) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - private def newDriverId(submitDate: Date): String = { - "driver-%s-%04d".format( - createDateFormat.format(submitDate), nextDriverNumber.incrementAndGet()) - } + // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) + private def newDriverId(submitDate: Date): String = + f"driver-${createDateFormat.format(submitDate)}-${nextDriverNumber.incrementAndGet()}%04d" /** * Build a driver description from the fields specified in the submit request. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index 426bb818c926..f5ca1c221d66 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.pmml.export import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.beans.BeanProperty @@ -34,7 +34,7 @@ private[mllib] trait PMMLModelExport { val version = getClass.getPackage.getImplementationVersion val app = new Application("Apache Spark MLlib").setVersion(version) val timestamp = new Timestamp() - .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss", Locale.US).format(new Date())) val header = new Header() .setApplication(app) .setTimestamp(timestamp) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 7ab68a13e09c..67c078ae5e26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.text.SimpleDateFormat -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import scala.util.Try @@ -331,7 +331,7 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) override protected def nullSafeEval(timestamp: Any, format: Any): Any = { - val sdf = new SimpleDateFormat(format.toString) + val sdf = new SimpleDateFormat(format.toString, Locale.US) UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) } @@ -400,7 +400,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] private lazy val formatter: SimpleDateFormat = - Try(new SimpleDateFormat(constFormat.toString)).getOrElse(null) + Try(new SimpleDateFormat(constFormat.toString, Locale.US)).getOrElse(null) override def eval(input: InternalRow): Any = { val t = left.eval(input) @@ -425,7 +425,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { null } else { val formatString = f.asInstanceOf[UTF8String].toString - Try(new SimpleDateFormat(formatString).parse( + Try(new SimpleDateFormat(formatString, Locale.US).parse( t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) } } @@ -520,7 +520,7 @@ case class FromUnixTime(sec: Expression, format: Expression) private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] private lazy val formatter: SimpleDateFormat = - Try(new SimpleDateFormat(constFormat.toString)).getOrElse(null) + Try(new SimpleDateFormat(constFormat.toString, Locale.US)).getOrElse(null) override def eval(input: InternalRow): Any = { val time = left.eval(input) @@ -539,9 +539,10 @@ case class FromUnixTime(sec: Expression, format: Expression) if (f == null) { null } else { - Try(UTF8String.fromString(new SimpleDateFormat( - f.asInstanceOf[UTF8String].toString).format(new java.util.Date( - time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + Try( + UTF8String.fromString(new SimpleDateFormat(f.toString, Locale.US). + format(new java.util.Date(time.asInstanceOf[Long] * 1000L))) + ).getOrElse(null) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 1bcbb6cfc924..25a5e3fd7da7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1415,7 +1415,7 @@ case class Sentences( val locale = if (languageStr != null && countryStr != null) { new Locale(languageStr.toString, countryStr.toString) } else { - Locale.getDefault + Locale.US } getSentences(string.asInstanceOf[UTF8String].toString, locale) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index aec18922ea6c..c45970658cf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.json +import java.util.Locale + import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.commons.lang3.time.FastDateFormat @@ -56,11 +58,11 @@ private[sql] class JSONOptions( // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd")) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ")) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US) // Parse mode flags if (!ParseModes.isValidMode(parseMode)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 0b643a5b8426..235ca8d2633a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import javax.xml.bind.DatatypeConverter import scala.annotation.tailrec @@ -79,14 +79,14 @@ object DateTimeUtils { // `SimpleDateFormat` is not thread-safe. val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) } } // `SimpleDateFormat` is not thread-safe. private val threadLocalDateFormat = new ThreadLocal[DateFormat] { override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd") + new SimpleDateFormat("yyyy-MM-dd", Locale.US) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 6118a34d29ea..35cea25ba0b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.Calendar +import java.util.{Calendar, Locale} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -30,8 +30,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { import IntegralLiteralTestUtils._ - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - val sdfDate = new SimpleDateFormat("yyyy-MM-dd") + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val sdfDate = new SimpleDateFormat("yyyy-MM-dd", Locale.US) val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) @@ -49,7 +49,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("DayOfYear") { - val sdfDay = new SimpleDateFormat("D") + val sdfDay = new SimpleDateFormat("D", Locale.US) (0 to 3).foreach { m => (0 to 5).foreach { i => val c = Calendar.getInstance() @@ -411,9 +411,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("from_unixtime") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2) + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) checkEvaluation( FromUnixTime(Literal(0L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(0))) checkEvaluation(FromUnixTime( @@ -430,11 +430,11 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("unix_timestamp") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2) + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) val fmt3 = "yy-MM-dd" - val sdf3 = new SimpleDateFormat(fmt3) + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) val date1 = Date.valueOf("2015-07-24") checkEvaluation( UnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) @@ -466,11 +466,11 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("to_unix_timestamp") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2) + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) val fmt3 = "yy-MM-dd" - val sdf3 = new SimpleDateFormat(fmt3) + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) val date1 = Date.valueOf("2015-07-24") checkEvaluation( ToUnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 4f516d006458..e0a9a0c3d5c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.DateTimeUtils._ @@ -68,8 +68,8 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(d2.toString === d1.toString) } - val df1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - val df2 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss z") + val df1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val df2 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss z", Locale.US) checkFromToJavaDate(new Date(100)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 3ab775c90923..1981d8607c0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -247,7 +247,7 @@ private[csv] object CSVTypeCast { case options.positiveInf => Float.PositiveInfinity case _ => Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue()) } case _: DoubleType => datum match { @@ -256,7 +256,7 @@ private[csv] object CSVTypeCast { case options.positiveInf => Double.PositiveInfinity case _ => Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue()) } case _: BooleanType => datum.toBoolean case dt: DecimalType => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 014614eb997a..5903729c11fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.StandardCharsets +import java.util.Locale import org.apache.commons.lang3.time.FastDateFormat @@ -104,11 +105,11 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd")) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ")) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US) val maxColumns = getInt("maxColumns", 20480) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 0cc1edd196bc..dbc27d8b237f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -102,7 +102,7 @@ object SQLMetrics { */ def stringValue(metricsType: String, values: Seq[Long]): String = { if (metricsType == SUM_METRIC) { - val numberFormat = NumberFormat.getIntegerInstance(Locale.ENGLISH) + val numberFormat = NumberFormat.getIntegerInstance(Locale.US) numberFormat.format(values.sum) } else { val strFormat: Long => String = if (metricsType == SIZE_METRIC) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala index c662e7c6bc77..042977f870b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -21,7 +21,7 @@ import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.sql.Timestamp import java.text.SimpleDateFormat -import java.util.Calendar +import java.util.{Calendar, Locale} import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ListBuffer @@ -37,7 +37,7 @@ object TextSocketSource { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: StructField("timestamp", TimestampType) :: Nil) - val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index f7aa3b747ae5..e05b2252ee34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat +import java.util.Locale import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ @@ -55,8 +56,8 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = NOW()"""), Row(true)) } - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - val sdfDate = new SimpleDateFormat("yyyy-MM-dd") + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val sdfDate = new SimpleDateFormat("yyyy-MM-dd", Locale.US) val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) val ts = new Timestamp(sdf.parse("2013-04-08 13:10:15").getTime) @@ -395,11 +396,11 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { } test("from_unixtime") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2) + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) val fmt3 = "yy-MM-dd HH-mm-ss" - val sdf3 = new SimpleDateFormat(fmt3) + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") checkAnswer( df.select(from_unixtime(col("a"))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index f7c22c6c93f7..8209b5bd7f9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -21,6 +21,7 @@ import java.io.File import java.nio.charset.UnsupportedCharsetException import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat +import java.util.Locale import org.apache.commons.lang3.time.FastDateFormat import org.apache.hadoop.io.SequenceFile.CompressionType @@ -487,7 +488,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .select("date") .collect() - val dateFormat = new SimpleDateFormat("dd/MM/yyyy HH:mm") + val dateFormat = new SimpleDateFormat("dd/MM/yyyy HH:mm", Locale.US) val expected = Seq(Seq(new Timestamp(dateFormat.parse("26/08/2015 18:00").getTime)), Seq(new Timestamp(dateFormat.parse("27/10/2014 18:30").getTime)), @@ -509,7 +510,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .select("date") .collect() - val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm") + val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm", Locale.US) val expected = Seq( new Date(dateFormat.parse("26/08/2015 18:00").getTime), new Date(dateFormat.parse("27/10/2014 18:30").getTime), @@ -728,7 +729,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("inferSchema", "false") .load(iso8601timestampsPath) - val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ") + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ", Locale.US) val expectedTimestamps = timestamps.collect().map { r => // This should be ISO8601 formatted string. Row(iso8501.format(r.toSeq.head)) @@ -761,7 +762,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("inferSchema", "false") .load(iso8601datesPath) - val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd") + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd", Locale.US) val expectedDates = dates.collect().map { r => // This should be ISO8601 formatted string. Row(iso8501.format(r.toSeq.head)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index 51832a13cfe0..c74406b9cbfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -144,13 +144,12 @@ class CSVTypeCastSuite extends SparkFunSuite { DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) } - test("Float and Double Types are cast correctly with Locale") { + test("Float and Double Types are cast without respect to platform default Locale") { val originalLocale = Locale.getDefault try { - val locale : Locale = new Locale("fr", "FR") - Locale.setDefault(locale) - assert(CSVTypeCast.castTo("1,00", FloatType) == 1.0) - assert(CSVTypeCast.castTo("1,00", DoubleType) == 1.0) + Locale.setDefault(new Locale("fr", "FR")) + assert(CSVTypeCast.castTo("1,00", FloatType) == 100.0) // Would parse as 1.0 in fr-FR + assert(CSVTypeCast.castTo("1,00", DoubleType) == 100.0) } finally { Locale.setDefault(originalLocale) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 2843100fb3b3..05164d774cca 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -20,9 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io.IOException import java.net.URI import java.text.SimpleDateFormat -import java.util.{Date, Random} - -import scala.collection.JavaConverters._ +import java.util.{Date, Locale, Random} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} @@ -60,9 +58,8 @@ case class InsertIntoHiveTable( private def executionId: String = { val rand: Random = new Random - val format: SimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd_HH-mm-ss_SSS") - val executionId: String = "hive_" + format.format(new Date) + "_" + Math.abs(rand.nextLong) - return executionId + val format = new SimpleDateFormat("yyyy-MM-dd_HH-mm-ss_SSS", Locale.US) + "hive_" + format.format(new Date) + "_" + Math.abs(rand.nextLong) } private def getStagingDir(inputPath: Path, hadoopConf: Configuration): Path = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ea88276bb96c..e53c3e4d4833 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import java.text.NumberFormat -import java.util.Date +import java.util.{Date, Locale} import scala.collection.JavaConverters._ @@ -95,7 +95,7 @@ private[hive] class SparkHiveWriterContainer( } protected def getOutputName: String = { - val numberFormat = NumberFormat.getInstance() + val numberFormat = NumberFormat.getInstance(Locale.US) numberFormat.setMinimumIntegerDigits(5) numberFormat.setGroupingUsed(false) val extension = Utilities.getFileExtension(conf.value, fileSinkConf.getCompressed, outputFormat) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 64d0ecbeefc9..cecfd9909865 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import java.text.NumberFormat +import java.util.Locale import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} @@ -141,7 +142,7 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) class AppendingTextOutputFormat(path: String) extends TextOutputFormat[NullWritable, Text] { - val numberFormat = NumberFormat.getInstance() + val numberFormat = NumberFormat.getInstance(Locale.US) numberFormat.setMinimumIntegerDigits(5) numberFormat.setGroupingUsed(false) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala index 9b1c939e9329..84ecf81abfbf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.ui import java.text.SimpleDateFormat -import java.util.TimeZone +import java.util.{Locale, TimeZone} import java.util.concurrent.TimeUnit import scala.xml.Node @@ -80,11 +80,13 @@ private[streaming] object UIUtils { // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val batchTimeFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss", Locale.US) } private val batchTimeFormatWithMilliseconds = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss.SSS") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss.SSS", Locale.US) } /** From f151bd1af8a05d4b6c901ebe6ac0b51a4a1a20df Mon Sep 17 00:00:00 2001 From: eyal farago Date: Wed, 2 Nov 2016 11:12:20 +0100 Subject: [PATCH 096/381] [SPARK-16839][SQL] Simplify Struct creation code path ## What changes were proposed in this pull request? Simplify struct creation, especially the aspect of `CleanupAliases` which missed some aliases when handling trees created by `CreateStruct`. This PR includes: 1. A failing test (create struct with nested aliases, some of the aliases survive `CleanupAliases`). 2. A fix that transforms `CreateStruct` into a `CreateNamedStruct` constructor, effectively eliminating `CreateStruct` from all expression trees. 3. A `NamePlaceHolder` used by `CreateStruct` when column names cannot be extracted from unresolved `NamedExpression`. 4. A new Analyzer rule that resolves `NamePlaceHolder` into a string literal once the `NamedExpression` is resolved. 5. `CleanupAliases` code was simplified as it no longer has to deal with `CreateStruct`'s top level columns. ## How was this patch tested? Running all tests-suits in package org.apache.spark.sql, especially including the analysis suite, making sure added test initially fails, after applying suggested fix rerun the entire analysis package successfully. Modified few tests that expected `CreateStruct` which is now transformed into `CreateNamedStruct`. Author: eyal farago Author: Herman van Hovell Author: eyal farago Author: Eyal Farago Author: Hyukjin Kwon Author: eyalfa Closes #15718 from hvanhovell/SPARK-16839-2. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 12 +- .../sql/catalyst/analysis/Analyzer.scala | 53 ++--- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../sql/catalyst/expressions/Projection.scala | 2 - .../expressions/complexTypeCreator.scala | 212 ++++++------------ .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 38 +++- .../expressions/ComplexTypeSuite.scala | 1 - .../scala/org/apache/spark/sql/Column.scala | 3 + .../command/AnalyzeColumnCommand.scala | 4 +- .../sql-tests/results/group-by.sql.out | 2 +- .../apache/spark/sql/hive/test/TestHive.scala | 20 +- .../resources/sqlgen/subquery_in_having_2.sql | 2 +- .../sql/catalyst/LogicalPlanToSQLSuite.scala | 12 +- 14 files changed, 169 insertions(+), 198 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 806019d7524f..d7fe6b32822a 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1222,16 +1222,16 @@ test_that("column functions", { # Test struct() df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) - result <- collect(select(df, struct("a", "c"))) + result <- collect(select(df, alias(struct("a", "c"), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), - listToStruct(list(a = 4L, c = 6L))) + expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) - result <- collect(select(df, struct(df$a, df$b))) + result <- collect(select(df, alias(struct(df$a, df$b), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), - listToStruct(list(a = 4L, b = 5L))) + expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) # Test encode(), decode() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f8f4799322b3..5011f2fdbf9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.trees.{TreeNodeRef} +import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types._ @@ -83,6 +83,7 @@ class Analyzer( ResolveTableValuedFunctions :: ResolveRelations :: ResolveReferences :: + ResolveCreateNamedStruct :: ResolveDeserializer :: ResolveNewInstance :: ResolveUpCast :: @@ -653,11 +654,12 @@ class Analyzer( case s: Star => s.expand(child, resolver) case o => o :: Nil }) - case c: CreateStruct if containsStar(c.children) => - c.copy(children = c.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) + case c: CreateNamedStruct if containsStar(c.valExprs) => + val newChildren = c.children.grouped(2).flatMap { + case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children + case kv => kv + } + c.copy(children = newChildren.toList ) case c: CreateArray if containsStar(c.children) => c.copy(children = c.children.flatMap { case s: Star => s.expand(child, resolver) @@ -1141,7 +1143,7 @@ class Analyzer( case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => // Get the left hand side expressions. val expressions = e match { - case CreateStruct(exprs) => exprs + case cns : CreateNamedStruct => cns.valExprs case expr => Seq(expr) } resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => @@ -2072,18 +2074,8 @@ object EliminateUnions extends Rule[LogicalPlan] { */ object CleanupAliases extends Rule[LogicalPlan] { private def trimAliases(e: Expression): Expression = { - var stop = false e.transformDown { - // CreateStruct is a special case, we need to retain its top level Aliases as they decide the - // name of StructField. We also need to stop transform down this expression, or the Aliases - // under CreateStruct will be mistakenly trimmed. - case c: CreateStruct if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case c: CreateStructUnsafe if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case Alias(child, _) if !stop => child + case Alias(child, _) => child } } @@ -2116,15 +2108,8 @@ object CleanupAliases extends Rule[LogicalPlan] { case a: AppendColumns => a case other => - var stop = false other transformExpressionsDown { - case c: CreateStruct if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case c: CreateStructUnsafe if !stop => - stop = true - c.copy(children = c.children.map(trimNonTopLevelAliases)) - case Alias(child, _) if !stop => child + case Alias(child, _) => child } } } @@ -2217,3 +2202,19 @@ object TimeWindowing extends Rule[LogicalPlan] { } } } + +/** + * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. + */ +object ResolveCreateNamedStruct extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + case e: CreateNamedStruct if !e.resolved => + val children = e.children.grouped(2).flatMap { + case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => + Seq(Literal(e.name), e) + case kv => + kv + } + CreateNamedStruct(children.toList) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3e836ca375e2..b028d07fb8d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -357,7 +357,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), - expression[CreateStruct]("struct"), + CreateStruct.registryEntry, // misc functions expression[AssertTrue]("assert_true"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index a81fa1ce3adc..03e054d09851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -119,7 +119,6 @@ object UnsafeProjection { */ def create(exprs: Seq[Expression]): UnsafeProjection = { val unsafeExprs = exprs.map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) GenerateUnsafeProjection.generate(unsafeExprs) @@ -145,7 +144,6 @@ object UnsafeProjection { subexpressionEliminationEnabled: Boolean): UnsafeProjection = { val e = exprs.map(BindReferences.bindReference(_, inputSchema)) .map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 917aa0873130..dbfb2996ec9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -172,101 +174,71 @@ case class CreateMap(children: Seq[Expression]) extends Expression { } /** - * Returns a Row containing the evaluation of all children expressions. + * An expression representing a not yet available attribute name. This expression is unevaluable + * and as its name suggests it is a temporary place holder until we're able to determine the + * actual attribute name. */ -@ExpressionDescription( - usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.") -case class CreateStruct(children: Seq[Expression]) extends Expression { - - override def foldable: Boolean = children.forall(_.foldable) - - override lazy val dataType: StructType = { - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } - } - StructType(fields) - } - +case object NamePlaceholder extends LeafExpression with Unevaluable { + override lazy val resolved: Boolean = false + override def foldable: Boolean = false override def nullable: Boolean = false + override def dataType: DataType = StringType + override def prettyName: String = "NamePlaceholder" + override def toString: String = prettyName +} - override def eval(input: InternalRow): Any = { - InternalRow(children.map(_.eval(input)): _*) +/** + * Returns a Row containing the evaluation of all children expressions. + */ +object CreateStruct extends FunctionBuilder { + def apply(children: Seq[Expression]): CreateNamedStruct = { + CreateNamedStruct(children.zipWithIndex.flatMap { + case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) + case (e: NamedExpression, _) => Seq(NamePlaceholder, e) + case (e, index) => Seq(Literal(s"col${index + 1}"), e) + }) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rowClass = classOf[GenericInternalRow].getName - val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"this.$values = null;") - - ev.copy(code = s""" - boolean ${ev.isNull} = false; - this.$values = new Object[${children.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - children.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - }""" - }) + - s""" - final InternalRow ${ev.value} = new $rowClass($values); - this.$values = null; - """) + /** + * Entry to use in the function registry. + */ + val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = { + val info: ExpressionInfo = new ExpressionInfo( + "org.apache.spark.sql.catalyst.expressions.NamedStruct", + null, + "struct", + "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.", + "") + ("struct", (info, this)) } - - override def prettyName: String = "struct" } - /** - * Creates a struct with the given field names and values - * - * @param children Seq(name1, val1, name2, val2, ...) + * Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]]. */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") -// scalastyle:on line.size.limit -case class CreateNamedStruct(children: Seq[Expression]) extends Expression { +trait CreateNamedStructLike extends Expression { + lazy val (nameExprs, valExprs) = children.grouped(2).map { + case Seq(name, value) => (name, value) + }.toList.unzip - /** - * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this - * StructType. - */ - def flatten: Seq[NamedExpression] = valExprs.zip(names).map { - case (v, n) => Alias(v, n.toString)() - } + lazy val names = nameExprs.map(_.eval(EmptyRow)) - private lazy val (nameExprs, valExprs) = - children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + override def nullable: Boolean = false - private lazy val names = nameExprs.map(_.eval(EmptyRow)) + override def foldable: Boolean = valExprs.forall(_.foldable) override lazy val dataType: StructType = { val fields = names.zip(valExprs).map { - case (name, valExpr: NamedExpression) => - StructField(name.asInstanceOf[UTF8String].toString, - valExpr.dataType, valExpr.nullable, valExpr.metadata) - case (name, valExpr) => - StructField(name.asInstanceOf[UTF8String].toString, - valExpr.dataType, valExpr.nullable, Metadata.empty) + case (name, expr) => + val metadata = expr match { + case ne: NamedExpression => ne.metadata + case _ => Metadata.empty + } + StructField(name.toString, expr.dataType, expr.nullable, metadata) } StructType(fields) } - override def foldable: Boolean = valExprs.forall(_.foldable) - - override def nullable: Boolean = false - override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") @@ -274,8 +246,8 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - s"Only foldable StringType expressions are allowed to appear at odd position , got :" + - s" ${invalidNames.mkString(",")}") + "Only foldable StringType expressions are allowed to appear at odd position, got:" + + s" ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { @@ -284,9 +256,29 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { } } + /** + * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this + * StructType. + */ + def flatten: Seq[NamedExpression] = valExprs.zip(names).map { + case (v, n) => Alias(v, n.toString)() + } + override def eval(input: InternalRow): Any = { InternalRow(valExprs.map(_.eval(input)): _*) } +} + +/** + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") +// scalastyle:on line.size.limit +case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName @@ -316,44 +308,6 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { override def prettyName: String = "named_struct" } -/** - * Returns a Row containing the evaluation of all children expressions. This is a variant that - * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with - * this expression automatically at runtime. - */ -case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { - - override def foldable: Boolean = children.forall(_.foldable) - - override lazy val resolved: Boolean = childrenResolved - - override lazy val dataType: StructType = { - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } - } - StructType(fields) - } - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - InternalRow(children.map(_.eval(input)): _*) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval = GenerateUnsafeProjection.createCode(ctx, children) - ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) - } - - override def prettyName: String = "struct_unsafe" -} - - /** * Creates a struct with the given field names and values. This is a variant that returns * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with @@ -361,31 +315,7 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { * * @param children Seq(name1, val1, name2, val2, ...) */ -case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression { - - private lazy val (nameExprs, valExprs) = - children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip - - private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) - - override lazy val dataType: StructType = { - val fields = names.zip(valExprs).map { - case (name, valExpr: NamedExpression) => - StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata) - case (name, valExpr) => - StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) - } - StructType(fields) - } - - override def foldable: Boolean = valExprs.forall(_.foldable) - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - InternalRow(valExprs.map(_.eval(input)): _*) - } - +case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ac1577b3abb4..4b151c81d8f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -688,8 +688,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // inline table comes in two styles: // style 1: values (1), (2), (3) -- multiple columns are supported // style 2: values 1, 2, 3 -- only a single column is supported here - case CreateStruct(children) => children // style 1 - case child => Seq(child) // style 2 + case struct: CreateNamedStruct => struct.valExprs // style 1 + case child => Seq(child) // style 2 } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 590774c04304..817de48de279 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import org.scalatest.ShouldMatchers + import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -25,7 +27,8 @@ import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ -class AnalysisSuite extends AnalysisTest { + +class AnalysisSuite extends AnalysisTest with ShouldMatchers { import org.apache.spark.sql.catalyst.analysis.TestRelations._ test("union project *") { @@ -218,9 +221,36 @@ class AnalysisSuite extends AnalysisTest { // CreateStruct is a special case that we should not trim Alias for it. plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col")) - checkAnalysis(plan, plan) - plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col")) - checkAnalysis(plan, plan) + expected = testRelation.select(CreateNamedStruct(Seq( + Literal(a.name), a, + Literal("a+1"), (a + 1))).as("col")) + checkAnalysis(plan, expected) + } + + test("Analysis may leave unnecassary aliases") { + val att1 = testRelation.output.head + var plan = testRelation.select( + CreateStruct(Seq(att1, ((att1.as("aa")) + 1).as("a_plus_1"))).as("col"), + att1 + ) + val prevPlan = getAnalyzer(true).execute(plan) + plan = prevPlan.select(CreateArray(Seq( + CreateStruct(Seq(att1, (att1 + 1).as("a_plus_1"))).as("col1"), + /** alias should be eliminated by [[CleanupAliases]] */ + "col".attr.as("col2") + )).as("arr")) + plan = getAnalyzer(true).execute(plan) + + val expectedPlan = prevPlan.select( + CreateArray(Seq( + CreateNamedStruct(Seq( + Literal(att1.name), att1, + Literal("a_plus_1"), (att1 + 1))), + 'col.struct(prevPlan.output(0).dataType.asInstanceOf[StructType]).notNull + )).as("arr") + ) + + checkAnalysis(plan, expectedPlan) } test("SPARK-10534: resolve attribute references in order by clause") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 0c307b2b8576..c21c6de32c0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -243,7 +243,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val b = AttributeReference("b", IntegerType)() checkMetadata(CreateStruct(Seq(a, b))) checkMetadata(CreateNamedStruct(Seq("a", a, "b", b))) - checkMetadata(CreateStructUnsafe(Seq(a, b))) checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 249408e0fbce..7a131b30eafd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -186,6 +186,9 @@ class Column(val expr: Expression) extends Logging { case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => UnresolvedAlias(a, Some(Column.generateAlias)) + // Wait until the struct is resolved. This will generate a nicer looking alias. + case struct: CreateNamedStructLike => UnresolvedAlias(struct) + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index f873f34a845e..6141fab4aff0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -137,7 +137,7 @@ object ColumnStatStruct { private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - private def getStruct(exprs: Seq[Expression]): CreateStruct = { + private def getStruct(exprs: Seq[Expression]): CreateNamedStruct = { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() @@ -168,7 +168,7 @@ object ColumnStatStruct { } } - def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match { + def apply(attr: Attribute, relativeSD: Double): CreateNamedStruct = attr.dataType match { // Use aggregate functions to compute statistics we need. case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD)) case StringType => getStruct(stringColumnStat(attr, relativeSD)) diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index a91f04e098b1..af6c930d64b7 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -87,7 +87,7 @@ struct -- !query 9 SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1 -- !query 9 schema -struct> +struct> -- !query 9 output diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 6eb571b91ffa..90000445dffb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -190,6 +190,12 @@ private[hive] class TestHiveSparkSession( new File(Thread.currentThread().getContextClassLoader.getResource(path).getFile) } + private def quoteHiveFile(path : String) = if (Utils.isWindows) { + getHiveFile(path).getPath.replace('\\', '/') + } else { + getHiveFile(path).getPath + } + def getWarehousePath(): String = { val tempConf = new SQLConf sc.conf.getAll.foreach { case (k, v) => tempConf.setConfString(k, v) } @@ -225,16 +231,16 @@ private[hive] class TestHiveSparkSession( val hiveQTestUtilTables: Seq[TestTable] = Seq( TestTable("src", "CREATE TABLE src (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), + s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), TestTable("src1", "CREATE TABLE src1 (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), + s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), TestTable("srcpart", () => { sql( "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { sql( - s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') """.stripMargin) } @@ -244,7 +250,7 @@ private[hive] class TestHiveSparkSession( "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { sql( - s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') """.stripMargin) } @@ -269,7 +275,7 @@ private[hive] class TestHiveSparkSession( sql( s""" - |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' |INTO TABLE src_thrift """.stripMargin) }), @@ -308,7 +314,7 @@ private[hive] class TestHiveSparkSession( |) """.stripMargin.cmd, s""" - |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/episodes.avro")}' |INTO TABLE episodes """.stripMargin.cmd ), @@ -379,7 +385,7 @@ private[hive] class TestHiveSparkSession( TestTable("src_json", s"""CREATE TABLE src_json (json STRING) STORED AS TEXTFILE """.stripMargin.cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) + s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) ) hiveQTestUtilTables.foreach(registerTestTable) diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql index de0116a4dcba..cdda29af50e3 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql @@ -7,4 +7,4 @@ having b.key in (select a.key where a.value > 'val_9' and a.value = min(b.value)) order by b.key -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (named_struct('gen_attr_0', `gen_attr_0`, 'gen_attr_4', `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index c7f10e569fa4..12d18dc87ceb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst import java.nio.charset.StandardCharsets import java.nio.file.{Files, NoSuchFileException, Paths} +import scala.io.Source import scala.util.control.NonFatal import org.apache.spark.sql.Column @@ -109,12 +110,15 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { Files.write(path, answerText.getBytes(StandardCharsets.UTF_8)) } else { val goldenFileName = s"sqlgen/$answerFile.sql" - val resourceFile = getClass.getClassLoader.getResource(goldenFileName) - if (resourceFile == null) { + val resourceStream = getClass.getClassLoader.getResourceAsStream(goldenFileName) + if (resourceStream == null) { throw new NoSuchFileException(goldenFileName) } - val path = resourceFile.getPath - val answerText = new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8) + val answerText = try { + Source.fromInputStream(resourceStream).mkString + } finally { + resourceStream.close + } val sqls = answerText.split(separator) assert(sqls.length == 2, "Golden sql files should have a separator.") val expectedSQL = sqls(1).trim() From 4af0ce2d96de3397c9bc05684cad290a52486577 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Wed, 2 Nov 2016 11:29:26 -0700 Subject: [PATCH 097/381] [SPARK-17683][SQL] Support ArrayType in Literal.apply ## What changes were proposed in this pull request? This pr is to add pattern-matching entries for array data in `Literal.apply`. ## How was this patch tested? Added tests in `LiteralExpressionSuite`. Author: Takeshi YAMAMURO Closes #15257 from maropu/SPARK-17683. --- .../sql/catalyst/expressions/literals.scala | 57 ++++++++++++++++++- .../expressions/LiteralExpressionSuite.scala | 27 ++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index a597a17aadd9..1985e68c94e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,14 +17,25 @@ package org.apache.spark.sql.catalyst.expressions +import java.lang.{Boolean => JavaBoolean} +import java.lang.{Byte => JavaByte} +import java.lang.{Double => JavaDouble} +import java.lang.{Float => JavaFloat} +import java.lang.{Integer => JavaInteger} +import java.lang.{Long => JavaLong} +import java.lang.{Short => JavaShort} +import java.math.{BigDecimal => JavaBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.util import java.util.Objects import javax.xml.bind.DatatypeConverter +import scala.math.{BigDecimal, BigInt} + import org.json4s.JsonAST._ +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -46,12 +57,17 @@ object Literal { case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale)) - case d: java.math.BigDecimal => + case d: JavaBigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) + case a: Array[_] => + val elementType = componentTypeToDataType(a.getClass.getComponentType()) + val dataType = ArrayType(elementType) + val convert = CatalystTypeConverters.createToCatalystConverter(dataType) + Literal(convert(a), dataType) case i: CalendarInterval => Literal(i, CalendarIntervalType) case null => Literal(null, NullType) case v: Literal => v @@ -59,6 +75,45 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + /** + * Returns the Spark SQL DataType for a given class object. Since this type needs to be resolved + * in runtime, we use match-case idioms for class objects here. However, there are similar + * functions in other files (e.g., HiveInspectors), so these functions need to merged into one. + */ + private[this] def componentTypeToDataType(clz: Class[_]): DataType = clz match { + // primitive types + case JavaShort.TYPE => ShortType + case JavaInteger.TYPE => IntegerType + case JavaLong.TYPE => LongType + case JavaDouble.TYPE => DoubleType + case JavaByte.TYPE => ByteType + case JavaFloat.TYPE => FloatType + case JavaBoolean.TYPE => BooleanType + + // java classes + case _ if clz == classOf[Date] => DateType + case _ if clz == classOf[Timestamp] => TimestampType + case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[Array[Byte]] => BinaryType + case _ if clz == classOf[JavaShort] => ShortType + case _ if clz == classOf[JavaInteger] => IntegerType + case _ if clz == classOf[JavaLong] => LongType + case _ if clz == classOf[JavaDouble] => DoubleType + case _ if clz == classOf[JavaByte] => ByteType + case _ if clz == classOf[JavaFloat] => FloatType + case _ if clz == classOf[JavaBoolean] => BooleanType + + // other scala classes + case _ if clz == classOf[String] => StringType + case _ if clz == classOf[BigInt] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[BigDecimal] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[CalendarInterval] => CalendarIntervalType + + case _ if clz.isArray => ArrayType(componentTypeToDataType(clz.getComponentType)) + + case _ => throw new AnalysisException(s"Unsupported component type $clz in arrays") + } + /** * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object * into code generation. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 450222d8cbba..4af4da8a9f0c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -43,6 +44,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, TimestampType), null) checkEvaluation(Literal.create(null, CalendarIntervalType), null) checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) + checkEvaluation(Literal.create(null, ArrayType(StringType, true)), null) checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) } @@ -122,5 +124,28 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } } - // TODO(davies): add tests for ArrayType, MapType and StructType + test("array") { + def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = { + val toCatalyst = (a: Array[_], elementType: DataType) => { + CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a) + } + checkEvaluation(Literal(a), toCatalyst(a, elementType)) + } + checkArrayLiteral(Array(1, 2, 3), IntegerType) + checkArrayLiteral(Array("a", "b", "c"), StringType) + checkArrayLiteral(Array(1.0, 4.0), DoubleType) + checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR), + CalendarIntervalType) + } + + test("unsupported types (map and struct) in literals") { + def checkUnsupportedTypeInLiteral(v: Any): Unit = { + val errMsgMap = intercept[RuntimeException] { + Literal(v) + } + assert(errMsgMap.getMessage.startsWith("Unsupported literal type")) + } + checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) + checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) + } } From 742e0fea5391857964e90d396641ecf95cac4248 Mon Sep 17 00:00:00 2001 From: buzhihuojie Date: Wed, 2 Nov 2016 11:36:20 -0700 Subject: [PATCH 098/381] [SPARK-17895] Improve doc for rangeBetween and rowsBetween ## What changes were proposed in this pull request? Copied description for row and range based frame boundary from https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala#L56 Added examples to show different behavior of rangeBetween and rowsBetween when involving duplicate values. Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. Author: buzhihuojie Closes #15727 from david-weiluo-ren/improveDocForRangeAndRowsBetween. --- .../apache/spark/sql/expressions/Window.scala | 55 +++++++++++++++++++ .../spark/sql/expressions/WindowSpec.scala | 55 +++++++++++++++++++ 2 files changed, 110 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 0b26d863cac5..327bc379d413 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -121,6 +121,32 @@ object Window { * and [[Window.currentRow]] to specify special boundary values, rather than using integral * values directly. * + * A row based boundary is based on the position of the row within the partition. + * An offset indicates the number of rows above or below the current row, the frame for the + * current row starts or ends. For instance, given a row based sliding frame with a lower bound + * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from + * index 4 to index 6. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * df.withColumn("sum", + * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) + * .show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 2| + * | 1| a| 3| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value ([[Window.unboundedPreceding]]). * @param end boundary end, inclusive. The frame is unbounded if this is the @@ -144,6 +170,35 @@ object Window { * and [[Window.currentRow]] to specify special boundary values, rather than using integral * values directly. * + * A range based boundary is based on the actual value of the ORDER BY + * expression(s). An offset is used to alter the value of the ORDER BY expression, for + * instance if the current order by expression has a value of 10 and the lower bound offset + * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a + * number of constraints on the ORDER BY expressions: there can be only one expression and this + * expression must have a numerical data type. An exception can be made when the offset is 0, + * because no value modification is needed, in this case multiple and non-numeric ORDER BY + * expression are allowed. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * df.withColumn("sum", + * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) + * .show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 4| + * | 1| a| 4| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value ([[Window.unboundedPreceding]]). * @param end boundary end, inclusive. The frame is unbounded if this is the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 1e85b6e7881a..4a8ce695bd4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -89,6 +89,32 @@ class WindowSpec private[sql]( * and [[Window.currentRow]] to specify special boundary values, rather than using integral * values directly. * + * A row based boundary is based on the position of the row within the partition. + * An offset indicates the number of rows above or below the current row, the frame for the + * current row starts or ends. For instance, given a row based sliding frame with a lower bound + * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from + * index 4 to index 6. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * df.withColumn("sum", + * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) + * .show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 2| + * | 1| a| 3| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value ([[Window.unboundedPreceding]]). * @param end boundary end, inclusive. The frame is unbounded if this is the @@ -111,6 +137,35 @@ class WindowSpec private[sql]( * and [[Window.currentRow]] to specify special boundary values, rather than using integral * values directly. * + * A range based boundary is based on the actual value of the ORDER BY + * expression(s). An offset is used to alter the value of the ORDER BY expression, for + * instance if the current order by expression has a value of 10 and the lower bound offset + * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a + * number of constraints on the ORDER BY expressions: there can be only one expression and this + * expression must have a numerical data type. An exception can be made when the offset is 0, + * because no value modification is needed, in this case multiple and non-numeric ORDER BY + * expression are allowed. + * + * {{{ + * import org.apache.spark.sql.expressions.Window + * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) + * .toDF("id", "category") + * df.withColumn("sum", + * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) + * .show() + * + * +---+--------+---+ + * | id|category|sum| + * +---+--------+---+ + * | 1| b| 3| + * | 2| b| 5| + * | 3| b| 3| + * | 1| a| 4| + * | 1| a| 4| + * | 2| a| 2| + * +---+--------+---+ + * }}} + * * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value ([[Window.unboundedPreceding]]). * @param end boundary end, inclusive. The frame is unbounded if this is the From 02f203107b8eda1f1576e36c4f12b0e3bc5e910e Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 2 Nov 2016 11:41:49 -0700 Subject: [PATCH 099/381] [SPARK-14393][SQL] values generated by non-deterministic functions shouldn't change after coalesce or union ## What changes were proposed in this pull request? When a user appended a column using a "nondeterministic" function to a DataFrame, e.g., `rand`, `randn`, and `monotonically_increasing_id`, the expected semantic is the following: - The value in each row should remain unchanged, as if we materialize the column immediately, regardless of later DataFrame operations. However, since we use `TaskContext.getPartitionId` to get the partition index from the current thread, the values from nondeterministic columns might change if we call `union` or `coalesce` after. `TaskContext.getPartitionId` returns the partition index of the current Spark task, which might not be the corresponding partition index of the DataFrame where we defined the column. See the unit tests below or JIRA for examples. This PR uses the partition index from `RDD.mapPartitionWithIndex` instead of `TaskContext` and fixes the partition initialization logic in whole-stage codegen, normal codegen, and codegen fallback. `initializeStatesForPartition(partitionIndex: Int)` was added to `Projection`, `Nondeterministic`, and `Predicate` (codegen) and initialized right after object creation in `mapPartitionWithIndex`. `newPredicate` now returns a `Predicate` instance rather than a function for proper initialization. ## How was this patch tested? Unit tests. (Actually I'm not very confident that this PR fixed all issues without introducing new ones ...) cc: rxin davies Author: Xiangrui Meng Closes #15567 from mengxr/SPARK-14393. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 16 +++++- .../sql/catalyst/expressions/Expression.scala | 19 +++++-- .../catalyst/expressions/InputFileName.scala | 2 +- .../MonotonicallyIncreasingID.scala | 11 ++-- .../sql/catalyst/expressions/Projection.scala | 22 +++++--- .../expressions/SparkPartitionID.scala | 13 +++-- .../expressions/codegen/CodeGenerator.scala | 14 +++++ .../expressions/codegen/CodegenFallback.scala | 18 +++++-- .../codegen/GenerateMutableProjection.scala | 4 ++ .../codegen/GeneratePredicate.scala | 18 +++++-- .../codegen/GenerateSafeProjection.scala | 4 ++ .../codegen/GenerateUnsafeProjection.scala | 4 ++ .../sql/catalyst/expressions/package.scala | 10 +++- .../sql/catalyst/expressions/predicates.scala | 4 -- .../expressions/randomExpressions.scala | 14 ++--- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../expressions/ExpressionEvalHelper.scala | 5 +- .../CodegenExpressionCachingSuite.scala | 13 +++-- .../sql/execution/DataSourceScanExec.scala | 6 ++- .../spark/sql/execution/ExistingRDD.scala | 3 +- .../spark/sql/execution/GenerateExec.scala | 3 +- .../spark/sql/execution/SparkPlan.scala | 4 +- .../sql/execution/WholeStageCodegenExec.scala | 8 ++- .../execution/basicPhysicalOperators.scala | 8 +-- .../columnar/InMemoryTableScanExec.scala | 5 +- .../joins/BroadcastNestedLoopJoinExec.scala | 7 +-- .../joins/CartesianProductExec.scala | 8 +-- .../spark/sql/execution/joins/HashJoin.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 2 +- .../apache/spark/sql/execution/objects.scala | 6 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 52 +++++++++++++++++++ .../hive/execution/HiveTableScanExec.scala | 3 +- 32 files changed, 231 insertions(+), 78 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index db535de9e9bb..e018af35cb18 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -788,14 +788,26 @@ abstract class RDD[T: ClassTag]( } /** - * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a - * performance API to be used carefully only if we are sure that the RDD elements are + * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning. + * It is a performance API to be used carefully only if we are sure that the RDD elements are * serializable and don't require closure cleaning. * * @param preservesPartitioning indicates whether the input function preserves the partitioner, * which should be `false` unless this is a pair RDD and the input function doesn't modify * the keys. */ + private[spark] def mapPartitionsWithIndexInternal[U: ClassTag]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter), + preservesPartitioning) + } + + /** + * [performance] Spark's internal mapPartitions method that skips closure cleaning. + */ private[spark] def mapPartitionsInternal[U: ClassTag]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 9edc1ceff26a..726a231fd814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -272,17 +272,28 @@ trait Nondeterministic extends Expression { final override def deterministic: Boolean = false final override def foldable: Boolean = false + @transient private[this] var initialized = false - final def setInitialValues(): Unit = { - initInternal() + /** + * Initializes internal states given the current partition index and mark this as initialized. + * Subclasses should override [[initializeInternal()]]. + */ + final def initialize(partitionIndex: Int): Unit = { + initializeInternal(partitionIndex) initialized = true } - protected def initInternal(): Unit + protected def initializeInternal(partitionIndex: Int): Unit + /** + * @inheritdoc + * Throws an exception if [[initialize()]] is not called yet. + * Subclasses should override [[evalInternal()]]. + */ final override def eval(input: InternalRow = null): Any = { - require(initialized, "nondeterministic expression should be initialized before evaluate") + require(initialized, + s"Nondeterministic expression ${this.getClass.getName} should be initialized before eval.") evalInternal(input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index 96929ecf5637..b6c12c535111 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -37,7 +37,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def prettyName: String = "input_file_name" - override protected def initInternal(): Unit = {} + override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { InputFileNameHolder.getInputFileName() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 5b4922e0cf2b..72b8dcca26e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -50,9 +50,9 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis @transient private[this] var partitionMask: Long = _ - override protected def initInternal(): Unit = { + override protected def initializeInternal(partitionIndex: Int): Unit = { count = 0L - partitionMask = TaskContext.getPartitionId().toLong << 33 + partitionMask = partitionIndex.toLong << 33 } override def nullable: Boolean = false @@ -68,9 +68,10 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") - ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;") - ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, - s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;") + ctx.addMutableState(ctx.JAVA_LONG, countTerm, "") + ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "") + ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") + ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 03e054d09851..476e37e6a9ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.{DataType, StructType} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. + * * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ @@ -30,10 +31,12 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) - expressions.foreach(_.foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - }) + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } // null check is required for when Kryo invokes the no-arg constructor. protected val exprArray = if (expressions != null) expressions.toArray else null @@ -54,6 +57,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { /** * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified * expressions. + * * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ @@ -63,10 +67,12 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu private[this] val buffer = new Array[Any](expressions.size) - expressions.foreach(_.foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - }) + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } private[this] val exprArray = expressions.toArray private[this] var mutableRow: InternalRow = new GenericInternalRow(exprArray.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 1f675d5b0727..6bef473cac06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} /** - * Expression that returns the current partition id of the Spark task. + * Expression that returns the current partition id. */ @ExpressionDescription( - usage = "_FUNC_() - Returns the current partition id of the Spark task", + usage = "_FUNC_() - Returns the current partition id", extended = "> SELECT _FUNC_();\n 0") case class SparkPartitionID() extends LeafExpression with Nondeterministic { @@ -38,16 +37,16 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override val prettyName = "SPARK_PARTITION_ID" - override protected def initInternal(): Unit = { - partitionId = TaskContext.getPartitionId() + override protected def initializeInternal(partitionIndex: Int): Unit = { + partitionId = partitionIndex } override protected def evalInternal(input: InternalRow): Int = partitionId override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = ctx.freshName("partitionId") - ctx.addMutableState(ctx.JAVA_INT, idTerm, - s"$idTerm = org.apache.spark.TaskContext.getPartitionId();") + ctx.addMutableState(ctx.JAVA_INT, idTerm, "") + ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 6cab50ae1bf8..9c3c6d3b2a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -184,6 +184,20 @@ class CodegenContext { splitExpressions(initCodes, "init", Nil) } + /** + * Code statements to initialize states that depend on the partition index. + * An integer `partitionIndex` will be made available within the scope. + */ + val partitionInitializationStatements: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty + + def addPartitionInitializationStatement(statement: String): Unit = { + partitionInitializationStatements += statement + } + + def initPartition(): String = { + partitionInitializationStatements.mkString("\n") + } + /** * Holding all the functions those will be added into generated class. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 6a5a3e7933ee..0322d1dd6a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -25,15 +25,23 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, No trait CodegenFallback extends Expression { protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - } - // LeafNode does not need `input` val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW val idx = ctx.references.length ctx.references += this + var childIndex = idx + this.foreach { + case n: Nondeterministic => + // This might add the current expression twice, but it won't hurt. + ctx.references += n + childIndex += 1 + ctx.addPartitionInitializationStatement( + s""" + |((Nondeterministic) references[$childIndex]) + | .initialize(partitionIndex); + """.stripMargin) + case _ => + } val objectTerm = ctx.freshName("obj") val placeHolder = ctx.registerComment(this.toString) if (nullable) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 5c4b56b0b224..4d732445544a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -111,6 +111,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ${ctx.initMutableStates()} } + public void initialize(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} public ${classOf[BaseMutableProjection].getName} target(InternalRow row) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 39aa7b17de6c..dcd1ed96a298 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -25,19 +25,26 @@ import org.apache.spark.sql.catalyst.expressions._ */ abstract class Predicate { def eval(r: InternalRow): Boolean + + /** + * Initializes internal states given the current partition index. + * This is used by nondeterministic expressions to set initial states. + * The default implementation does nothing. + */ + def initialize(partitionIndex: Int): Unit = {} } /** * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]]. */ -object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Boolean] { +object GeneratePredicate extends CodeGenerator[Expression, Predicate] { protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = BindReferences.bindReference(in, inputSchema) - protected def create(predicate: Expression): ((InternalRow) => Boolean) = { + protected def create(predicate: Expression): Predicate = { val ctx = newCodeGenContext() val eval = predicate.genCode(ctx) @@ -55,6 +62,10 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool ${ctx.initMutableStates()} } + public void initialize(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} public boolean eval(InternalRow ${ctx.INPUT_ROW}) { @@ -67,7 +78,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") - val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] - (r: InternalRow) => p.eval(r) + CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 2773e1a66621..b1cb6edefb85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -173,6 +173,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${ctx.initMutableStates()} } + public void initialize(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} public java.lang.Object apply(java.lang.Object _i) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7cc45372daa5..7e4c9089a2cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -380,6 +380,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${ctx.initMutableStates()} } + public void initialize(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} // Scala.Function1 need this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1510a4796683..1b00c9e79da2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -64,7 +64,15 @@ package object expressions { * column of the new row. If the schema of the input row is specified, then the given expression * will be bound to that schema. */ - abstract class Projection extends (InternalRow => InternalRow) + abstract class Projection extends (InternalRow => InternalRow) { + + /** + * Initializes internal states given the current partition index. + * This is used by nondeterministic expressions to set initial states. + * The default implementation does nothing. + */ + def initialize(partitionIndex: Int): Unit = {} + } /** * Converts a [[InternalRow]] to another Row given a sequence of expression that define each diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 9394e39aadd9..c941a576d00d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -31,10 +31,6 @@ object InterpretedPredicate { create(BindReferences.bindReference(expression, inputSchema)) def create(expression: Expression): (InternalRow => Boolean) = { - expression.foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - } (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index ca200768b228..e09029f5aab9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -42,8 +42,8 @@ abstract class RDG extends LeafExpression with Nondeterministic { */ @transient protected var rng: XORShiftRandom = _ - override protected def initInternal(): Unit = { - rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + override protected def initializeInternal(partitionIndex: Int): Unit = { + rng = new XORShiftRandom(seed + partitionIndex) } override def nullable: Boolean = false @@ -70,8 +70,9 @@ case class Rand(seed: Long) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") + ctx.addMutableState(className, rngTerm, "") + ctx.addPartitionInitializationStatement( + s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") } @@ -93,8 +94,9 @@ case class Randn(seed: Long) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") + ctx.addMutableState(className, rngTerm, "") + ctx.addPartitionInitializationStatement( + s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e5e2cd7d27d1..b6ad5db74e3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1060,6 +1060,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { case Project(projectList, LocalRelation(output, data)) if !projectList.exists(hasUnevaluableExpr) => val projection = new InterpretedProjection(projectList, output) + projection.initialize(0) LocalRelation(projectList.map(_.toAttribute), data.map(projection)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index f0c149c02b9a..9ceb70918541 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -75,7 +75,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.foreach { - case n: Nondeterministic => n.setInitialValues() + case n: Nondeterministic => n.initialize(0) case _ => } expression.eval(inputRow) @@ -121,6 +121,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) + plan.initialize(0) val actual = plan(inputRow).get(0, expression.dataType) if (!checkResult(actual, expected)) { @@ -182,12 +183,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { var plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) + plan.initialize(0) var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) + plan.initialize(0) actual = FromUnsafeProjection(expression.dataType :: Nil)( plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 06dc3bd33b90..fe5cb8eda824 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -31,19 +31,22 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { // Use an Add to wrap two of them together in case we only initialize the top level expressions. val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = UnsafeProjection.create(Seq(expr)) + instance.initialize(0) assert(instance.apply(null).getBoolean(0) === false) } test("GenerateMutableProjection should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GenerateMutableProjection.generate(Seq(expr)) + instance.initialize(0) assert(instance.apply(null).getBoolean(0) === false) } test("GeneratePredicate should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GeneratePredicate.generate(expr) - assert(instance.apply(null) === false) + instance.initialize(0) + assert(instance.eval(null) === false) } test("GenerateUnsafeProjection should not share expression instances") { @@ -73,13 +76,13 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { test("GeneratePredicate should not share expression instances") { val expr1 = MutableExpression() val instance1 = GeneratePredicate.generate(expr1) - assert(instance1.apply(null) === false) + assert(instance1.eval(null) === false) val expr2 = MutableExpression() expr2.mutableState = true val instance2 = GeneratePredicate.generate(expr2) - assert(instance1.apply(null) === false) - assert(instance2.apply(null) === true) + assert(instance1.eval(null) === false) + assert(instance2.eval(null) === true) } } @@ -89,7 +92,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { */ case class NondeterministicExpression() extends LeafExpression with Nondeterministic with CodegenFallback { - override protected def initInternal(): Unit = { } + override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): Any = false override def nullable: Boolean = false override def dataType: DataType = BooleanType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index fdd1fa364825..e485b52b43f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -71,8 +71,9 @@ case class RowDataSourceScanExec( val unsafeRow = if (outputUnsafeRows) { rdd } else { - rdd.mapPartitionsInternal { iter => + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) + proj.initialize(index) iter.map(proj) } } @@ -284,8 +285,9 @@ case class FileSourceScanExec( val unsafeRows = { val scan = inputRDD if (needsUnsafeRowConversion) { - scan.mapPartitionsInternal { iter => + scan.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) + proj.initialize(index) iter.map(proj) } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 455fb5bfbb6f..aab087cd9871 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -190,8 +190,9 @@ case class RDDScanExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - rdd.mapPartitionsInternal { iter => + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) + proj.initialize(index) iter.map { r => numOutputRows += 1 proj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 266312956266..19fbf0c16204 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -94,8 +94,9 @@ case class GenerateExec( } val numOutputRows = longMetric("numOutputRows") - rows.mapPartitionsInternal { iter => + rows.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(output, output) + proj.initialize(index) iter.map { r => numOutputRows += 1 proj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 24d0cffef82a..cadab37a449a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -29,7 +29,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetric @@ -354,7 +354,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } protected def newPredicate( - expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { + expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = { GeneratePredicate.generate(expression, inputSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 6303483f22fd..516b9d5444d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -331,6 +331,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co partitionIndex = index; this.inputs = inputs; ${ctx.initMutableStates()} + ${ctx.initPartition()} } ${ctx.declareAddedFunctions()} @@ -383,10 +384,13 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } else { // Right now, we support up to two input RDDs. rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) => - val partitionIndex = TaskContext.getPartitionId() + Iterator((leftIter, rightIter)) + // a small hack to obtain the correct partition index + }.mapPartitionsWithIndex { (index, zippedIter) => + val (leftIter, rightIter) = zippedIter.next() val clazz = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] - buffer.init(partitionIndex, Array(leftIter, rightIter)) + buffer.init(index, Array(leftIter, rightIter)) new Iterator[InternalRow] { override def hasNext: Boolean = { val v = buffer.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index a5291e0c12f8..32133f52630c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -70,9 +70,10 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } protected override def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val project = UnsafeProjection.create(projectList, child.output, subexpressionEliminationEnabled) + project.initialize(index) iter.map(project) } } @@ -205,10 +206,11 @@ case class FilterExec(condition: Expression, child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val predicate = newPredicate(condition, child.output) + predicate.initialize(0) iter.filter { row => - val r = predicate(row) + val r = predicate.eval(row) if (r) numOutputRows += 1 r } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index b87016d5a569..9028caa446e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -132,10 +132,11 @@ case class InMemoryTableScanExec( val relOutput: AttributeSeq = relation.output val buffers = relation.cachedColumnBuffers - buffers.mapPartitionsInternal { cachedBatchIterator => + buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), schema) + partitionFilter.initialize(index) // Find the ordinals and data types of the requested columns. val (requestedColumnIndices, requestedColumnDataTypes) = @@ -147,7 +148,7 @@ case class InMemoryTableScanExec( val cachedBatchesToScan = if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => - if (!partitionFilter(cachedBatch.stats)) { + if (!partitionFilter.eval(cachedBatch.stats)) { def statsString: String = schemaIndex.map { case (a, i) => val value = cachedBatch.stats.get(i, a.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index bfe7e3dea45d..f526a1987667 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -52,7 +52,7 @@ case class BroadcastNestedLoopJoinExec( UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil } - private[this] def genResultProjection: InternalRow => InternalRow = joinType match { + private[this] def genResultProjection: UnsafeProjection = joinType match { case LeftExistence(j) => UnsafeProjection.create(output, output) case other => @@ -84,7 +84,7 @@ case class BroadcastNestedLoopJoinExec( @transient private lazy val boundCondition = { if (condition.isDefined) { - newPredicate(condition.get, streamed.output ++ broadcast.output) + newPredicate(condition.get, streamed.output ++ broadcast.output).eval _ } else { (r: InternalRow) => true } @@ -366,8 +366,9 @@ case class BroadcastNestedLoopJoinExec( } val numOutputRows = longMetric("numOutputRows") - resultRdd.mapPartitionsInternal { iter => + resultRdd.mapPartitionsWithIndexInternal { (index, iter) => val resultProj = genResultProjection + resultProj.initialize(index) iter.map { r => numOutputRows += 1 resultProj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 15dc9b40662e..8341fe2ffd07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -98,15 +98,15 @@ case class CartesianProductExec( val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) - pair.mapPartitionsInternal { iter => + pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { - val boundCondition: (InternalRow) => Boolean = - newPredicate(condition.get, left.output ++ right.output) + val boundCondition = newPredicate(condition.get, left.output ++ right.output) + boundCondition.initialize(index) val joined = new JoinedRow iter.filter { r => - boundCondition(joined(r._1, r._2)) + boundCondition.eval(joined(r._1, r._2)) } } else { iter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 05c5e2f4cd77..1aef5f686426 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -81,7 +81,7 @@ trait HashJoin { UnsafeProjection.create(streamedKeys) @transient private[this] lazy val boundCondition = if (condition.isDefined) { - newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) + newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval _ } else { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ecf7cf289f03..ca9c0ed8cec3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -101,7 +101,7 @@ case class SortMergeJoinExec( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => - newPredicate(cond, left.output ++ right.output) + newPredicate(cond, left.output ++ right.output).eval _ }.getOrElse { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 9df56bbf1ef8..fde3b2a52899 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -87,8 +87,9 @@ case class DeserializeToObjectExec( } override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output) + projection.initialize(index) iter.map(projection) } } @@ -124,8 +125,9 @@ case class SerializeFromObjectExec( } override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val projection = UnsafeProjection.create(serializer) + projection.initialize(index) iter.map(projection) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 586a0fffeb7a..0e9a2c6cf7de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -19,7 +19,13 @@ package org.apache.spark.sql import java.nio.charset.StandardCharsets +import scala.util.Random + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -406,4 +412,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(true), Row(true)) ) } + + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { + import DataFrameFunctionsSuite.CodegenFallbackExpr + for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { + val c = if (codegenFallback) { + Column(CodegenFallbackExpr(v.expr)) + } else { + v + } + withSQLConf( + (SQLConf.WHOLESTAGE_FALLBACK.key, codegenFallback.toString), + (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) { + val df = spark.range(0, 4, 1, 4).withColumn("c", c) + val rows = df.collect() + val rowsAfterCoalesce = df.coalesce(2).collect() + assert(rows === rowsAfterCoalesce, "Values changed after coalesce when " + + s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + + val df1 = spark.range(0, 2, 1, 2).withColumn("c", c) + val rows1 = df1.collect() + val df2 = spark.range(2, 4, 1, 2).withColumn("c", c) + val rows2 = df2.collect() + val rowsAfterUnion = df1.union(df2).collect() + assert(rowsAfterUnion === rows1 ++ rows2, "Values changed after union when " + + s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + } + } + } + + test("SPARK-14393: values generated by non-deterministic functions shouldn't change after " + + "coalesce or union") { + Seq( + monotonically_increasing_id(), spark_partition_id(), + rand(Random.nextLong()), randn(Random.nextLong()) + ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) + } +} + +object DataFrameFunctionsSuite { + case class CodegenFallbackExpr(child: Expression) extends Expression with CodegenFallback { + override def children: Seq[Expression] = Seq(child) + override def nullable: Boolean = child.nullable + override def dataType: DataType = child.dataType + override lazy val resolved = true + override def eval(input: InternalRow): Any = child.eval(input) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 231f204b12b4..c80695bd3e0f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -154,8 +154,9 @@ case class HiveTableScanExec( val numOutputRows = longMetric("numOutputRows") // Avoid to serialize MetastoreRelation because schema is lazy. (see SPARK-15649) val outputSchema = schema - rdd.mapPartitionsInternal { iter => + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(outputSchema) + proj.initialize(index) iter.map { r => numOutputRows += 1 proj(r) From 3c24299b71e23e159edbb972347b13430f92a465 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 2 Nov 2016 11:47:45 -0700 Subject: [PATCH 100/381] [SPARK-18160][CORE][YARN] spark.files & spark.jars should not be passed to driver in yarn mode ## What changes were proposed in this pull request? spark.files is still passed to driver in yarn mode, so SparkContext will still handle it which cause the error in the jira desc. ## How was this patch tested? Tested manually in a 5 node cluster. As this issue only happens in multiple node cluster, so I didn't write test for it. Author: Jeff Zhang Closes #15669 from zjffdu/SPARK-18160. --- .../scala/org/apache/spark/SparkContext.scala | 29 ++++--------------- .../org/apache/spark/deploy/yarn/Client.scala | 5 +++- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4694790c72cd..63478c88b057 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1716,29 +1716,12 @@ class SparkContext(config: SparkConf) extends Logging { key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - if (master == "yarn" && deployMode == "cluster") { - // In order for this to work in yarn cluster mode the user must specify the - // --addJars option to the client to upload the file into the distributed cache - // of the AM to make it show up in the current working directory. - val fileName = new Path(uri.getPath).getName() - try { - env.rpcEnv.fileServer.addJar(new File(fileName)) - } catch { - case e: Exception => - // For now just log an error but allow to go through so spark examples work. - // The spark examples don't really need the jar distributed since its also - // the app jar. - logError("Error adding jar (" + e + "), was the --addJars option used?") - null - } - } else { - try { - env.rpcEnv.fileServer.addJar(new File(uri.getPath)) - } catch { - case exc: FileNotFoundException => - logError(s"Jar not found at $path") - null - } + try { + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) + } catch { + case exc: FileNotFoundException => + logError(s"Jar not found at $path") + null } // A JAR file which exists locally on every worker node case "local" => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 55e4a833b670..053a78617d4e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1202,7 +1202,10 @@ private object Client extends Logging { // Note that any env variable with the SPARK_ prefix gets propagated to all (remote) processes System.setProperty("SPARK_YARN_MODE", "true") val sparkConf = new SparkConf - + // SparkSubmit would use yarn cache to distribute files & jars in yarn mode, + // so remove them from sparkConf here for yarn mode. + sparkConf.remove("spark.jars") + sparkConf.remove("spark.files") val args = new ClientArguments(argStrings) new Client(args, sparkConf).run() } From 37d95227a21de602b939dae84943ba007f434513 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Wed, 2 Nov 2016 11:52:29 -0700 Subject: [PATCH 101/381] [SPARK-17058][BUILD] Add maven snapshots-and-staging profile to build/test against staging artifacts ## What changes were proposed in this pull request? Adds a `snapshots-and-staging profile` so that RCs of projects like Hadoop and HBase can be used in developer-only build and test runs. There's a comment above the profile telling people not to use this in production. There's no attempt to do the same for SBT, as Ivy is different. ## How was this patch tested? Tested by building against the Hadoop 2.7.3 RC 1 JARs without the profile (and without any local copy of the 2.7.3 artifacts), the build failed ``` mvn install -DskipTests -Pyarn,hadoop-2.7,hive -Dhadoop.version=2.7.3 ... [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Launcher 2.1.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ Downloading: https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-client/2.7.3/hadoop-client-2.7.3.pom [WARNING] The POM for org.apache.hadoop:hadoop-client:jar:2.7.3 is missing, no dependency information available Downloading: https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-client/2.7.3/hadoop-client-2.7.3.jar [INFO] ------------------------------------------------------------------------ [INFO] Reactor Summary: [INFO] [INFO] Spark Project Parent POM ........................... SUCCESS [ 4.482 s] [INFO] Spark Project Tags ................................. SUCCESS [ 17.402 s] [INFO] Spark Project Sketch ............................... SUCCESS [ 11.252 s] [INFO] Spark Project Networking ........................... SUCCESS [ 13.458 s] [INFO] Spark Project Shuffle Streaming Service ............ SUCCESS [ 9.043 s] [INFO] Spark Project Unsafe ............................... SUCCESS [ 16.027 s] [INFO] Spark Project Launcher ............................. FAILURE [ 1.653 s] [INFO] Spark Project Core ................................. SKIPPED ... ``` With the profile, the build completed ``` mvn install -DskipTests -Pyarn,hadoop-2.7,hive,snapshots-and-staging -Dhadoop.version=2.7.3 ``` Author: Steve Loughran Closes #14646 from steveloughran/stevel/SPARK-17058-support-asf-snapshots. --- pom.xml | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/pom.xml b/pom.xml index aaf7cfa7eb2a..04d2eaa1d3ba 100644 --- a/pom.xml +++ b/pom.xml @@ -2693,6 +2693,54 @@ + + + snapshots-and-staging + + + https://repository.apache.org/content/groups/staging/ + https://repository.apache.org/content/repositories/snapshots/ + + + + + ASF Staging + ${asf.staging} + + + ASF Snapshots + ${asf.snapshots} + + true + + + false + + + + + + + ASF Staging + ${asf.staging} + + + ASF Snapshots + ${asf.snapshots} + + true + + + false + + + + + + + org.json + json + From b533fa2b205544b42dcebe0a6fee9d8275f6da7d Mon Sep 17 00:00:00 2001 From: Michael Allman Date: Thu, 10 Nov 2016 13:41:13 -0800 Subject: [PATCH 185/381] [SPARK-17993][SQL] Fix Parquet log output redirection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (Link to Jira issue: https://issues.apache.org/jira/browse/SPARK-17993) ## What changes were proposed in this pull request? PR #14690 broke parquet log output redirection for converted partitioned Hive tables. For example, when querying parquet files written by Parquet-mr 1.6.0 Spark prints a torrent of (harmless) warning messages from the Parquet reader: ``` Oct 18, 2016 7:42:18 PM WARNING: org.apache.parquet.CorruptStatistics: Ignoring statistics because created_by could not be parsed (see PARQUET-251): parquet-mr version 1.6.0 org.apache.parquet.VersionParser$VersionParseException: Could not parse created_by: parquet-mr version 1.6.0 using format: (.+) version ((.*) )?\(build ?(.*)\) at org.apache.parquet.VersionParser.parse(VersionParser.java:112) at org.apache.parquet.CorruptStatistics.shouldIgnoreStatistics(CorruptStatistics.java:60) at org.apache.parquet.format.converter.ParquetMetadataConverter.fromParquetStatistics(ParquetMetadataConverter.java:263) at org.apache.parquet.hadoop.ParquetFileReader$Chunk.readAllPages(ParquetFileReader.java:583) at org.apache.parquet.hadoop.ParquetFileReader.readNextRowGroup(ParquetFileReader.java:513) at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.checkEndOfRowGroup(VectorizedParquetRecordReader.java:270) at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.nextBatch(VectorizedParquetRecordReader.java:225) at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.nextKeyValue(VectorizedParquetRecordReader.java:137) at org.apache.spark.sql.execution.datasources.RecordReaderIterator.hasNext(RecordReaderIterator.scala:39) at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:102) at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.nextIterator(FileScanRDD.scala:162) at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:102) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.scan_nextBatch$(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:372) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:231) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:225) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319) at org.apache.spark.rdd.RDD.iterator(RDD.scala:283) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:99) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` This only happens during execution, not planning, and it doesn't matter what log level the `SparkContext` is set to. That's because Parquet (versions < 1.9) doesn't use slf4j for logging. Note, you can tell that log redirection is not working here because the log message format does not conform to the default Spark log message format. This is a regression I noted as something we needed to fix as a follow up. It appears that the problem arose because we removed the call to `inferSchema` during Hive table conversion. That call is what triggered the output redirection. ## How was this patch tested? I tested this manually in four ways: 1. Executing `spark.sqlContext.range(10).selectExpr("id as a").write.mode("overwrite").parquet("test")`. 2. Executing `spark.read.format("parquet").load(legacyParquetFile).show` for a Parquet file `legacyParquetFile` written using Parquet-mr 1.6.0. 3. Executing `select * from legacy_parquet_table limit 1` for some unpartitioned Parquet-based Hive table written using Parquet-mr 1.6.0. 4. Executing `select * from legacy_partitioned_parquet_table where partcol=x limit 1` for some partitioned Parquet-based Hive table written using Parquet-mr 1.6.0. I ran each test with a new instance of `spark-shell` or `spark-sql`. Incidentally, I found that test case 3 was not a regression—redirection was not occurring in the master codebase prior to #14690. I spent some time working on a unit test, but based on my experience working on this ticket I feel that automated testing here is far from feasible. cc ericl dongjoon-hyun Author: Michael Allman Closes #15538 from mallman/spark-17993-fix_parquet_log_redirection. --- .../parquet/ParquetLogRedirector.java | 72 +++++++++++++++++++ .../parquet/ParquetFileFormat.scala | 58 ++++----------- sql/core/src/test/resources/log4j.properties | 4 +- sql/hive/src/test/resources/log4j.properties | 4 ++ 4 files changed, 90 insertions(+), 48 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetLogRedirector.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetLogRedirector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetLogRedirector.java new file mode 100644 index 000000000000..7a7f32ee1e87 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetLogRedirector.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.Serializable; +import java.util.logging.Handler; +import java.util.logging.Logger; + +import org.apache.parquet.Log; +import org.slf4j.bridge.SLF4JBridgeHandler; + +// Redirects the JUL logging for parquet-mr versions <= 1.8 to SLF4J logging using +// SLF4JBridgeHandler. Parquet-mr versions >= 1.9 use SLF4J directly +final class ParquetLogRedirector implements Serializable { + // Client classes should hold a reference to INSTANCE to ensure redirection occurs. This is + // especially important for Serializable classes where fields are set but constructors are + // ignored + static final ParquetLogRedirector INSTANCE = new ParquetLogRedirector(); + + // JUL loggers must be held by a strong reference, otherwise they may get destroyed by GC. + // However, the root JUL logger used by Parquet isn't properly referenced. Here we keep + // references to loggers in both parquet-mr <= 1.6 and 1.7/1.8 + private static final Logger apacheParquetLogger = + Logger.getLogger(Log.class.getPackage().getName()); + private static final Logger parquetLogger = Logger.getLogger("parquet"); + + static { + // For parquet-mr 1.7 and 1.8, which are under `org.apache.parquet` namespace. + try { + Class.forName(Log.class.getName()); + redirect(Logger.getLogger(Log.class.getPackage().getName())); + } catch (ClassNotFoundException ex) { + throw new RuntimeException(ex); + } + + // For parquet-mr 1.6.0 and lower versions bundled with Hive, which are under `parquet` + // namespace. + try { + Class.forName("parquet.Log"); + redirect(Logger.getLogger("parquet")); + } catch (Throwable t) { + // SPARK-9974: com.twitter:parquet-hadoop-bundle:1.6.0 is not packaged into the assembly + // when Spark is built with SBT. So `parquet.Log` may not be found. This try/catch block + // should be removed after this issue is fixed. + } + } + + private ParquetLogRedirector() { + } + + private static void redirect(Logger logger) { + for (Handler handler : logger.getHandlers()) { + logger.removeHandler(handler); + } + logger.setUseParentHandlers(false); + logger.addHandler(new SLF4JBridgeHandler()); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index b8ea7f40c4ab..031a0fe57893 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.parquet import java.net.URI -import java.util.logging.{Logger => JLogger} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -29,14 +28,12 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl -import org.apache.parquet.{Log => ApacheParquetLog} import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.codec.CodecConfig import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType -import org.slf4j.bridge.SLF4JBridgeHandler import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging @@ -56,6 +53,11 @@ class ParquetFileFormat with DataSourceRegister with Logging with Serializable { + // Hold a reference to the (serializable) singleton instance of ParquetLogRedirector. This + // ensures the ParquetLogRedirector class is initialized whether an instance of ParquetFileFormat + // is constructed or deserialized. Do not heed the Scala compiler's warning about an unused field + // here. + private val parquetLogRedirector = ParquetLogRedirector.INSTANCE override def shortName(): String = "parquet" @@ -129,10 +131,14 @@ class ParquetFileFormat conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) } - ParquetFileFormat.redirectParquetLogs() - new OutputWriterFactory { - override def newInstance( + // This OutputWriterFactory instance is deserialized when writing Parquet files on the + // executor side without constructing or deserializing ParquetFileFormat. Therefore, we hold + // another reference to ParquetLogRedirector.INSTANCE here to ensure the latter class is + // initialized. + private val parquetLogRedirector = ParquetLogRedirector.INSTANCE + + override def newInstance( path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { @@ -673,44 +679,4 @@ object ParquetFileFormat extends Logging { Failure(cause) }.toOption } - - // JUL loggers must be held by a strong reference, otherwise they may get destroyed by GC. - // However, the root JUL logger used by Parquet isn't properly referenced. Here we keep - // references to loggers in both parquet-mr <= 1.6 and >= 1.7 - val apacheParquetLogger: JLogger = JLogger.getLogger(classOf[ApacheParquetLog].getPackage.getName) - val parquetLogger: JLogger = JLogger.getLogger("parquet") - - // Parquet initializes its own JUL logger in a static block which always prints to stdout. Here - // we redirect the JUL logger via SLF4J JUL bridge handler. - val redirectParquetLogsViaSLF4J: Unit = { - def redirect(logger: JLogger): Unit = { - logger.getHandlers.foreach(logger.removeHandler) - logger.setUseParentHandlers(false) - logger.addHandler(new SLF4JBridgeHandler) - } - - // For parquet-mr 1.7.0 and above versions, which are under `org.apache.parquet` namespace. - // scalastyle:off classforname - Class.forName(classOf[ApacheParquetLog].getName) - // scalastyle:on classforname - redirect(JLogger.getLogger(classOf[ApacheParquetLog].getPackage.getName)) - - // For parquet-mr 1.6.0 and lower versions bundled with Hive, which are under `parquet` - // namespace. - try { - // scalastyle:off classforname - Class.forName("parquet.Log") - // scalastyle:on classforname - redirect(JLogger.getLogger("parquet")) - } catch { case _: Throwable => - // SPARK-9974: com.twitter:parquet-hadoop-bundle:1.6.0 is not packaged into the assembly - // when Spark is built with SBT. So `parquet.Log` may not be found. This try/catch block - // should be removed after this issue is fixed. - } - } - - /** - * ParquetFileFormat.prepareWrite calls this function to initialize `redirectParquetLogsViaSLF4J`. - */ - def redirectParquetLogs(): Unit = {} } diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index 33b9ecf1e282..25b817382195 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -53,5 +53,5 @@ log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF # Parquet related logging -log4j.logger.org.apache.parquet.hadoop=WARN -log4j.logger.org.apache.spark.sql.parquet=INFO +log4j.logger.org.apache.parquet=ERROR +log4j.logger.parquet=ERROR diff --git a/sql/hive/src/test/resources/log4j.properties b/sql/hive/src/test/resources/log4j.properties index fea3404769d9..072bb25d30a8 100644 --- a/sql/hive/src/test/resources/log4j.properties +++ b/sql/hive/src/test/resources/log4j.properties @@ -59,3 +59,7 @@ log4j.logger.hive.ql.metadata.Hive=OFF log4j.additivity.org.apache.hadoop.hive.ql.io.RCFile=false log4j.logger.org.apache.hadoop.hive.ql.io.RCFile=ERROR + +# Parquet related logging +log4j.logger.org.apache.parquet=ERROR +log4j.logger.parquet=ERROR From 2f7461f31331cfc37f6cfa3586b7bbefb3af5547 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 10 Nov 2016 13:42:48 -0800 Subject: [PATCH 186/381] [SPARK-17990][SPARK-18302][SQL] correct several partition related behaviours of ExternalCatalog ## What changes were proposed in this pull request? This PR corrects several partition related behaviors of `ExternalCatalog`: 1. default partition location should not always lower case the partition column names in path string(fix `HiveExternalCatalog`) 2. rename partition should not always lower case the partition column names in updated partition path string(fix `HiveExternalCatalog`) 3. rename partition should update the partition location only for managed table(fix `InMemoryCatalog`) 4. create partition with existing directory should be fine(fix `InMemoryCatalog`) 5. create partition with non-existing directory should create that directory(fix `InMemoryCatalog`) 6. drop partition from external table should not delete the directory(fix `InMemoryCatalog`) ## How was this patch tested? new tests in `ExternalCatalogSuite` Author: Wenchen Fan Closes #15797 from cloud-fan/partition. --- .../catalog/ExternalCatalogUtils.scala | 121 ++++++++++++++ .../catalyst/catalog/InMemoryCatalog.scala | 92 +++++------ .../sql/catalyst/catalog/interface.scala | 11 ++ .../catalog/ExternalCatalogSuite.scala | 150 ++++++++++++++---- .../catalog/SessionCatalogSuite.scala | 24 ++- .../spark/sql/execution/command/ddl.scala | 8 +- .../spark/sql/execution/command/tables.scala | 3 +- .../datasources/CatalogFileIndex.scala | 2 +- .../datasources/DataSourceStrategy.scala | 2 +- .../datasources/FileFormatWriter.scala | 6 +- .../PartitioningAwareFileIndex.scala | 2 - .../datasources/PartitioningUtils.scala | 94 +---------- .../sql/execution/command/DDLSuite.scala | 8 +- .../ParquetPartitionDiscoverySuite.scala | 21 +-- .../spark/sql/hive/HiveExternalCatalog.scala | 51 +++++- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 4 +- .../spark/sql/hive/MultiDatabaseSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 19 files changed, 397 insertions(+), 208 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala new file mode 100644 index 000000000000..b1442eec164d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.util.Shell + +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec + +object ExternalCatalogUtils { + // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't + // depend on Hive. + val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" + + ////////////////////////////////////////////////////////////////////////////////////////////////// + // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). + ////////////////////////////////////////////////////////////////////////////////////////////////// + + val charToEscape = { + val bitSet = new java.util.BitSet(128) + + /** + * ASCII 01-1F are HTTP control characters that need to be escaped. + * \u000A and \u000D are \n and \r, respectively. + */ + val clist = Array( + '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', + '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', + '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', + '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', + '{', '[', ']', '^') + + clist.foreach(bitSet.set(_)) + + if (Shell.WINDOWS) { + Array(' ', '<', '>', '|').foreach(bitSet.set(_)) + } + + bitSet + } + + def needsEscaping(c: Char): Boolean = { + c >= 0 && c < charToEscape.size() && charToEscape.get(c) + } + + def escapePathName(path: String): String = { + val builder = new StringBuilder() + path.foreach { c => + if (needsEscaping(c)) { + builder.append('%') + builder.append(f"${c.asInstanceOf[Int]}%02X") + } else { + builder.append(c) + } + } + + builder.toString() + } + + + def unescapePathName(path: String): String = { + val sb = new StringBuilder + var i = 0 + + while (i < path.length) { + val c = path.charAt(i) + if (c == '%' && i + 2 < path.length) { + val code: Int = try { + Integer.parseInt(path.substring(i + 1, i + 3), 16) + } catch { + case _: Exception => -1 + } + if (code >= 0) { + sb.append(code.asInstanceOf[Char]) + i += 3 + } else { + sb.append(c) + i += 1 + } + } else { + sb.append(c) + i += 1 + } + } + + sb.toString() + } + + def generatePartitionPath( + spec: TablePartitionSpec, + partitionColumnNames: Seq[String], + tablePath: Path): Path = { + val partitionPathStrings = partitionColumnNames.map { col => + val partitionValue = spec(col) + val partitionString = if (partitionValue == null) { + DEFAULT_PARTITION_NAME + } else { + escapePathName(partitionValue) + } + escapePathName(col) + "=" + partitionString + } + partitionPathStrings.foldLeft(tablePath) { (totalPath, nextPartPath) => + new Path(totalPath, nextPartPath) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 20db81e6f906..a3ffeaa63f69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -231,7 +231,7 @@ class InMemoryCatalog( assert(tableMeta.storage.locationUri.isDefined, "Managed table should always have table location, as we will assign a default location " + "to it if it doesn't have one.") - val dir = new Path(tableMeta.storage.locationUri.get) + val dir = new Path(tableMeta.location) try { val fs = dir.getFileSystem(hadoopConfig) fs.delete(dir, true) @@ -259,7 +259,7 @@ class InMemoryCatalog( assert(oldDesc.table.storage.locationUri.isDefined, "Managed table should always have table location, as we will assign a default location " + "to it if it doesn't have one.") - val oldDir = new Path(oldDesc.table.storage.locationUri.get) + val oldDir = new Path(oldDesc.table.location) val newDir = new Path(catalog(db).db.locationUri, newName) try { val fs = oldDir.getFileSystem(hadoopConfig) @@ -355,25 +355,28 @@ class InMemoryCatalog( } } - val tableDir = new Path(catalog(db).db.locationUri, table) - val partitionColumnNames = getTable(db, table).partitionColumnNames + val tableMeta = getTable(db, table) + val partitionColumnNames = tableMeta.partitionColumnNames + val tablePath = new Path(tableMeta.location) // TODO: we should follow hive to roll back if one partition path failed to create. parts.foreach { p => - // If location is set, the partition is using an external partition location and we don't - // need to handle its directory. - if (p.storage.locationUri.isEmpty) { - val partitionPath = partitionColumnNames.flatMap { col => - p.spec.get(col).map(col + "=" + _) - }.mkString("/") - try { - val fs = tableDir.getFileSystem(hadoopConfig) - fs.mkdirs(new Path(tableDir, partitionPath)) - } catch { - case e: IOException => - throw new SparkException(s"Unable to create partition path $partitionPath", e) + val partitionPath = p.storage.locationUri.map(new Path(_)).getOrElse { + ExternalCatalogUtils.generatePartitionPath(p.spec, partitionColumnNames, tablePath) + } + + try { + val fs = tablePath.getFileSystem(hadoopConfig) + if (!fs.exists(partitionPath)) { + fs.mkdirs(partitionPath) } + } catch { + case e: IOException => + throw new SparkException(s"Unable to create partition path $partitionPath", e) } - existingParts.put(p.spec, p) + + existingParts.put( + p.spec, + p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toString)))) } } @@ -392,19 +395,15 @@ class InMemoryCatalog( } } - val tableDir = new Path(catalog(db).db.locationUri, table) - val partitionColumnNames = getTable(db, table).partitionColumnNames - // TODO: we should follow hive to roll back if one partition path failed to delete. + val shouldRemovePartitionLocation = getTable(db, table).tableType == CatalogTableType.MANAGED + // TODO: we should follow hive to roll back if one partition path failed to delete, and support + // partial partition spec. partSpecs.foreach { p => - // If location is set, the partition is using an external partition location and we don't - // need to handle its directory. - if (existingParts.contains(p) && existingParts(p).storage.locationUri.isEmpty) { - val partitionPath = partitionColumnNames.flatMap { col => - p.get(col).map(col + "=" + _) - }.mkString("/") + if (existingParts.contains(p) && shouldRemovePartitionLocation) { + val partitionPath = new Path(existingParts(p).location) try { - val fs = tableDir.getFileSystem(hadoopConfig) - fs.delete(new Path(tableDir, partitionPath), true) + val fs = partitionPath.getFileSystem(hadoopConfig) + fs.delete(partitionPath, true) } catch { case e: IOException => throw new SparkException(s"Unable to delete partition path $partitionPath", e) @@ -423,33 +422,34 @@ class InMemoryCatalog( requirePartitionsExist(db, table, specs) requirePartitionsNotExist(db, table, newSpecs) - val tableDir = new Path(catalog(db).db.locationUri, table) - val partitionColumnNames = getTable(db, table).partitionColumnNames + val tableMeta = getTable(db, table) + val partitionColumnNames = tableMeta.partitionColumnNames + val tablePath = new Path(tableMeta.location) + val shouldUpdatePartitionLocation = getTable(db, table).tableType == CatalogTableType.MANAGED + val existingParts = catalog(db).tables(table).partitions // TODO: we should follow hive to roll back if one partition path failed to rename. specs.zip(newSpecs).foreach { case (oldSpec, newSpec) => - val newPart = getPartition(db, table, oldSpec).copy(spec = newSpec) - val existingParts = catalog(db).tables(table).partitions - - // If location is set, the partition is using an external partition location and we don't - // need to handle its directory. - if (newPart.storage.locationUri.isEmpty) { - val oldPath = partitionColumnNames.flatMap { col => - oldSpec.get(col).map(col + "=" + _) - }.mkString("/") - val newPath = partitionColumnNames.flatMap { col => - newSpec.get(col).map(col + "=" + _) - }.mkString("/") + val oldPartition = getPartition(db, table, oldSpec) + val newPartition = if (shouldUpdatePartitionLocation) { + val oldPartPath = new Path(oldPartition.location) + val newPartPath = ExternalCatalogUtils.generatePartitionPath( + newSpec, partitionColumnNames, tablePath) try { - val fs = tableDir.getFileSystem(hadoopConfig) - fs.rename(new Path(tableDir, oldPath), new Path(tableDir, newPath)) + val fs = tablePath.getFileSystem(hadoopConfig) + fs.rename(oldPartPath, newPartPath) } catch { case e: IOException => - throw new SparkException(s"Unable to rename partition path $oldPath", e) + throw new SparkException(s"Unable to rename partition path $oldPartPath", e) } + oldPartition.copy( + spec = newSpec, + storage = oldPartition.storage.copy(locationUri = Some(newPartPath.toString))) + } else { + oldPartition.copy(spec = newSpec) } existingParts.remove(oldSpec) - existingParts.put(newSpec, newPart) + existingParts.put(newSpec, newPartition) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 34748a04859a..93c70de18ae7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -99,6 +99,12 @@ case class CatalogTablePartition( output.filter(_.nonEmpty).mkString("CatalogPartition(\n\t", "\n\t", ")") } + /** Return the partition location, assuming it is specified. */ + def location: String = storage.locationUri.getOrElse { + val specString = spec.map { case (k, v) => s"$k=$v" }.mkString(", ") + throw new AnalysisException(s"Partition [$specString] did not specify locationUri") + } + /** * Given the partition schema, returns a row with that schema holding the partition values. */ @@ -171,6 +177,11 @@ case class CatalogTable( throw new AnalysisException(s"table $identifier did not specify database") } + /** Return the table location, assuming it is specified. */ + def location: String = storage.locationUri.getOrElse { + throw new AnalysisException(s"table $identifier did not specify locationUri") + } + /** Return the fully qualified name of this table, assuming the database was specified. */ def qualifiedName: String = identifier.unquotedString diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 34bdfc8a9871..303a8662d3f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.catalyst.catalog -import java.io.File -import java.net.URI - +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite @@ -320,6 +319,33 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac catalog.createPartitions("db2", "tbl2", Seq(part1), ignoreIfExists = true) } + test("create partitions without location") { + val catalog = newBasicCatalog() + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat(None, None, None, None, false, Map.empty), + schema = new StructType() + .add("col1", "int") + .add("col2", "string") + .add("partCol1", "int") + .add("partCol2", "string"), + provider = Some("hive"), + partitionColumnNames = Seq("partCol1", "partCol2")) + catalog.createTable(table, ignoreIfExists = false) + + val partition = CatalogTablePartition(Map("partCol1" -> "1", "partCol2" -> "2"), storageFormat) + catalog.createPartitions("db1", "tbl", Seq(partition), ignoreIfExists = false) + + val partitionLocation = catalog.getPartition( + "db1", + "tbl", + Map("partCol1" -> "1", "partCol2" -> "2")).location + val tableLocation = catalog.getTable("db1", "tbl").location + val defaultPartitionLocation = new Path(new Path(tableLocation, "partCol1=1"), "partCol2=2") + assert(new Path(partitionLocation) == defaultPartitionLocation) + } + test("list partitions with partial partition spec") { val catalog = newBasicCatalog() val parts = catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "1"))) @@ -399,6 +425,46 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac intercept[AnalysisException] { catalog.getPartition("db2", "tbl2", part2.spec) } } + test("rename partitions should update the location for managed table") { + val catalog = newBasicCatalog() + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat(None, None, None, None, false, Map.empty), + schema = new StructType() + .add("col1", "int") + .add("col2", "string") + .add("partCol1", "int") + .add("partCol2", "string"), + provider = Some("hive"), + partitionColumnNames = Seq("partCol1", "partCol2")) + catalog.createTable(table, ignoreIfExists = false) + + val tableLocation = catalog.getTable("db1", "tbl").location + + val mixedCasePart1 = CatalogTablePartition( + Map("partCol1" -> "1", "partCol2" -> "2"), storageFormat) + val mixedCasePart2 = CatalogTablePartition( + Map("partCol1" -> "3", "partCol2" -> "4"), storageFormat) + + catalog.createPartitions("db1", "tbl", Seq(mixedCasePart1), ignoreIfExists = false) + assert( + new Path(catalog.getPartition("db1", "tbl", mixedCasePart1.spec).location) == + new Path(new Path(tableLocation, "partCol1=1"), "partCol2=2")) + + catalog.renamePartitions("db1", "tbl", Seq(mixedCasePart1.spec), Seq(mixedCasePart2.spec)) + assert( + new Path(catalog.getPartition("db1", "tbl", mixedCasePart2.spec).location) == + new Path(new Path(tableLocation, "partCol1=3"), "partCol2=4")) + + // For external tables, RENAME PARTITION should not update the partition location. + val existingPartLoc = catalog.getPartition("db2", "tbl2", part1.spec).location + catalog.renamePartitions("db2", "tbl2", Seq(part1.spec), Seq(part3.spec)) + assert( + new Path(catalog.getPartition("db2", "tbl2", part3.spec).location) == + new Path(existingPartLoc)) + } + test("rename partitions when database/table does not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { @@ -419,11 +485,6 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("alter partitions") { val catalog = newBasicCatalog() try { - // Note: Before altering table partitions in Hive, you *must* set the current database - // to the one that contains the table of interest. Otherwise you will end up with the - // most helpful error message ever: "Unable to alter partition. alter is not possible." - // See HIVE-2742 for more detail. - catalog.setCurrentDatabase("db2") val newLocation = newUriForDatabase() val newSerde = "com.sparkbricks.text.EasySerde" val newSerdeProps = Map("spark" -> "bricks", "compressed" -> "false") @@ -571,10 +632,11 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac // -------------------------------------------------------------------------- private def exists(uri: String, children: String*): Boolean = { - val base = new File(new URI(uri)) - children.foldLeft(base) { - case (parent, child) => new File(parent, child) - }.exists() + val base = new Path(uri) + val finalPath = children.foldLeft(base) { + case (parent, child) => new Path(parent, child) + } + base.getFileSystem(new Configuration()).exists(finalPath) } test("create/drop database should create/delete the directory") { @@ -623,7 +685,6 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("create/drop/rename partitions should create/delete/rename the directory") { val catalog = newBasicCatalog() - val databaseDir = catalog.getDatabase("db1").locationUri val table = CatalogTable( identifier = TableIdentifier("tbl", Some("db1")), tableType = CatalogTableType.MANAGED, @@ -631,34 +692,61 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac schema = new StructType() .add("col1", "int") .add("col2", "string") - .add("a", "int") - .add("b", "string"), + .add("partCol1", "int") + .add("partCol2", "string"), provider = Some("hive"), - partitionColumnNames = Seq("a", "b") - ) + partitionColumnNames = Seq("partCol1", "partCol2")) catalog.createTable(table, ignoreIfExists = false) + val tableLocation = catalog.getTable("db1", "tbl").location + + val part1 = CatalogTablePartition(Map("partCol1" -> "1", "partCol2" -> "2"), storageFormat) + val part2 = CatalogTablePartition(Map("partCol1" -> "3", "partCol2" -> "4"), storageFormat) + val part3 = CatalogTablePartition(Map("partCol1" -> "5", "partCol2" -> "6"), storageFormat) + catalog.createPartitions("db1", "tbl", Seq(part1, part2), ignoreIfExists = false) - assert(exists(databaseDir, "tbl", "a=1", "b=2")) - assert(exists(databaseDir, "tbl", "a=3", "b=4")) + assert(exists(tableLocation, "partCol1=1", "partCol2=2")) + assert(exists(tableLocation, "partCol1=3", "partCol2=4")) catalog.renamePartitions("db1", "tbl", Seq(part1.spec), Seq(part3.spec)) - assert(!exists(databaseDir, "tbl", "a=1", "b=2")) - assert(exists(databaseDir, "tbl", "a=5", "b=6")) + assert(!exists(tableLocation, "partCol1=1", "partCol2=2")) + assert(exists(tableLocation, "partCol1=5", "partCol2=6")) catalog.dropPartitions("db1", "tbl", Seq(part2.spec, part3.spec), ignoreIfNotExists = false, purge = false) - assert(!exists(databaseDir, "tbl", "a=3", "b=4")) - assert(!exists(databaseDir, "tbl", "a=5", "b=6")) + assert(!exists(tableLocation, "partCol1=3", "partCol2=4")) + assert(!exists(tableLocation, "partCol1=5", "partCol2=6")) - val externalPartition = CatalogTablePartition( - Map("a" -> "7", "b" -> "8"), + val tempPath = Utils.createTempDir() + // create partition with existing directory is OK. + val partWithExistingDir = CatalogTablePartition( + Map("partCol1" -> "7", "partCol2" -> "8"), CatalogStorageFormat( - Some(Utils.createTempDir().getAbsolutePath), - None, None, None, false, Map.empty) - ) - catalog.createPartitions("db1", "tbl", Seq(externalPartition), ignoreIfExists = false) - assert(!exists(databaseDir, "tbl", "a=7", "b=8")) + Some(tempPath.getAbsolutePath), + None, None, None, false, Map.empty)) + catalog.createPartitions("db1", "tbl", Seq(partWithExistingDir), ignoreIfExists = false) + + tempPath.delete() + // create partition with non-existing directory will create that directory. + val partWithNonExistingDir = CatalogTablePartition( + Map("partCol1" -> "9", "partCol2" -> "10"), + CatalogStorageFormat( + Some(tempPath.getAbsolutePath), + None, None, None, false, Map.empty)) + catalog.createPartitions("db1", "tbl", Seq(partWithNonExistingDir), ignoreIfExists = false) + assert(tempPath.exists()) + } + + test("drop partition from external table should not delete the directory") { + val catalog = newBasicCatalog() + catalog.createPartitions("db2", "tbl1", Seq(part1), ignoreIfExists = false) + + val partPath = new Path(catalog.getPartition("db2", "tbl1", part1.spec).location) + val fs = partPath.getFileSystem(new Configuration) + assert(fs.exists(partPath)) + + catalog.dropPartitions("db2", "tbl1", Seq(part1.spec), ignoreIfNotExists = false, purge = false) + assert(fs.exists(partPath)) } } @@ -731,7 +819,7 @@ abstract class CatalogTestUtils { CatalogTable( identifier = TableIdentifier(name, database), tableType = CatalogTableType.EXTERNAL, - storage = storageFormat, + storage = storageFormat.copy(locationUri = Some(Utils.createTempDir().getAbsolutePath)), schema = new StructType() .add("col1", "int") .add("col2", "string") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 001d9c47785d..52385de50db6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -527,13 +527,13 @@ class SessionCatalogSuite extends SparkFunSuite { sessionCatalog.createTable(newTable("tbl", "mydb"), ignoreIfExists = false) sessionCatalog.createPartitions( TableIdentifier("tbl", Some("mydb")), Seq(part1, part2), ignoreIfExists = false) - assert(catalogPartitionsEqual(externalCatalog, "mydb", "tbl", Seq(part1, part2))) + assert(catalogPartitionsEqual(externalCatalog.listPartitions("mydb", "tbl"), part1, part2)) // Create partitions without explicitly specifying database sessionCatalog.setCurrentDatabase("mydb") sessionCatalog.createPartitions( TableIdentifier("tbl"), Seq(partWithMixedOrder), ignoreIfExists = false) assert(catalogPartitionsEqual( - externalCatalog, "mydb", "tbl", Seq(part1, part2, partWithMixedOrder))) + externalCatalog.listPartitions("mydb", "tbl"), part1, part2, partWithMixedOrder)) } test("create partitions when database/table does not exist") { @@ -586,13 +586,13 @@ class SessionCatalogSuite extends SparkFunSuite { test("drop partitions") { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) - assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) + assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) sessionCatalog.dropPartitions( TableIdentifier("tbl2", Some("db2")), Seq(part1.spec), ignoreIfNotExists = false, purge = false) - assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part2))) + assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part2)) // Drop partitions without explicitly specifying database sessionCatalog.setCurrentDatabase("db2") sessionCatalog.dropPartitions( @@ -604,7 +604,7 @@ class SessionCatalogSuite extends SparkFunSuite { // Drop multiple partitions at once sessionCatalog.createPartitions( TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false) - assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) + assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) sessionCatalog.dropPartitions( TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), @@ -844,10 +844,11 @@ class SessionCatalogSuite extends SparkFunSuite { test("list partitions") { val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))).toSet == Set(part1, part2)) + assert(catalogPartitionsEqual( + catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))), part1, part2)) // List partitions without explicitly specifying database catalog.setCurrentDatabase("db2") - assert(catalog.listPartitions(TableIdentifier("tbl2")).toSet == Set(part1, part2)) + assert(catalogPartitionsEqual(catalog.listPartitions(TableIdentifier("tbl2")), part1, part2)) } test("list partitions when database/table does not exist") { @@ -860,6 +861,15 @@ class SessionCatalogSuite extends SparkFunSuite { } } + private def catalogPartitionsEqual( + actualParts: Seq[CatalogTablePartition], + expectedParts: CatalogTablePartition*): Boolean = { + // ExternalCatalog may set a default location for partitions, here we ignore the partition + // location when comparing them. + actualParts.map(p => p.copy(storage = p.storage.copy(locationUri = None))).toSet == + expectedParts.map(p => p.copy(storage = p.storage.copy(locationUri = None))).toSet + } + // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 8500ab460a1b..84a63fdb9f36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.Resolver -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTablePartition, CatalogTableType, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, PartitioningUtils} @@ -500,7 +500,7 @@ case class AlterTableRecoverPartitionsCommand( s"location provided: $tableIdentWithDB") } - val root = new Path(table.storage.locationUri.get) + val root = new Path(table.location) logInfo(s"Recover all the partitions in $root") val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) @@ -558,9 +558,9 @@ case class AlterTableRecoverPartitionsCommand( val name = st.getPath.getName if (st.isDirectory && name.contains("=")) { val ps = name.split("=", 2) - val columnName = PartitioningUtils.unescapePathName(ps(0)) + val columnName = ExternalCatalogUtils.unescapePathName(ps(0)) // TODO: Validate the value - val value = PartitioningUtils.unescapePathName(ps(1)) + val value = ExternalCatalogUtils.unescapePathName(ps(1)) if (resolver(columnName, partitionNames.head)) { scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value), partitionNames.drop(1), threshold, resolver) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index e49a1f5acd0c..119e732d0202 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -710,7 +710,8 @@ case class ShowPartitionsCommand( private def getPartName(spec: TablePartitionSpec, partColNames: Seq[String]): String = { partColNames.map { name => - PartitioningUtils.escapePathName(name) + "=" + PartitioningUtils.escapePathName(spec(name)) + ExternalCatalogUtils.escapePathName(name) + "=" + + ExternalCatalogUtils.escapePathName(spec(name)) }.mkString(File.separator) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 443a2ec033a9..4ad91dcceb43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -67,7 +67,7 @@ class CatalogFileIndex( val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( table.identifier, filters) val partitions = selectedPartitions.map { p => - val path = new Path(p.storage.locationUri.get) + val path = new Path(p.location) val fs = path.getFileSystem(hadoopConf) PartitionPath( p.toRow(partitionSchema), path.makeQualified(fs.getUri, fs.getWorkingDirectory)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2d43a6ad098e..739aeac877b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -190,7 +190,7 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { val effectiveOutputPath = if (overwritingSinglePartition) { val partition = t.sparkSession.sessionState.catalog.getPartition( l.catalogTable.get.identifier, overwrite.specificPartition.get) - new Path(partition.storage.locationUri.get) + new Path(partition.location) } else { outputPath } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index fa7fe143daeb..69b3fa667ef5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning @@ -281,11 +281,11 @@ object FileFormatWriter extends Logging { private def partitionStringExpression: Seq[Expression] = { description.partitionColumns.zipWithIndex.flatMap { case (c, i) => val escaped = ScalaUDF( - PartitioningUtils.escapePathName _, + ExternalCatalogUtils.escapePathName _, StringType, Seq(Cast(c, StringType)), Seq(StringType)) - val str = If(IsNull(c), Literal(PartitioningUtils.DEFAULT_PARTITION_NAME), escaped) + val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index a8a722dd3c62..3740caa22c37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -128,7 +128,6 @@ abstract class PartitioningAwareFileIndex( case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => val spec = PartitioningUtils.parsePartitions( leafDirs, - PartitioningUtils.DEFAULT_PARTITION_NAME, typeInference = false, basePaths = basePaths) @@ -148,7 +147,6 @@ abstract class PartitioningAwareFileIndex( case _ => PartitioningUtils.parsePartitions( leafDirs, - PartitioningUtils.DEFAULT_PARTITION_NAME, typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, basePaths = basePaths) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index b51b41869bf0..a28b04ca3fb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -25,7 +25,6 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.fs.Path -import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow @@ -56,15 +55,15 @@ object PartitionSpec { } object PartitioningUtils { - // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since sql/core doesn't - // depend on Hive. - val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" private[datasources] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) { require(columnNames.size == literals.size) } + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName + /** * Given a group of qualified paths, tries to parse them and returns a partition specification. * For example, given: @@ -90,12 +89,11 @@ object PartitioningUtils { */ private[datasources] def parsePartitions( paths: Seq[Path], - defaultPartitionName: String, typeInference: Boolean, basePaths: Set[Path]): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. val (partitionValues, optDiscoveredBasePaths) = paths.map { path => - parsePartition(path, defaultPartitionName, typeInference, basePaths) + parsePartition(path, typeInference, basePaths) }.unzip // We create pairs of (path -> path's partition value) here @@ -173,7 +171,6 @@ object PartitioningUtils { */ private[datasources] def parsePartition( path: Path, - defaultPartitionName: String, typeInference: Boolean, basePaths: Set[Path]): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] @@ -196,7 +193,7 @@ object PartitioningUtils { // Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1. // Once we get the string, we try to parse it and find the partition column and value. val maybeColumn = - parsePartitionColumn(currentPath.getName, defaultPartitionName, typeInference) + parsePartitionColumn(currentPath.getName, typeInference) maybeColumn.foreach(columns += _) // Now, we determine if we should stop. @@ -228,7 +225,6 @@ object PartitioningUtils { private def parsePartitionColumn( columnSpec: String, - defaultPartitionName: String, typeInference: Boolean): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { @@ -240,7 +236,7 @@ object PartitioningUtils { val rawColumnValue = columnSpec.drop(equalSignIndex + 1) assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") - val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName, typeInference) + val literal = inferPartitionColumnValue(rawColumnValue, typeInference) Some(columnName -> literal) } } @@ -355,7 +351,6 @@ object PartitioningUtils { */ private[datasources] def inferPartitionColumnValue( raw: String, - defaultPartitionName: String, typeInference: Boolean): Literal = { val decimalTry = Try { // `BigDecimal` conversion can fail when the `field` is not a form of number. @@ -380,14 +375,14 @@ object PartitioningUtils { .orElse(Try(Literal(JTimestamp.valueOf(unescapePathName(raw))))) // Then falls back to string .getOrElse { - if (raw == defaultPartitionName) { + if (raw == DEFAULT_PARTITION_NAME) { Literal.create(null, NullType) } else { Literal.create(unescapePathName(raw), StringType) } } } else { - if (raw == defaultPartitionName) { + if (raw == DEFAULT_PARTITION_NAME) { Literal.create(null, NullType) } else { Literal.create(unescapePathName(raw), StringType) @@ -450,77 +445,4 @@ object PartitioningUtils { Literal.create(Cast(l, desiredType).eval(), desiredType) } } - - ////////////////////////////////////////////////////////////////////////////////////////////////// - // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). - ////////////////////////////////////////////////////////////////////////////////////////////////// - - val charToEscape = { - val bitSet = new java.util.BitSet(128) - - /** - * ASCII 01-1F are HTTP control characters that need to be escaped. - * \u000A and \u000D are \n and \r, respectively. - */ - val clist = Array( - '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', - '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', - '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', - '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', - '{', '[', ']', '^') - - clist.foreach(bitSet.set(_)) - - if (Shell.WINDOWS) { - Array(' ', '<', '>', '|').foreach(bitSet.set(_)) - } - - bitSet - } - - def needsEscaping(c: Char): Boolean = { - c >= 0 && c < charToEscape.size() && charToEscape.get(c) - } - - def escapePathName(path: String): String = { - val builder = new StringBuilder() - path.foreach { c => - if (needsEscaping(c)) { - builder.append('%') - builder.append(f"${c.asInstanceOf[Int]}%02X") - } else { - builder.append(c) - } - } - - builder.toString() - } - - def unescapePathName(path: String): String = { - val sb = new StringBuilder - var i = 0 - - while (i < path.length) { - val c = path.charAt(i) - if (c == '%' && i + 2 < path.length) { - val code: Int = try { - Integer.parseInt(path.substring(i + 1, i + 3), 16) - } catch { - case _: Exception => -1 - } - if (code >= 0) { - sb.append(code.asInstanceOf[Char]) - i += 3 - } else { - sb.append(c) - i += 1 - } - } else { - sb.append(c) - i += 1 - } - } - - sb.toString() - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index df3a3c34c39a..363715c6d224 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -875,7 +875,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) val part2 = Map("a" -> "2", "b" -> "6") - val root = new Path(catalog.getTableMetadata(tableIdent).storage.locationUri.get) + val root = new Path(catalog.getTableMetadata(tableIdent).location) val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) @@ -1133,7 +1133,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isDefined) assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty) - assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined) assert(catalog.getPartition(tableIdent, partSpec).storage.properties.isEmpty) // Verify that the location is set to the expected string def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = { @@ -1296,9 +1296,9 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isDefined) assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris")) - assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isDefined) // add partitions without explicitly specifying database catalog.setCurrentDatabase("dbx") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 120a3a2ef33a..22e35a1bc0b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -29,6 +29,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition} @@ -48,11 +49,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha import PartitioningUtils._ import testImplicits._ - val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" + val defaultPartitionName = ExternalCatalogUtils.DEFAULT_PARTITION_NAME test("column type inference") { def check(raw: String, literal: Literal): Unit = { - assert(inferPartitionColumnValue(raw, defaultPartitionName, true) === literal) + assert(inferPartitionColumnValue(raw, true) === literal) } check("10", Literal.create(10, IntegerType)) @@ -76,7 +77,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/a=10.5/b=hello") var exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), defaultPartitionName, true, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), true, Set.empty[Path]) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -88,7 +89,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/path/"))) @@ -101,7 +101,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/path/something=true/table"))) @@ -114,7 +113,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/path/table=true"))) @@ -127,7 +125,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha exception = intercept[AssertionError] { parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/path/"))) } @@ -147,7 +144,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha exception = intercept[AssertionError] { parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, Set(new Path("hdfs://host:9000/tmp/tables/"))) } @@ -156,13 +152,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - val actual = parsePartition(new Path(path), defaultPartitionName, true, Set.empty[Path])._1 + val actual = parsePartition(new Path(path), true, Set.empty[Path])._1 assert(expected === actual) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), defaultPartitionName, true, Set.empty[Path]) + parsePartition(new Path(path), true, Set.empty[Path]) }.getMessage assert(message.contains(expected)) @@ -204,7 +200,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha // when the basePaths is the same as the path to a leaf directory val partitionSpec1: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), - defaultPartitionName = defaultPartitionName, typeInference = true, basePaths = Set(new Path("file://path/a=10")))._1 @@ -213,7 +208,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha // when the basePaths is the path to a base directory of leaf directories val partitionSpec2: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), - defaultPartitionName = defaultPartitionName, typeInference = true, basePaths = Set(new Path("file://path")))._1 @@ -231,7 +225,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val actualSpec = parsePartitions( paths.map(new Path(_)), - defaultPartitionName, true, rootPaths) assert(actualSpec === spec) @@ -314,7 +307,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partitions with type inference disabled") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { val actualSpec = - parsePartitions(paths.map(new Path(_)), defaultPartitionName, false, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), false, Set.empty[Path]) assert(actualSpec === spec) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index b537061d0d22..42ce1a88a2b6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive +import java.io.IOException import java.util import scala.util.control.NonFatal @@ -26,7 +27,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.thrift.TException -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier @@ -255,7 +256,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // compatible format, which means the data source is file-based and must have a `path`. require(tableDefinition.storage.locationUri.isDefined, "External file-based data source table must have a `path` entry in storage properties.") - Some(new Path(tableDefinition.storage.locationUri.get).toUri.toString) + Some(new Path(tableDefinition.location).toUri.toString) } else { None } @@ -789,7 +790,21 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat parts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit = withClient { requireTableExists(db, table) - val lowerCasedParts = parts.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) + + val tableMeta = getTable(db, table) + val partitionColumnNames = tableMeta.partitionColumnNames + val tablePath = new Path(tableMeta.location) + val partsWithLocation = parts.map { p => + // Ideally we can leave the partition location empty and let Hive metastore to set it. + // However, Hive metastore is not case preserving and will generate wrong partition location + // with lower cased partition column names. Here we set the default partition location + // manually to avoid this problem. + val partitionPath = p.storage.locationUri.map(new Path(_)).getOrElse { + ExternalCatalogUtils.generatePartitionPath(p.spec, partitionColumnNames, tablePath) + } + p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toString))) + } + val lowerCasedParts = partsWithLocation.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) client.createPartitions(db, table, lowerCasedParts, ignoreIfExists) } @@ -810,6 +825,31 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat newSpecs: Seq[TablePartitionSpec]): Unit = withClient { client.renamePartitions( db, table, specs.map(lowerCasePartitionSpec), newSpecs.map(lowerCasePartitionSpec)) + + val tableMeta = getTable(db, table) + val partitionColumnNames = tableMeta.partitionColumnNames + // Hive metastore is not case preserving and keeps partition columns with lower cased names. + // When Hive rename partition for managed tables, it will create the partition location with + // a default path generate by the new spec with lower cased partition column names. This is + // unexpected and we need to rename them manually and alter the partition location. + val hasUpperCasePartitionColumn = partitionColumnNames.exists(col => col.toLowerCase != col) + if (tableMeta.tableType == MANAGED && hasUpperCasePartitionColumn) { + val tablePath = new Path(tableMeta.location) + val newParts = newSpecs.map { spec => + val partition = client.getPartition(db, table, lowerCasePartitionSpec(spec)) + val wrongPath = new Path(partition.location) + val rightPath = ExternalCatalogUtils.generatePartitionPath( + spec, partitionColumnNames, tablePath) + try { + tablePath.getFileSystem(hadoopConf).rename(wrongPath, rightPath) + } catch { + case e: IOException => throw new SparkException( + s"Unable to rename partition path from $wrongPath to $rightPath", e) + } + partition.copy(storage = partition.storage.copy(locationUri = Some(rightPath.toString))) + } + alterPartitions(db, table, newParts) + } } override def alterPartitions( @@ -817,6 +857,11 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table: String, newParts: Seq[CatalogTablePartition]): Unit = withClient { val lowerCasedParts = newParts.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) + // Note: Before altering table partitions in Hive, you *must* set the current database + // to the one that contains the table of interest. Otherwise you will end up with the + // most helpful error message ever: "Unable to alter partition. alter is not possible." + // See HIVE-2742 for more detail. + client.setCurrentDatabase(db) client.alterPartitions(db, table, lowerCasedParts) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index d3873cf6c823..fbd705172cae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -445,7 +445,7 @@ object SetWarehouseLocationTest extends Logging { catalog.getTableMetadata(TableIdentifier("testLocation", Some("default"))) val expectedLocation = "file:" + expectedWarehouseLocation.toString + "/testlocation" - val actualLocation = tableMetadata.storage.locationUri.get + val actualLocation = tableMetadata.location if (actualLocation != expectedLocation) { throw new Exception( s"Expected table location is $expectedLocation. But, it is actually $actualLocation") @@ -461,7 +461,7 @@ object SetWarehouseLocationTest extends Logging { catalog.getTableMetadata(TableIdentifier("testLocation", Some("testLocationDB"))) val expectedLocation = "file:" + expectedWarehouseLocation.toString + "/testlocationdb.db/testlocation" - val actualLocation = tableMetadata.storage.locationUri.get + val actualLocation = tableMetadata.location if (actualLocation != expectedLocation) { throw new Exception( s"Expected table location is $expectedLocation. But, it is actually $actualLocation") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index cfc1d81d544e..9f4401ae2256 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -29,7 +29,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle val expectedPath = spark.sharedState.externalCatalog.getDatabase(dbName).locationUri + "/" + tableName - assert(metastoreTable.storage.locationUri.get === expectedPath) + assert(metastoreTable.location === expectedPath) } private def getTableNames(dbName: Option[String] = None): Array[String] = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 0076a778683c..6efae13ddf69 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -425,7 +425,7 @@ class HiveDDLSuite sql("CREATE TABLE tab1 (height INT, length INT) PARTITIONED BY (a INT, b INT)") val part1 = Map("a" -> "1", "b" -> "5") val part2 = Map("a" -> "2", "b" -> "6") - val root = new Path(catalog.getTableMetadata(tableIdent).storage.locationUri.get) + val root = new Path(catalog.getTableMetadata(tableIdent).location) val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c21db3595fa1..e607af67f93e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -542,7 +542,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } userSpecifiedLocation match { case Some(location) => - assert(r.catalogTable.storage.locationUri.get === location) + assert(r.catalogTable.location === location) case None => // OK. } // Also make sure that the format and serde are as desired. From e0deee1f7df31177cfc14bbb296f0baa372f473d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 10 Nov 2016 13:44:54 -0800 Subject: [PATCH 187/381] [SPARK-18403][SQL] Temporarily disable flaky ObjectHashAggregateSuite ## What changes were proposed in this pull request? Randomized tests in `ObjectHashAggregateSuite` is being flaky and breaks PR builds. This PR disables them temporarily to bring back the PR build. ## How was this patch tested? N/A Author: Cheng Lian Closes #15845 from liancheng/ignore-flaky-object-hash-agg-suite. --- .../spark/sql/hive/execution/ObjectHashAggregateSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 93fc5e8a5e37..b7f91d8c3a79 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -326,7 +326,8 @@ class ObjectHashAggregateSuite // Currently Spark SQL doesn't support evaluating distinct aggregate function together // with aggregate functions without partial aggregation support. if (!(aggs.contains(withoutPartial) && aggs.contains(withDistinct))) { - test( + // TODO Re-enables them after fixing SPARK-18403 + ignore( s"randomized aggregation test - " + s"${names.mkString("[", ", ", "]")} - " + s"${if (withGroupingKeys) "with" else "without"} grouping keys - " + From a3356343cbf58b930326f45721fb4ecade6f8029 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 10 Nov 2016 17:00:43 -0800 Subject: [PATCH 188/381] [SPARK-18185] Fix all forms of INSERT / OVERWRITE TABLE for Datasource tables ## What changes were proposed in this pull request? As of current 2.1, INSERT OVERWRITE with dynamic partitions against a Datasource table will overwrite the entire table instead of only the partitions matching the static keys, as in Hive. It also doesn't respect custom partition locations. This PR adds support for all these operations to Datasource tables managed by the Hive metastore. It is implemented as follows - During planning time, the full set of partitions affected by an INSERT or OVERWRITE command is read from the Hive metastore. - The planner identifies any partitions with custom locations and includes this in the write task metadata. - FileFormatWriter tasks refer to this custom locations map when determining where to write for dynamic partition output. - When the write job finishes, the set of written partitions is compared against the initial set of matched partitions, and the Hive metastore is updated to reflect the newly added / removed partitions. It was necessary to introduce a method for staging files with absolute output paths to `FileCommitProtocol`. These files are not handled by the Hadoop output committer but are moved to their final locations when the job commits. The overwrite behavior of legacy Datasource tables is also changed: no longer will the entire table be overwritten if a partial partition spec is present. cc cloud-fan yhuai ## How was this patch tested? Unit tests, existing tests. Author: Eric Liang Author: Wenchen Fan Closes #15814 from ericl/sc-5027. --- .../internal/io/FileCommitProtocol.scala | 15 ++ .../io/HadoopMapReduceCommitProtocol.scala | 63 ++++++- .../sql/catalyst/parser/AstBuilder.scala | 12 +- .../plans/logical/basicLogicalOperators.scala | 10 +- .../sql/catalyst/parser/PlanParserSuite.scala | 4 +- .../execution/datasources/DataSource.scala | 20 ++- .../datasources/DataSourceStrategy.scala | 94 +++++++--- .../datasources/FileFormatWriter.scala | 26 ++- .../InsertIntoHadoopFsRelationCommand.scala | 61 ++++++- .../datasources/PartitioningUtils.scala | 10 ++ .../execution/streaming/FileStreamSink.scala | 2 +- .../ManifestFileCommitProtocol.scala | 6 + .../PartitionProviderCompatibilitySuite.scala | 161 +++++++++++++++++- 13 files changed, 411 insertions(+), 73 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index fb8020585cf8..afd2250c93a8 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -82,9 +82,24 @@ abstract class FileCommitProtocol { * * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest * are left to the commit protocol implementation to decide. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. */ def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + /** + * Similar to newTaskTempFile(), but allows files to committed to an absolute output location. + * Depending on the implementation, there may be weaker guarantees around adding files this way. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String + /** * Commits a task after the writes succeed. Must be called on the executors when running tasks. */ diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 6b0bcb8f908b..b2d9b8d2a012 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -17,7 +17,9 @@ package org.apache.spark.internal.io -import java.util.Date +import java.util.{Date, UUID} + +import scala.collection.mutable import org.apache.hadoop.conf.Configurable import org.apache.hadoop.fs.Path @@ -42,6 +44,19 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) /** OutputCommitter from Hadoop is not serializable so marking it transient. */ @transient private var committer: OutputCommitter = _ + /** + * Tracks files staged by this task for absolute output paths. These outputs are not managed by + * the Hadoop OutputCommitter, so we must move these to their final locations on job commit. + * + * The mapping is from the temp output path to the final desired output path of the file. + */ + @transient private var addedAbsPathFiles: mutable.Map[String, String] = null + + /** + * The staging directory for all files committed with absolute output paths. + */ + private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() // If OutputFormat is Configurable, we should set conf to it. @@ -54,11 +69,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def newTaskTempFile( taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { - // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet - // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, - // the file name is fine and won't overflow. - val split = taskContext.getTaskAttemptID.getTaskID.getId - val filename = f"part-$split%05d-$jobId$ext" + val filename = getFilename(taskContext, ext) val stagingDir: String = committer match { // For FileOutputCommitter it has its own staging path called "work path". @@ -73,6 +84,28 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) } } + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + val filename = getFilename(taskContext, ext) + val absOutputPath = new Path(absoluteDir, filename).toString + + // Include a UUID here to prevent file collisions for one task writing to different dirs. + // In principle we could include hash(absoluteDir) instead but this is simpler. + val tmpOutputPath = new Path( + absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + + addedAbsPathFiles(tmpOutputPath) = absOutputPath + tmpOutputPath + } + + private def getFilename(taskContext: TaskAttemptContext, ext: String): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + f"part-$split%05d-$jobId$ext" + } + override def setupJob(jobContext: JobContext): Unit = { // Setup IDs val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) @@ -93,26 +126,42 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { committer.commitJob(jobContext) + val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) + .foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + for ((src, dst) <- filesToMove) { + fs.rename(new Path(src), new Path(dst)) + } + fs.delete(absPathStagingDir, true) } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(absPathStagingDir, true) } override def setupTask(taskContext: TaskAttemptContext): Unit = { committer = setupCommitter(taskContext) committer.setupTask(taskContext) + addedAbsPathFiles = mutable.Map[String, String]() } override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { val attemptId = taskContext.getTaskAttemptID SparkHadoopMapRedUtil.commitTask( committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) - EmptyTaskCommitMessage + new TaskCommitMessage(addedAbsPathFiles.toMap) } override def abortTask(taskContext: TaskAttemptContext): Unit = { committer.abortTask(taskContext) + // best effort cleanup of other staged files + for ((src, _) <- addedAbsPathFiles) { + val tmp = new Path(src) + tmp.getFileSystem(taskContext.getConfiguration).delete(tmp, false) + } } /** Whether we are using a direct output committer */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2c4db0d2c342..3fa7bf1cdbf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -172,24 +172,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val tableIdent = visitTableIdentifier(ctx.tableIdentifier) val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) - val dynamicPartitionKeys = partitionKeys.filter(_._2.isEmpty) + val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " + "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx) } val overwrite = ctx.OVERWRITE != null - val overwritePartition = - if (overwrite && partitionKeys.nonEmpty && dynamicPartitionKeys.isEmpty) { - Some(partitionKeys.map(t => (t._1, t._2.get))) - } else { - None - } + val staticPartitionKeys: Map[String, String] = + partitionKeys.filter(_._2.nonEmpty).map(t => (t._1, t._2.get)) InsertIntoTable( UnresolvedRelation(tableIdent, None), partitionKeys, query, - OverwriteOptions(overwrite, overwritePartition), + OverwriteOptions(overwrite, if (overwrite) staticPartitionKeys else Map.empty), ctx.EXISTS != null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index dcae7b026f58..4dcc2885536e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -349,13 +349,15 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { * Options for writing new data into a table. * * @param enabled whether to overwrite existing data in the table. - * @param specificPartition only data in the specified partition will be overwritten. + * @param staticPartitionKeys if non-empty, specifies that we only want to overwrite partitions + * that match this partial partition spec. If empty, all partitions + * will be overwritten. */ case class OverwriteOptions( enabled: Boolean, - specificPartition: Option[CatalogTypes.TablePartitionSpec] = None) { - if (specificPartition.isDefined) { - assert(enabled, "Overwrite must be enabled when specifying a partition to overwrite.") + staticPartitionKeys: CatalogTypes.TablePartitionSpec = Map.empty) { + if (staticPartitionKeys.nonEmpty) { + assert(enabled, "Overwrite must be enabled when specifying specific partitions.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 5f0f6ee479c6..9aae520ae664 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -185,9 +185,9 @@ class PlanParserSuite extends PlanTest { OverwriteOptions( overwrite, if (overwrite && partition.nonEmpty) { - Some(partition.map(kv => (kv._1, kv._2.get))) + partition.map(kv => (kv._1, kv._2.get)) } else { - None + Map.empty }), ifNotExists) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 5d663949df6b..65422f1495f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -417,15 +417,17 @@ case class DataSource( // will be adjusted within InsertIntoHadoopFsRelation. val plan = InsertIntoHadoopFsRelationCommand( - outputPath, - columns, - bucketSpec, - format, - _ => Unit, // No existing table needs to be refreshed. - options, - data.logicalPlan, - mode, - catalogTable) + outputPath = outputPath, + staticPartitionKeys = Map.empty, + customPartitionLocations = Map.empty, + partitionColumns = columns, + bucketSpec = bucketSpec, + fileFormat = format, + refreshFunction = _ => Unit, // No existing table needs to be refreshed. + options = options, + query = data.logicalPlan, + mode = mode, + catalogTable = catalogTable) sparkSession.sessionState.executePlan(plan).toRdd // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it. copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 739aeac877b9..4f19a2d00b0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -24,10 +24,10 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTablePartition, SimpleCatalogRelation} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec} +import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -182,41 +182,53 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { "Cannot overwrite a path that is also being read from.") } - val overwritingSinglePartition = - overwrite.specificPartition.isDefined && + val partitionSchema = query.resolve( + t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) + val partitionsTrackedByCatalog = t.sparkSession.sessionState.conf.manageFilesourcePartitions && + l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty && l.catalogTable.get.tracksPartitionsInCatalog - val effectiveOutputPath = if (overwritingSinglePartition) { - val partition = t.sparkSession.sessionState.catalog.getPartition( - l.catalogTable.get.identifier, overwrite.specificPartition.get) - new Path(partition.location) - } else { - outputPath - } - - val effectivePartitionSchema = if (overwritingSinglePartition) { - Nil - } else { - query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) + var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil + var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty + + // When partitions are tracked by the catalog, compute all custom partition locations that + // may be relevant to the insertion job. + if (partitionsTrackedByCatalog) { + val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions( + l.catalogTable.get.identifier, Some(overwrite.staticPartitionKeys)) + initialMatchingPartitions = matchingPartitions.map(_.spec) + customPartitionLocations = getCustomPartitionLocations( + t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions) } + // Callback for updating metastore partition metadata after the insertion job completes. + // TODO(ekl) consider moving this into InsertIntoHadoopFsRelationCommand def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = { - if (l.catalogTable.isDefined && updatedPartitions.nonEmpty && - l.catalogTable.get.partitionColumnNames.nonEmpty && - l.catalogTable.get.tracksPartitionsInCatalog) { - val metastoreUpdater = AlterTableAddPartitionCommand( - l.catalogTable.get.identifier, - updatedPartitions.map(p => (p, None)), - ifNotExists = true) - metastoreUpdater.run(t.sparkSession) + if (partitionsTrackedByCatalog) { + val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions + if (newPartitions.nonEmpty) { + AlterTableAddPartitionCommand( + l.catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), + ifNotExists = true).run(t.sparkSession) + } + if (overwrite.enabled) { + val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions + if (deletedPartitions.nonEmpty) { + AlterTableDropPartitionCommand( + l.catalogTable.get.identifier, deletedPartitions.toSeq, + ifExists = true, purge = true).run(t.sparkSession) + } + } } t.location.refresh() } val insertCmd = InsertIntoHadoopFsRelationCommand( - effectiveOutputPath, - effectivePartitionSchema, + outputPath, + if (overwrite.enabled) overwrite.staticPartitionKeys else Map.empty, + customPartitionLocations, + partitionSchema, t.bucketSpec, t.fileFormat, refreshPartitionsCallback, @@ -227,6 +239,34 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { insertCmd } + + /** + * Given a set of input partitions, returns those that have locations that differ from the + * Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by + * the user. + * + * @return a mapping from partition specs to their custom locations + */ + private def getCustomPartitionLocations( + spark: SparkSession, + table: CatalogTable, + basePath: Path, + partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = { + val hadoopConf = spark.sessionState.newHadoopConf + val fs = basePath.getFileSystem(hadoopConf) + val qualifiedBasePath = basePath.makeQualified(fs.getUri, fs.getWorkingDirectory) + partitions.flatMap { p => + val defaultLocation = qualifiedBasePath.suffix( + "/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString + val catalogLocation = new Path(p.location).makeQualified( + fs.getUri, fs.getWorkingDirectory).toString + if (catalogLocation != defaultLocation) { + Some(p.spec -> catalogLocation) + } else { + None + } + }.toMap + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 69b3fa667ef5..4e4b0e48cd7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -47,6 +47,10 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** A helper object for writing FileFormat data out to a location. */ object FileFormatWriter extends Logging { + /** Describes how output files should be placed in the filesystem. */ + case class OutputSpec( + outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String]) + /** A shared job description for all the write tasks. */ private class WriteJobDescription( val uuid: String, // prevent collision between different (appending) write jobs @@ -56,7 +60,8 @@ object FileFormatWriter extends Logging { val partitionColumns: Seq[Attribute], val nonPartitionColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], - val path: String) + val path: String, + val customPartitionLocations: Map[TablePartitionSpec, String]) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns), @@ -83,7 +88,7 @@ object FileFormatWriter extends Logging { plan: LogicalPlan, fileFormat: FileFormat, committer: FileCommitProtocol, - outputPath: String, + outputSpec: OutputSpec, hadoopConf: Configuration, partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], @@ -93,7 +98,7 @@ object FileFormatWriter extends Logging { val job = Job.getInstance(hadoopConf) job.setOutputKeyClass(classOf[Void]) job.setOutputValueClass(classOf[InternalRow]) - FileOutputFormat.setOutputPath(job, new Path(outputPath)) + FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) val partitionSet = AttributeSet(partitionColumns) val dataColumns = plan.output.filterNot(partitionSet.contains) @@ -111,7 +116,8 @@ object FileFormatWriter extends Logging { partitionColumns = partitionColumns, nonPartitionColumns = dataColumns, bucketSpec = bucketSpec, - path = outputPath) + path = outputSpec.outputPath, + customPartitionLocations = outputSpec.customPartitionLocations) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and @@ -308,7 +314,17 @@ object FileFormatWriter extends Logging { } val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext) - val path = committer.newTaskTempFile(taskAttemptContext, partDir, ext) + val customPath = partDir match { + case Some(dir) => + description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + case _ => + None + } + val path = if (customPath.isDefined) { + committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + } else { + committer.newTaskTempFile(taskAttemptContext, partDir, ext) + } val newWriter = description.outputWriterFactory.newInstance( path = path, dataSchema = description.nonPartitionColumns.toStructType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index a0a8cb5024c3..28975e1546e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.io.IOException -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ @@ -32,19 +32,32 @@ import org.apache.spark.sql.execution.command.RunnableCommand /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. * Writing to dynamic partitions is also supported. + * + * @param staticPartitionKeys partial partitioning spec for write. This defines the scope of + * partition overwrites: when the spec is empty, all partitions are + * overwritten. When it covers a prefix of the partition keys, only + * partitions matching the prefix are overwritten. + * @param customPartitionLocations mapping of partition specs to their custom locations. The + * caller should guarantee that exactly those table partitions + * falling under the specified static partition keys are contained + * in this map, and that no other partitions are. */ case class InsertIntoHadoopFsRelationCommand( outputPath: Path, + staticPartitionKeys: TablePartitionSpec, + customPartitionLocations: Map[TablePartitionSpec, String], partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], fileFormat: FileFormat, - refreshFunction: (Seq[TablePartitionSpec]) => Unit, + refreshFunction: Seq[TablePartitionSpec] => Unit, options: Map[String, String], @transient query: LogicalPlan, mode: SaveMode, catalogTable: Option[CatalogTable]) extends RunnableCommand { + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil override def run(sparkSession: SparkSession): Seq[Row] = { @@ -66,10 +79,7 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.ErrorIfExists, true) => throw new AnalysisException(s"path $qualifiedOutputPath already exists.") case (SaveMode.Overwrite, true) => - if (!fs.delete(qualifiedOutputPath, true /* recursively */)) { - throw new IOException(s"Unable to clear output " + - s"directory $qualifiedOutputPath prior to writing to it") - } + deleteMatchingPartitions(fs, qualifiedOutputPath) true case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => true @@ -93,7 +103,8 @@ case class InsertIntoHadoopFsRelationCommand( plan = query, fileFormat = fileFormat, committer = committer, - outputPath = qualifiedOutputPath.toString, + outputSpec = FileFormatWriter.OutputSpec( + qualifiedOutputPath.toString, customPartitionLocations), hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = bucketSpec, @@ -105,4 +116,40 @@ case class InsertIntoHadoopFsRelationCommand( Seq.empty[Row] } + + /** + * Deletes all partition files that match the specified static prefix. Partitions with custom + * locations are also cleared based on the custom locations map given to this class. + */ + private def deleteMatchingPartitions(fs: FileSystem, qualifiedOutputPath: Path): Unit = { + val staticPartitionPrefix = if (staticPartitionKeys.nonEmpty) { + "/" + partitionColumns.flatMap { p => + staticPartitionKeys.get(p.name) match { + case Some(value) => + Some(escapePathName(p.name) + "=" + escapePathName(value)) + case None => + None + } + }.mkString("/") + } else { + "" + } + // first clear the path determined by the static partition keys (e.g. /table/foo=1) + val staticPrefixPath = qualifiedOutputPath.suffix(staticPartitionPrefix) + if (fs.exists(staticPrefixPath) && !fs.delete(staticPrefixPath, true /* recursively */)) { + throw new IOException(s"Unable to clear output " + + s"directory $staticPrefixPath prior to writing to it") + } + // now clear all custom partition locations (e.g. /custom/dir/where/foo=2/bar=4) + for ((spec, customLoc) <- customPartitionLocations) { + assert( + (staticPartitionKeys.toSet -- spec).isEmpty, + "Custom partition location did not match static partitioning keys") + val path = new Path(customLoc) + if (fs.exists(path) && !fs.delete(path, true)) { + throw new IOException(s"Unable to clear partition " + + s"directory $path prior to writing to it") + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index a28b04ca3fb5..bf9f318780ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -62,6 +62,7 @@ object PartitioningUtils { } import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName /** @@ -252,6 +253,15 @@ object PartitioningUtils { }.toMap } + /** + * This is the inverse of parsePathFragment(). + */ + def getPathFragment(spec: TablePartitionSpec, partitionSchema: StructType): String = { + partitionSchema.map { field => + escapePathName(field.name) + "=" + escapePathName(spec(field.name)) + }.mkString("/") + } + /** * Normalize the column names in partition specification, w.r.t. the real partition column names * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index e849cafef418..f1c5f9ab5067 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -80,7 +80,7 @@ class FileStreamSink( plan = data.logicalPlan, fileFormat = fileFormat, committer = committer, - outputPath = path, + outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala index 1fe13fa1623f..92191c8b64b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala @@ -96,6 +96,12 @@ class ManifestFileCommitProtocol(jobId: String, path: String) file } + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + throw new UnsupportedOperationException( + s"$this does not support adding files with an absolute path") + } + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { if (addedFiles.nonEmpty) { val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index ac435bf6195b..a1aa07456fd3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils class PartitionProviderCompatibilitySuite extends QueryTest with TestHiveSingleton with SQLTestUtils { @@ -135,7 +136,7 @@ class PartitionProviderCompatibilitySuite } } - test("insert overwrite partition of legacy datasource table overwrites entire table") { + test("insert overwrite partition of legacy datasource table") { withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { withTable("test") { withTempDir { dir => @@ -144,9 +145,9 @@ class PartitionProviderCompatibilitySuite """insert overwrite table test |partition (partCol=1) |select * from range(100)""".stripMargin) - assert(spark.sql("select * from test").count() == 100) + assert(spark.sql("select * from test").count() == 104) - // Dynamic partitions case + // Overwriting entire table spark.sql("insert overwrite table test select id, id from range(10)".stripMargin) assert(spark.sql("select * from test").count() == 10) } @@ -186,4 +187,158 @@ class PartitionProviderCompatibilitySuite } } } + + /** + * Runs a test against a multi-level partitioned table, then validates that the custom locations + * were respected by the output writer. + * + * The initial partitioning structure is: + * /P1=0/P2=0 -- custom location a + * /P1=0/P2=1 -- custom location b + * /P1=1/P2=0 -- custom location c + * /P1=1/P2=1 -- default location + */ + private def testCustomLocations(testFn: => Unit): Unit = { + val base = Utils.createTempDir(namePrefix = "base") + val a = Utils.createTempDir(namePrefix = "a") + val b = Utils.createTempDir(namePrefix = "b") + val c = Utils.createTempDir(namePrefix = "c") + try { + spark.sql(s""" + |create table test (id long, P1 int, P2 int) + |using parquet + |options (path "${base.getAbsolutePath}") + |partitioned by (P1, P2)""".stripMargin) + spark.sql(s"alter table test add partition (P1=0, P2=0) location '${a.getAbsolutePath}'") + spark.sql(s"alter table test add partition (P1=0, P2=1) location '${b.getAbsolutePath}'") + spark.sql(s"alter table test add partition (P1=1, P2=0) location '${c.getAbsolutePath}'") + spark.sql(s"alter table test add partition (P1=1, P2=1)") + + testFn + + // Now validate the partition custom locations were respected + val initialCount = spark.sql("select * from test").count() + val numA = spark.sql("select * from test where P1=0 and P2=0").count() + val numB = spark.sql("select * from test where P1=0 and P2=1").count() + val numC = spark.sql("select * from test where P1=1 and P2=0").count() + Utils.deleteRecursively(a) + spark.sql("refresh table test") + assert(spark.sql("select * from test where P1=0 and P2=0").count() == 0) + assert(spark.sql("select * from test").count() == initialCount - numA) + Utils.deleteRecursively(b) + spark.sql("refresh table test") + assert(spark.sql("select * from test where P1=0 and P2=1").count() == 0) + assert(spark.sql("select * from test").count() == initialCount - numA - numB) + Utils.deleteRecursively(c) + spark.sql("refresh table test") + assert(spark.sql("select * from test where P1=1 and P2=0").count() == 0) + assert(spark.sql("select * from test").count() == initialCount - numA - numB - numC) + } finally { + Utils.deleteRecursively(base) + Utils.deleteRecursively(a) + Utils.deleteRecursively(b) + Utils.deleteRecursively(c) + spark.sql("drop table test") + } + } + + test("sanity check table setup") { + testCustomLocations { + assert(spark.sql("select * from test").count() == 0) + assert(spark.sql("show partitions test").count() == 4) + } + } + + test("insert into partial dynamic partitions") { + testCustomLocations { + spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 20) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert into test partition (P1=1, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 30) + assert(spark.sql("show partitions test").count() == 20) + spark.sql("insert into test partition (P1=2, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 40) + assert(spark.sql("show partitions test").count() == 30) + } + } + + test("insert into fully dynamic partitions") { + testCustomLocations { + spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)") + assert(spark.sql("select * from test").count() == 20) + assert(spark.sql("show partitions test").count() == 12) + } + } + + test("insert into static partition") { + testCustomLocations { + spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)") + assert(spark.sql("select * from test").count() == 20) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert into test partition (P1=1, P2=1) select id from range(10)") + assert(spark.sql("select * from test").count() == 30) + assert(spark.sql("show partitions test").count() == 4) + } + } + + test("overwrite partial dynamic partitions") { + testCustomLocations { + spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(5)") + assert(spark.sql("select * from test").count() == 5) + assert(spark.sql("show partitions test").count() == 7) + spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(1)") + assert(spark.sql("select * from test").count() == 1) + assert(spark.sql("show partitions test").count() == 3) + spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 11) + assert(spark.sql("show partitions test").count() == 11) + spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(1)") + assert(spark.sql("select * from test").count() == 2) + assert(spark.sql("show partitions test").count() == 2) + spark.sql("insert overwrite table test partition (P1=3, P2) select id, id from range(100)") + assert(spark.sql("select * from test").count() == 102) + assert(spark.sql("show partitions test").count() == 102) + } + } + + test("overwrite fully dynamic partitions") { + testCustomLocations { + spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 10) + spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(5)") + assert(spark.sql("select * from test").count() == 5) + assert(spark.sql("show partitions test").count() == 5) + } + } + + test("overwrite static partition") { + testCustomLocations { + spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(5)") + assert(spark.sql("select * from test").count() == 5) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert overwrite table test partition (P1=1, P2=1) select id from range(5)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert overwrite table test partition (P1=1, P2=2) select id from range(5)") + assert(spark.sql("select * from test").count() == 15) + assert(spark.sql("show partitions test").count() == 5) + } + } } From 5ddf69470b93c0b8a28bb4ac905e7670d9c50a95 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 10 Nov 2016 17:13:10 -0800 Subject: [PATCH 189/381] [SPARK-18401][SPARKR][ML] SparkR random forest should support output original label. ## What changes were proposed in this pull request? SparkR ```spark.randomForest``` classification prediction should output original label rather than the indexed label. This issue is very similar with [SPARK-18291](https://issues.apache.org/jira/browse/SPARK-18291). ## How was this patch tested? Add unit tests. Author: Yanbo Liang Closes #15842 from yanboliang/spark-18401. --- R/pkg/inst/tests/testthat/test_mllib.R | 24 ++++++++++++++++ .../r/RandomForestClassificationWrapper.scala | 28 ++++++++++++++++--- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 33e9d0d267ac..b76f75dbdc68 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -935,6 +935,10 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numTrees, 20) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") write.ml(model, modelPath) @@ -947,6 +951,26 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numClasses, stats2$numClasses) unlink(modelPath) + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) }) test_that("spark.gbt", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 6947ba7e7597..31f846dc6cfe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -23,9 +23,9 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -35,6 +35,8 @@ private[r] class RandomForestClassifierWrapper private ( val formula: String, val features: Array[String]) extends MLWritable { + import RandomForestClassifierWrapper._ + private val rfcModel: RandomForestClassificationModel = pipeline.stages(1).asInstanceOf[RandomForestClassificationModel] @@ -46,7 +48,9 @@ private[r] class RandomForestClassifierWrapper private ( def summary: String = rfcModel.toDebugString def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset).drop(rfcModel.getFeaturesCol) + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(rfcModel.getFeaturesCol) } override def write: MLWriter = new @@ -54,6 +58,10 @@ private[r] class RandomForestClassifierWrapper private ( } private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] { + + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + def fit( // scalastyle:ignore data: DataFrame, formula: String, @@ -73,6 +81,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC val rFormula = new RFormula() .setFormula(formula) + .setForceIndexLabel(true) RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) @@ -82,6 +91,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC .attributes.get val features = featureAttrs.map(_.name.get) + // get label names from output schema + val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + // assemble and fit the pipeline val rfc = new RandomForestClassifier() .setMaxDepth(maxDepth) @@ -97,10 +111,16 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC .setCacheNodeIds(cacheNodeIds) .setProbabilityCol(probabilityCol) .setFeaturesCol(rFormula.getFeaturesCol) + .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + val pipeline = new Pipeline() - .setStages(Array(rFormulaModel, rfc)) + .setStages(Array(rFormulaModel, rfc, idxToStr)) .fit(data) new RandomForestClassifierWrapper(pipeline, formula, features) From 4f15d94cfec86130f8dab28ae2e228ded8124020 Mon Sep 17 00:00:00 2001 From: Junjie Chen Date: Fri, 11 Nov 2016 10:37:58 -0800 Subject: [PATCH 190/381] [SPARK-13331] AES support for over-the-wire encryption ## What changes were proposed in this pull request? DIGEST-MD5 mechanism is used for SASL authentication and secure communication. DIGEST-MD5 mechanism supports 3DES, DES, and RC4 ciphers. However, 3DES, DES and RC4 are slow relatively. AES provide better performance and security by design and is a replacement for 3DES according to NIST. Apache Common Crypto is a cryptographic library optimized with AES-NI, this patch employ Apache Common Crypto as enc/dec backend for SASL authentication and secure channel to improve spark RPC. ## How was this patch tested? Unit tests and Integration test. Author: Junjie Chen Closes #15172 from cjjnjust/shuffle_rpc_encrypt. --- common/network-common/pom.xml | 4 + .../network/sasl/SaslClientBootstrap.java | 23 +- .../spark/network/sasl/SaslRpcHandler.java | 101 ++++-- .../spark/network/sasl/aes/AesCipher.java | 294 ++++++++++++++++++ .../network/sasl/aes/AesConfigMessage.java | 101 ++++++ .../util/ByteArrayReadableChannel.java | 62 ++++ .../spark/network/util/TransportConf.java | 22 ++ .../spark/network/sasl/SparkSaslSuite.java | 93 +++++- docs/configuration.md | 26 ++ 9 files changed, 689 insertions(+), 37 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index fcefe64d59c9..ca99fa89ebe1 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -76,6 +76,10 @@ guava compile + + org.apache.commons + commons-crypto + diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 9e5c616ee5a1..a1bb45365746 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -30,6 +30,8 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.sasl.aes.AesCipher; +import org.apache.spark.network.sasl.aes.AesConfigMessage; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; @@ -88,9 +90,26 @@ public void doBootstrap(TransportClient client, Channel channel) { throw new RuntimeException( new SaslException("Encryption requests by negotiated non-encrypted connection.")); } - SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); + + if (conf.aesEncryptionEnabled()) { + // Generate a request config message to send to server. + AesConfigMessage configMessage = AesCipher.createConfigMessage(conf); + ByteBuffer buf = configMessage.encodeMessage(); + + // Encrypted the config message. + byte[] toEncrypt = JavaUtils.bufferToArray(buf); + ByteBuffer encrypted = ByteBuffer.wrap(saslClient.wrap(toEncrypt, 0, toEncrypt.length)); + + client.sendRpcSync(encrypted, conf.saslRTTimeoutMs()); + AesCipher cipher = new AesCipher(configMessage, conf); + logger.info("Enabling AES cipher for client channel {}", client); + cipher.addToChannel(channel); + saslClient.dispose(); + } else { + SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); + } saslClient = null; - logger.debug("Channel {} configured for SASL encryption.", client); + logger.debug("Channel {} configured for encryption.", client); } } catch (IOException ioe) { throw new RuntimeException(ioe); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index c41f5b6873f6..b2f3ef214b7a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -29,6 +29,8 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.aes.AesCipher; +import org.apache.spark.network.sasl.aes.AesConfigMessage; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.JavaUtils; @@ -59,6 +61,7 @@ class SaslRpcHandler extends RpcHandler { private SparkSaslServer saslServer; private boolean isComplete; + private boolean isAuthenticated; SaslRpcHandler( TransportConf conf, @@ -71,6 +74,7 @@ class SaslRpcHandler extends RpcHandler { this.secretKeyHolder = secretKeyHolder; this.saslServer = null; this.isComplete = false; + this.isAuthenticated = false; } @Override @@ -80,30 +84,31 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb delegate.receive(client, message, callback); return; } + if (saslServer == null || !saslServer.isComplete()) { + ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + SaslMessage saslMessage; + try { + saslMessage = SaslMessage.decode(nettyBuf); + } finally { + nettyBuf.release(); + } - ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); - SaslMessage saslMessage; - try { - saslMessage = SaslMessage.decode(nettyBuf); - } finally { - nettyBuf.release(); - } - - if (saslServer == null) { - // First message in the handshake, setup the necessary state. - client.setClientId(saslMessage.appId); - saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, - conf.saslServerAlwaysEncrypt()); - } + if (saslServer == null) { + // First message in the handshake, setup the necessary state. + client.setClientId(saslMessage.appId); + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, + conf.saslServerAlwaysEncrypt()); + } - byte[] response; - try { - response = saslServer.response(JavaUtils.bufferToArray( - saslMessage.body().nioByteBuffer())); - } catch (IOException ioe) { - throw new RuntimeException(ioe); + byte[] response; + try { + response = saslServer.response(JavaUtils.bufferToArray( + saslMessage.body().nioByteBuffer())); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + callback.onSuccess(ByteBuffer.wrap(response)); } - callback.onSuccess(ByteBuffer.wrap(response)); // Setup encryption after the SASL response is sent, otherwise the client can't parse the // response. It's ok to change the channel pipeline here since we are processing an incoming @@ -111,15 +116,42 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb // method returns. This assumes that the code ensures, through other means, that no outbound // messages are being written to the channel while negotiation is still going on. if (saslServer.isComplete()) { - logger.debug("SASL authentication successful for channel {}", client); - isComplete = true; - if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { + if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { + logger.debug("SASL authentication successful for channel {}", client); + complete(true); + return; + } + + if (!conf.aesEncryptionEnabled()) { logger.debug("Enabling encryption for channel {}", client); SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); - saslServer = null; - } else { - saslServer.dispose(); - saslServer = null; + complete(false); + return; + } + + // Extra negotiation should happen after authentication, so return directly while + // processing authenticate. + if (!isAuthenticated) { + logger.debug("SASL authentication successful for channel {}", client); + isAuthenticated = true; + return; + } + + // Create AES cipher when it is authenticated + try { + byte[] encrypted = JavaUtils.bufferToArray(message); + ByteBuffer decrypted = ByteBuffer.wrap(saslServer.unwrap(encrypted, 0 , encrypted.length)); + + AesConfigMessage configMessage = AesConfigMessage.decodeMessage(decrypted); + AesCipher cipher = new AesCipher(configMessage, conf); + + // Send response back to client to confirm that server accept config. + callback.onSuccess(JavaUtils.stringToBytes(AesCipher.TRANSFORM)); + logger.info("Enabling AES cipher for Server channel {}", client); + cipher.addToChannel(channel); + complete(true); + } catch (IOException ioe) { + throw new RuntimeException(ioe); } } } @@ -155,4 +187,17 @@ public void exceptionCaught(Throwable cause, TransportClient client) { delegate.exceptionCaught(cause, client); } + private void complete(boolean dispose) { + if (dispose) { + try { + saslServer.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL server", e); + } + } + + saslServer = null; + isComplete = true; + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java new file mode 100644 index 000000000000..78034a69f734 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl.aes; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Properties; +import javax.crypto.spec.SecretKeySpec; +import javax.crypto.spec.IvParameterSpec; + +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.*; +import io.netty.util.AbstractReferenceCounted; +import org.apache.commons.crypto.cipher.CryptoCipherFactory; +import org.apache.commons.crypto.random.CryptoRandom; +import org.apache.commons.crypto.random.CryptoRandomFactory; +import org.apache.commons.crypto.stream.CryptoInputStream; +import org.apache.commons.crypto.stream.CryptoOutputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.util.ByteArrayReadableChannel; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.TransportConf; + +/** + * AES cipher for encryption and decryption. + */ +public class AesCipher { + private static final Logger logger = LoggerFactory.getLogger(AesCipher.class); + public static final String ENCRYPTION_HANDLER_NAME = "AesEncryption"; + public static final String DECRYPTION_HANDLER_NAME = "AesDecryption"; + public static final int STREAM_BUFFER_SIZE = 1024 * 32; + public static final String TRANSFORM = "AES/CTR/NoPadding"; + + private final SecretKeySpec inKeySpec; + private final IvParameterSpec inIvSpec; + private final SecretKeySpec outKeySpec; + private final IvParameterSpec outIvSpec; + private final Properties properties; + + public AesCipher(AesConfigMessage configMessage, TransportConf conf) throws IOException { + this.properties = CryptoStreamUtils.toCryptoConf(conf); + this.inKeySpec = new SecretKeySpec(configMessage.inKey, "AES"); + this.inIvSpec = new IvParameterSpec(configMessage.inIv); + this.outKeySpec = new SecretKeySpec(configMessage.outKey, "AES"); + this.outIvSpec = new IvParameterSpec(configMessage.outIv); + } + + /** + * Create AES crypto output stream + * @param ch The underlying channel to write out. + * @return Return output crypto stream for encryption. + * @throws IOException + */ + private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { + return new CryptoOutputStream(TRANSFORM, properties, ch, outKeySpec, outIvSpec); + } + + /** + * Create AES crypto input stream + * @param ch The underlying channel used to read data. + * @return Return input crypto stream for decryption. + * @throws IOException + */ + private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { + return new CryptoInputStream(TRANSFORM, properties, ch, inKeySpec, inIvSpec); + } + + /** + * Add handlers to channel + * @param ch the channel for adding handlers + * @throws IOException + */ + public void addToChannel(Channel ch) throws IOException { + ch.pipeline() + .addFirst(ENCRYPTION_HANDLER_NAME, new AesEncryptHandler(this)) + .addFirst(DECRYPTION_HANDLER_NAME, new AesDecryptHandler(this)); + } + + /** + * Create the configuration message + * @param conf is the local transport configuration. + * @return Config message for sending. + */ + public static AesConfigMessage createConfigMessage(TransportConf conf) { + int keySize = conf.aesCipherKeySize(); + Properties properties = CryptoStreamUtils.toCryptoConf(conf); + + try { + int paramLen = CryptoCipherFactory.getCryptoCipher(AesCipher.TRANSFORM, properties) + .getBlockSize(); + byte[] inKey = new byte[keySize]; + byte[] outKey = new byte[keySize]; + byte[] inIv = new byte[paramLen]; + byte[] outIv = new byte[paramLen]; + + CryptoRandom random = CryptoRandomFactory.getCryptoRandom(properties); + random.nextBytes(inKey); + random.nextBytes(outKey); + random.nextBytes(inIv); + random.nextBytes(outIv); + + return new AesConfigMessage(inKey, inIv, outKey, outIv); + } catch (Exception e) { + logger.error("AES config error", e); + throw Throwables.propagate(e); + } + } + + /** + * CryptoStreamUtils is used to convert config from TransportConf to AES Crypto config. + */ + private static class CryptoStreamUtils { + public static Properties toCryptoConf(TransportConf conf) { + Properties props = new Properties(); + if (conf.aesCipherClass() != null) { + props.setProperty(CryptoCipherFactory.CLASSES_KEY, conf.aesCipherClass()); + } + return props; + } + } + + private static class AesEncryptHandler extends ChannelOutboundHandlerAdapter { + private final ByteArrayWritableChannel byteChannel; + private final CryptoOutputStream cos; + + AesEncryptHandler(AesCipher cipher) throws IOException { + byteChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE); + cos = cipher.createOutputStream(byteChannel); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + ctx.write(new EncryptedMessage(cos, msg, byteChannel), promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + try { + cos.close(); + } finally { + super.close(ctx, promise); + } + } + } + + private static class AesDecryptHandler extends ChannelInboundHandlerAdapter { + private final CryptoInputStream cis; + private final ByteArrayReadableChannel byteChannel; + + AesDecryptHandler(AesCipher cipher) throws IOException { + byteChannel = new ByteArrayReadableChannel(); + cis = cipher.createInputStream(byteChannel); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + byteChannel.feedData((ByteBuf) data); + + byte[] decryptedData = new byte[byteChannel.readableBytes()]; + int offset = 0; + while (offset < decryptedData.length) { + offset += cis.read(decryptedData, offset, decryptedData.length - offset); + } + + ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + try { + cis.close(); + } finally { + super.channelInactive(ctx); + } + } + } + + private static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion { + private final boolean isByteBuf; + private final ByteBuf buf; + private final FileRegion region; + private long transferred; + private CryptoOutputStream cos; + + // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has + // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data + // from upper handler, another is used to store encrypted data. + private ByteArrayWritableChannel byteEncChannel; + private ByteArrayWritableChannel byteRawChannel; + + private ByteBuffer currentEncrypted; + + EncryptedMessage(CryptoOutputStream cos, Object msg, ByteArrayWritableChannel ch) { + Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, + "Unrecognized message type: %s", msg.getClass().getName()); + this.isByteBuf = msg instanceof ByteBuf; + this.buf = isByteBuf ? (ByteBuf) msg : null; + this.region = isByteBuf ? null : (FileRegion) msg; + this.transferred = 0; + this.byteRawChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE); + this.cos = cos; + this.byteEncChannel = ch; + } + + @Override + public long count() { + return isByteBuf ? buf.readableBytes() : region.count(); + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return transferred; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + Preconditions.checkArgument(position == transfered(), "Invalid position."); + + do { + if (currentEncrypted == null) { + encryptMore(); + } + + int bytesWritten = currentEncrypted.remaining(); + target.write(currentEncrypted); + bytesWritten -= currentEncrypted.remaining(); + transferred += bytesWritten; + if (!currentEncrypted.hasRemaining()) { + currentEncrypted = null; + byteEncChannel.reset(); + } + } while (transferred < count()); + + return transferred; + } + + private void encryptMore() throws IOException { + byteRawChannel.reset(); + + if (isByteBuf) { + int copied = byteRawChannel.write(buf.nioBuffer()); + buf.skipBytes(copied); + } else { + region.transferTo(byteRawChannel, region.transfered()); + } + cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); + cos.flush(); + + currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), + 0, byteEncChannel.length()); + } + + @Override + protected void deallocate() { + byteRawChannel.reset(); + byteEncChannel.reset(); + if (region != null) { + region.release(); + } + if (buf != null) { + buf.release(); + } + } + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java new file mode 100644 index 000000000000..3ef6f74a1f89 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl.aes; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** + * The AES cipher options for encryption negotiation. + */ +public class AesConfigMessage implements Encodable { + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xEB; + + public byte[] inKey; + public byte[] outKey; + public byte[] inIv; + public byte[] outIv; + + public AesConfigMessage(byte[] inKey, byte[] inIv, byte[] outKey, byte[] outIv) { + if (inKey == null || inIv == null || outKey == null || outIv == null) { + throw new IllegalArgumentException("Cipher Key or IV must not be null!"); + } + + this.inKey = inKey; + this.inIv = inIv; + this.outKey = outKey; + this.outIv = outIv; + } + + @Override + public int encodedLength() { + return 1 + + Encoders.ByteArrays.encodedLength(inKey) + Encoders.ByteArrays.encodedLength(outKey) + + Encoders.ByteArrays.encodedLength(inIv) + Encoders.ByteArrays.encodedLength(outIv); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.ByteArrays.encode(buf, inKey); + Encoders.ByteArrays.encode(buf, inIv); + Encoders.ByteArrays.encode(buf, outKey); + Encoders.ByteArrays.encode(buf, outIv); + } + + /** + * Encode the config message. + * @return ByteBuffer which contains encoded config message. + */ + public ByteBuffer encodeMessage(){ + ByteBuffer buf = ByteBuffer.allocate(encodedLength()); + + ByteBuf wrappedBuf = Unpooled.wrappedBuffer(buf); + wrappedBuf.clear(); + encode(wrappedBuf); + + return buf; + } + + /** + * Decode the config message from buffer + * @param buffer the buffer contain encoded config message + * @return config message + */ + public static AesConfigMessage decodeMessage(ByteBuffer buffer) { + ByteBuf buf = Unpooled.wrappedBuffer(buffer); + + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected AesConfigMessage, received something else" + + " (maybe your client does not have AES enabled?)"); + } + + byte[] outKey = Encoders.ByteArrays.decode(buf); + byte[] outIv = Encoders.ByteArrays.decode(buf); + byte[] inKey = Encoders.ByteArrays.decode(buf); + byte[] inIv = Encoders.ByteArrays.decode(buf); + return new AesConfigMessage(inKey, inIv, outKey, outIv); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java new file mode 100644 index 000000000000..25d103d0e316 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; + +import io.netty.buffer.ByteBuf; + +public class ByteArrayReadableChannel implements ReadableByteChannel { + private ByteBuf data; + + public int readableBytes() { + return data.readableBytes(); + } + + public void feedData(ByteBuf buf) { + data = buf; + } + + @Override + public int read(ByteBuffer dst) throws IOException { + int totalRead = 0; + while (data.readableBytes() > 0 && dst.remaining() > 0) { + int bytesToRead = Math.min(data.readableBytes(), dst.remaining()); + dst.put(data.readSlice(bytesToRead).nioBuffer()); + totalRead += bytesToRead; + } + + if (data.readableBytes() == 0) { + data.release(); + } + + return totalRead; + } + + @Override + public void close() throws IOException { + } + + @Override + public boolean isOpen() { + return true; + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 64eaba103ccc..d0d072849d38 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -18,6 +18,7 @@ package org.apache.spark.network.util; import com.google.common.primitives.Ints; +import org.apache.commons.crypto.cipher.CryptoCipherFactory; /** * A central location that tracks all the settings we expose to users. @@ -175,4 +176,25 @@ public boolean saslServerAlwaysEncrypt() { return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); } + /** + * The trigger for enabling AES encryption. + */ + public boolean aesEncryptionEnabled() { + return conf.getBoolean("spark.authenticate.encryption.aes.enabled", false); + } + + /** + * The implementation class for crypto cipher + */ + public String aesCipherClass() { + return conf.get("spark.authenticate.encryption.aes.cipher.class", null); + } + + /** + * The bytes of AES cipher key which is effective when AES cipher is enabled. Notice that + * the length should be 16, 24 or 32 bytes. + */ + public int aesCipherKeySize() { + return conf.getInt("spark.authenticate.encryption.aes.cipher.keySize", 16); + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 45cc03df435a..4e6146cf070d 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -53,6 +53,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.sasl.aes.AesCipher; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; @@ -149,7 +150,7 @@ public Void answer(InvocationOnMock invocation) { .when(rpcHandler) .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); - SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); + SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false, false); try { ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10)); @@ -275,7 +276,7 @@ public ManagedBuffer answer(InvocationOnMock invocation) { new Random().nextBytes(data); Files.write(data, file); - ctx = new SaslTestCtx(rpcHandler, true, false); + ctx = new SaslTestCtx(rpcHandler, true, false, false); final CountDownLatch lock = new CountDownLatch(1); @@ -317,7 +318,7 @@ public void testServerAlwaysEncrypt() throws Exception { SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), false, false); + ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false); fail("Should have failed to connect without encryption."); } catch (Exception e) { assertTrue(e.getCause() instanceof SaslException); @@ -336,7 +337,7 @@ public void testDataEncryptionIsActuallyEnabled() throws Exception { // able to understand RPCs sent to it and thus close the connection. SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); + ctx = new SaslTestCtx(mock(RpcHandler.class), true, true, false); ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10)); fail("Should have failed to send RPC to server."); @@ -374,6 +375,69 @@ public void testDelegates() throws Exception { } } + @Test + public void testAesEncryption() throws Exception { + final AtomicReference response = new AtomicReference<>(); + final File file = File.createTempFile("sasltest", ".txt"); + SaslTestCtx ctx = null; + try { + final TransportConf conf = new TransportConf("rpc", new SystemPropertyConfigProvider()); + final TransportConf spyConf = spy(conf); + doReturn(true).when(spyConf).aesEncryptionEnabled(); + + StreamManager sm = mock(StreamManager.class); + when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { + @Override + public ManagedBuffer answer(InvocationOnMock invocation) { + return new FileSegmentManagedBuffer(spyConf, file, 0, file.length()); + } + }); + + RpcHandler rpcHandler = mock(RpcHandler.class); + when(rpcHandler.getStreamManager()).thenReturn(sm); + + byte[] data = new byte[256 * 1024 * 1024]; + new Random().nextBytes(data); + Files.write(data, file); + + ctx = new SaslTestCtx(rpcHandler, true, false, true); + + final Object lock = new Object(); + + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + response.set((ManagedBuffer) invocation.getArguments()[1]); + response.get().retain(); + synchronized (lock) { + lock.notifyAll(); + } + return null; + } + }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); + + synchronized (lock) { + ctx.client.fetchChunk(0, 0, callback); + lock.wait(10 * 1000); + } + + verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); + verify(callback, never()).onFailure(anyInt(), any(Throwable.class)); + + byte[] received = ByteStreams.toByteArray(response.get().createInputStream()); + assertTrue(Arrays.equals(data, received)); + } finally { + file.delete(); + if (ctx != null) { + ctx.close(); + } + if (response.get() != null) { + response.get().release(); + } + } + } + private static class SaslTestCtx { final TransportClient client; @@ -386,18 +450,28 @@ private static class SaslTestCtx { SaslTestCtx( RpcHandler rpcHandler, boolean encrypt, - boolean disableClientEncryption) + boolean disableClientEncryption, + boolean aesEnable) throws Exception { TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + if (aesEnable) { + conf = spy(conf); + doReturn(true).when(conf).aesEncryptionEnabled(); + } + SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); when(keyHolder.getSecretKey(anyString())).thenReturn("secret"); TransportContext ctx = new TransportContext(conf, rpcHandler); - this.checker = new EncryptionCheckerBootstrap(); + String encryptHandlerName = aesEnable ? AesCipher.ENCRYPTION_HANDLER_NAME : + SaslEncryption.ENCRYPTION_HANDLER_NAME; + + this.checker = new EncryptionCheckerBootstrap(encryptHandlerName); + this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder), checker)); @@ -437,13 +511,18 @@ private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAd implements TransportServerBootstrap { boolean foundEncryptionHandler; + String encryptHandlerName; + + public EncryptionCheckerBootstrap(String encryptHandlerName) { + this.encryptHandlerName = encryptHandlerName; + } @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { if (!foundEncryptionHandler) { foundEncryptionHandler = - ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null; + ctx.channel().pipeline().get(encryptHandlerName) != null; } ctx.write(msg, promise); } diff --git a/docs/configuration.md b/docs/configuration.md index d0acd944dd6b..41c1778ee7fc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1529,6 +1529,32 @@ Apart from these, the following properties are also available, and may be useful currently supported by the external shuffle service. + + spark.authenticate.encryption.aes.enabled + false + + Enable AES for over-the-wire encryption + + + + spark.authenticate.encryption.aes.cipher.keySize + 16 + + The bytes of AES cipher key which is effective when AES cipher is enabled. AES + works with 16, 24 and 32 bytes keys. + + + + spark.authenticate.encryption.aes.cipher.class + null + + Specify the underlying implementation class of crypto cipher. Set null here to use default. + In order to use OpenSslCipher users should install openssl. Currently, there are two cipher + classes available in Commons Crypto library: + org.apache.commons.crypto.cipher.OpenSslCipher + org.apache.commons.crypto.cipher.JceCipher + + spark.core.connection.ack.wait.timeout 60s From a531fe1a82ec515314f2db2e2305283fef24067f Mon Sep 17 00:00:00 2001 From: Vinayak Date: Fri, 11 Nov 2016 12:54:16 -0600 Subject: [PATCH 191/381] [SPARK-17843][WEB UI] Indicate event logs pending for processing on history server UI ## What changes were proposed in this pull request? History Server UI's application listing to display information on currently under process event logs so a user knows that pending this processing an application may not list on the UI. When there are no event logs under process, the application list page has a "Last Updated" date-time at the top indicating the date-time of the last _completed_ scan of the event logs. The value is displayed to the user in his/her local time zone. ## How was this patch tested? All unit tests pass. Particularly all the suites under org.apache.spark.deploy.history.\* were run to test changes. - Very first startup - Pending logs - no logs processed yet: screen shot 2016-10-24 at 3 07 04 pm - Very first startup - Pending logs - some logs processed: screen shot 2016-10-24 at 3 18 42 pm - Last updated - No currently pending logs: screen shot 2016-10-17 at 8 34 37 pm - Last updated - With some currently pending logs: screen shot 2016-10-24 at 3 09 31 pm - No applications found and No currently pending logs: screen shot 2016-10-24 at 3 24 26 pm Author: Vinayak Closes #15410 from vijoshi/SAAS-608_master. --- .../spark/ui/static/historypage-common.js | 24 ++++++++ .../history/ApplicationHistoryProvider.scala | 24 ++++++++ .../deploy/history/FsHistoryProvider.scala | 59 +++++++++++++------ .../spark/deploy/history/HistoryPage.scala | 19 ++++++ .../spark/deploy/history/HistoryServer.scala | 8 +++ 5 files changed, 116 insertions(+), 18 deletions(-) create mode 100644 core/src/main/resources/org/apache/spark/ui/static/historypage-common.js diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js new file mode 100644 index 000000000000..55d540d8317a --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +$(document).ready(function() { + if ($('#last-updated').length) { + var lastUpdatedMillis = Number($('#last-updated').text()); + var updatedDate = new Date(lastUpdatedMillis); + $('#last-updated').text(updatedDate.toLocaleDateString()+", "+updatedDate.toLocaleTimeString()) + } +}); diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 06530ff83646..d7d82800b8b5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -74,6 +74,30 @@ private[history] case class LoadedAppUI( private[history] abstract class ApplicationHistoryProvider { + /** + * Returns the count of application event logs that the provider is currently still processing. + * History Server UI can use this to indicate to a user that the application listing on the UI + * can be expected to list additional known applications once the processing of these + * application event logs completes. + * + * A History Provider that does not have a notion of count of event logs that may be pending + * for processing need not override this method. + * + * @return Count of application event logs that are currently under process + */ + def getEventLogsUnderProcess(): Int = { + return 0; + } + + /** + * Returns the time the history provider last updated the application history information + * + * @return 0 if this is undefined or unsupported, otherwise the last updated time in millis + */ + def getLastUpdatedTime(): Long = { + return 0; + } + /** * Returns a list of applications available for the history server to show. * diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index dfc1aad64c81..ca38a4763942 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{FileNotFoundException, IOException, OutputStream} import java.util.UUID -import java.util.concurrent.{Executors, ExecutorService, TimeUnit} +import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable @@ -108,7 +108,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // The modification time of the newest log detected during the last scan. Currently only // used for logging msgs (logs are re-scanned based on file size, rather than modtime) - private var lastScanTime = -1L + private val lastScanTime = new java.util.concurrent.atomic.AtomicLong(-1) // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted // into the map in order, so the LinkedHashMap maintains the correct ordering. @@ -120,6 +120,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] + private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) + /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -226,6 +228,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) applications.get(appId) } + override def getEventLogsUnderProcess(): Int = pendingReplayTasksCount.get() + + override def getLastUpdatedTime(): Long = lastScanTime.get() + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { try { applications.get(appId).flatMap { appInfo => @@ -329,26 +335,43 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) if (logInfos.nonEmpty) { logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") } - logInfos.map { file => - replayExecutor.submit(new Runnable { + + var tasks = mutable.ListBuffer[Future[_]]() + + try { + for (file <- logInfos) { + tasks += replayExecutor.submit(new Runnable { override def run(): Unit = mergeApplicationListing(file) }) } - .foreach { task => - try { - // Wait for all tasks to finish. This makes sure that checkForLogs - // is not scheduled again while some tasks are already running in - // the replayExecutor. - task.get() - } catch { - case e: InterruptedException => - throw e - case e: Exception => - logError("Exception while merging application listings", e) - } + } catch { + // let the iteration over logInfos break, since an exception on + // replayExecutor.submit (..) indicates the ExecutorService is unable + // to take any more submissions at this time + + case e: Exception => + logError(s"Exception while submitting event log for replay", e) + } + + pendingReplayTasksCount.addAndGet(tasks.size) + + tasks.foreach { task => + try { + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. + task.get() + } catch { + case e: InterruptedException => + throw e + case e: Exception => + logError("Exception while merging application listings", e) + } finally { + pendingReplayTasksCount.decrementAndGet() } + } - lastScanTime = newLastScanTime + lastScanTime.set(newLastScanTime) } catch { case e: Exception => logError("Exception in checking for event log updates", e) } @@ -365,7 +388,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } catch { case e: Exception => logError("Exception encountered when attempting to update last scan time", e) - lastScanTime + lastScanTime.get() } finally { if (!fs.delete(path, true)) { logWarning(s"Error deleting ${path}") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 96b9ecf43b14..0e7a6c24d4fa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -30,13 +30,30 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete) + val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() + val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = +
    {providerConfig.map { case (k, v) =>
  • {k}: {v}
  • }}
+ { + if (eventLogsUnderProcessCount > 0) { +

There are {eventLogsUnderProcessCount} event log(s) currently being + processed which may result in additional applications getting listed on this page. + Refresh the page to view updates.

+ } + } + + { + if (lastUpdatedTime > 0) { +

Last updated: {lastUpdatedTime}

+ } + } + { if (allAppsSize > 0) { ++ @@ -46,6 +63,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } else if (requestedIncomplete) {

No incomplete applications found!

+ } else if (eventLogsUnderProcessCount > 0) { +

No completed applications found!

} else {

No completed applications found!

++ parent.emptyListingHtml } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 3175b36b3e56..7e21fa681aa1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -179,6 +179,14 @@ class HistoryServer( provider.getListing() } + def getEventLogsUnderProcess(): Int = { + provider.getEventLogsUnderProcess() + } + + def getLastUpdatedTime(): Long = { + provider.getLastUpdatedTime() + } + def getApplicationInfoList: Iterator[ApplicationInfo] = { getApplicationList().map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } From d42bb7cc4e32c173769bd7da5b9b5eafb510860c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 11 Nov 2016 13:28:18 -0800 Subject: [PATCH 192/381] [SPARK-17982][SQL] SQLBuilder should wrap the generated SQL with parenthesis for LIMIT ## What changes were proposed in this pull request? Currently, `SQLBuilder` handles `LIMIT` by always adding `LIMIT` at the end of the generated subSQL. It makes `RuntimeException`s like the following. This PR adds a parenthesis always except `SubqueryAlias` is used together with `LIMIT`. **Before** ``` scala scala> sql("CREATE TABLE tbl(id INT)") scala> sql("CREATE VIEW v1(id2) AS SELECT id FROM tbl LIMIT 2") java.lang.RuntimeException: Failed to analyze the canonicalized SQL: ... ``` **After** ``` scala scala> sql("CREATE TABLE tbl(id INT)") scala> sql("CREATE VIEW v1(id2) AS SELECT id FROM tbl LIMIT 2") scala> sql("SELECT id2 FROM v1") res4: org.apache.spark.sql.DataFrame = [id2: int] ``` **Fixed cases in this PR** The following two cases are the detail query plans having problematic SQL generations. 1. `SELECT * FROM (SELECT id FROM tbl LIMIT 2)` Please note that **FROM SELECT** part of the generated SQL in the below. When we don't use '()' for limit, this fails. ```scala # Original logical plan: Project [id#1] +- GlobalLimit 2 +- LocalLimit 2 +- Project [id#1] +- MetastoreRelation default, tbl # Canonicalized logical plan: Project [gen_attr_0#1 AS id#4] +- SubqueryAlias tbl +- Project [gen_attr_0#1] +- GlobalLimit 2 +- LocalLimit 2 +- Project [gen_attr_0#1] +- SubqueryAlias gen_subquery_0 +- Project [id#1 AS gen_attr_0#1] +- SQLTable default, tbl, [id#1] # Generated SQL: SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`tbl`) AS gen_subquery_0 LIMIT 2) AS tbl ``` 2. `SELECT * FROM (SELECT id FROM tbl TABLESAMPLE (2 ROWS))` Please note that **((~~~) AS gen_subquery_0 LIMIT 2)** in the below. When we use '()' for limit on `SubqueryAlias`, this fails. ```scala # Original logical plan: Project [id#1] +- Project [id#1] +- GlobalLimit 2 +- LocalLimit 2 +- MetastoreRelation default, tbl # Canonicalized logical plan: Project [gen_attr_0#1 AS id#4] +- SubqueryAlias tbl +- Project [gen_attr_0#1] +- GlobalLimit 2 +- LocalLimit 2 +- SubqueryAlias gen_subquery_0 +- Project [id#1 AS gen_attr_0#1] +- SQLTable default, tbl, [id#1] # Generated SQL: SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM ((SELECT `id` AS `gen_attr_0` FROM `default`.`tbl`) AS gen_subquery_0 LIMIT 2)) AS tbl ``` ## How was this patch tested? Pass the Jenkins test with a newly added test case. Author: Dongjoon Hyun Closes #15546 from dongjoon-hyun/SPARK-17982. --- .../org/apache/spark/sql/catalyst/SQLBuilder.scala | 7 ++++++- .../test/resources/sqlgen/generate_with_other_1.sql | 2 +- .../test/resources/sqlgen/generate_with_other_2.sql | 2 +- sql/hive/src/test/resources/sqlgen/limit.sql | 4 ++++ .../spark/sql/catalyst/LogicalPlanToSQLSuite.scala | 10 ++++++++++ 5 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 sql/hive/src/test/resources/sqlgen/limit.sql diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala index 6f821f80cc4c..380454267eaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala @@ -138,9 +138,14 @@ class SQLBuilder private ( case g: Generate => generateToSQL(g) - case Limit(limitExpr, child) => + // This prevents a pattern of `((...) AS gen_subquery_0 LIMIT 1)` which does not work. + // For example, `SELECT * FROM (SELECT id FROM tbl TABLESAMPLE (2 ROWS))` makes this plan. + case Limit(limitExpr, child: SubqueryAlias) => s"${toSQL(child)} LIMIT ${limitExpr.sql}" + case Limit(limitExpr, child) => + s"(${toSQL(child)} LIMIT ${limitExpr.sql})" + case Filter(condition, child) => val whereOrHaving = child match { case _: Aggregate => "HAVING" diff --git a/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql b/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql index ab444d0c7093..0739f8fff546 100644 --- a/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql +++ b/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql @@ -5,4 +5,4 @@ WHERE id > 2 ORDER BY val, id LIMIT 5 -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_0.`gen_attr_2`, gen_subquery_0.`gen_attr_3`, gen_subquery_0.`gen_attr_4`, gen_subquery_0.`gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 WHERE (`gen_attr_1` > CAST(2 AS BIGINT))) AS gen_subquery_1 LATERAL VIEW explode(`gen_attr_2`) gen_subquery_2 AS `gen_attr_0` ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5) AS parquet_t3 +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM ((SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_0.`gen_attr_2`, gen_subquery_0.`gen_attr_3`, gen_subquery_0.`gen_attr_4`, gen_subquery_0.`gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 WHERE (`gen_attr_1` > CAST(2 AS BIGINT))) AS gen_subquery_1 LATERAL VIEW explode(`gen_attr_2`) gen_subquery_2 AS `gen_attr_0` ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5)) AS parquet_t3 diff --git a/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql b/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql index 42a2369f34d1..c4b344ee238a 100644 --- a/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql +++ b/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql @@ -7,4 +7,4 @@ WHERE val > 2 ORDER BY val, id LIMIT 5 -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_5`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_3`) gen_subquery_2 AS `gen_attr_2` LATERAL VIEW explode(`gen_attr_2`) gen_subquery_3 AS `gen_attr_0` WHERE (`gen_attr_0` > CAST(2 AS BIGINT)) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5) AS gen_subquery_1 +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM ((SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_5`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_3`) gen_subquery_2 AS `gen_attr_2` LATERAL VIEW explode(`gen_attr_2`) gen_subquery_3 AS `gen_attr_0` WHERE (`gen_attr_0` > CAST(2 AS BIGINT)) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST LIMIT 5)) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/limit.sql b/sql/hive/src/test/resources/sqlgen/limit.sql new file mode 100644 index 000000000000..7a6b060fbf50 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/limit.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT * FROM (SELECT id FROM tbl LIMIT 2) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0`, `name` AS `gen_attr_1` FROM `default`.`tbl`) AS gen_subquery_0 LIMIT 2)) AS tbl diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index 8696337b9dc8..557ea44d1c80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -1173,4 +1173,14 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { ) } } + + test("SPARK-17982 - limit") { + withTable("tbl") { + sql("CREATE TABLE tbl(id INT, name STRING)") + checkSQL( + "SELECT * FROM (SELECT id FROM tbl LIMIT 2)", + "limit" + ) + } + } } From 6e95325fc3726d260054bd6e7c0717b3c139917e Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 11 Nov 2016 13:52:10 -0800 Subject: [PATCH 193/381] [SPARK-18387][SQL] Add serialization to checkEvaluation. ## What changes were proposed in this pull request? This removes the serialization test from RegexpExpressionsSuite and replaces it by serializing all expressions in checkEvaluation. This also fixes math constant expressions by making LeafMathExpression Serializable and fixes NumberFormat values that are null or invalid after serialization. ## How was this patch tested? This patch is to tests. Author: Ryan Blue Closes #15847 from rdblue/SPARK-18387-fix-serializable-expressions. --- .../expressions/mathExpressions.scala | 2 +- .../expressions/stringExpressions.scala | 44 +++++++++++-------- .../expressions/ExpressionEvalHelper.scala | 15 ++++--- .../expressions/RegexpExpressionsSuite.scala | 16 +------ 4 files changed, 36 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index a60494a5bb69..65273a77b105 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -36,7 +36,7 @@ import org.apache.spark.unsafe.types.UTF8String * @param name The short name of the function */ abstract class LeafMathExpression(c: Double, name: String) - extends LeafExpression with CodegenFallback { + extends LeafExpression with CodegenFallback with Serializable { override def dataType: DataType = DoubleType override def foldable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5f533fecf8d0..e74ef9a08750 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1431,18 +1431,20 @@ case class FormatNumber(x: Expression, d: Expression) // Associated with the pattern, for the last d value, and we will update the // pattern (DecimalFormat) once the new coming d value differ with the last one. + // This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after + // serialization (numberFormat has not been updated for dValue = 0). @transient - private var lastDValue: Int = -100 + private var lastDValue: Option[Int] = None // A cached DecimalFormat, for performance concern, we will change it // only if the d value changed. @transient - private val pattern: StringBuffer = new StringBuffer() + private lazy val pattern: StringBuffer = new StringBuffer() // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. @transient - private val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) + private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { val dValue = dObject.asInstanceOf[Int] @@ -1450,24 +1452,28 @@ case class FormatNumber(x: Expression, d: Expression) return null } - if (dValue != lastDValue) { - // construct a new DecimalFormat only if a new dValue - pattern.delete(0, pattern.length) - pattern.append("#,###,###,###,###,###,##0") - - // decimal place - if (dValue > 0) { - pattern.append(".") - - var i = 0 - while (i < dValue) { - i += 1 - pattern.append("0") + lastDValue match { + case Some(last) if last == dValue => + // use the current pattern + case _ => + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length) + pattern.append("#,###,###,###,###,###,##0") + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } } - } - lastDValue = dValue - numberFormat.applyLocalizedPattern(pattern.toString) + lastDValue = Some(dValue) + + numberFormat.applyLocalizedPattern(pattern.toString) } x.dataType match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 9ceb70918541..f83650424a96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -22,7 +22,8 @@ import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.exceptions.TestFailedException import org.scalatest.prop.GeneratorDrivenPropertyChecks -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer @@ -43,13 +44,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val serializer = new JavaSerializer(new SparkConf()).newInstance + val expr: Expression = serializer.deserialize(serializer.serialize(expression)) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) - checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) - checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) - if (GenerateUnsafeProjection.canSupport(expression.dataType)) { - checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) + checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow) } - checkEvaluationWithOptimization(expression, catalystValue, inputRow) + checkEvaluationWithOptimization(expr, catalystValue, inputRow) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index d0d1aaa9d299..5299549e7b4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.StringType @@ -192,17 +191,4 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringSplit(s1, s2), null, row3) } - test("RegExpReplace serialization") { - val serializer = new JavaSerializer(new SparkConf()).newInstance - - val row = create_row("abc", "b", "") - - val s = 's.string.at(0) - val p = 'p.string.at(1) - val r = 'r.string.at(2) - - val expr: RegExpReplace = serializer.deserialize(serializer.serialize(RegExpReplace(s, p, r))) - checkEvaluation(expr, "ac", row) - } - } From ba23f768f7419039df85530b84258ec31f0c22b4 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 11 Nov 2016 15:49:55 -0800 Subject: [PATCH 194/381] [SPARK-18264][SPARKR] build vignettes with package, update vignettes for CRAN release build and add info on release ## What changes were proposed in this pull request? Changes to DESCRIPTION to build vignettes. Changes the metadata for vignettes to generate the recommended format (which is about <10% of size before). Unfortunately it does not look as nice (before - left, after - right) ![image](https://cloud.githubusercontent.com/assets/8969467/20040492/b75883e6-a40d-11e6-9534-25cdd5d59a8b.png) ![image](https://cloud.githubusercontent.com/assets/8969467/20040490/a40f4d42-a40d-11e6-8c91-af00ddcbdad9.png) Also add information on how to run build/release to CRAN later. ## How was this patch tested? manually, unit tests shivaram We need this for branch-2.1 Author: Felix Cheung Closes #15790 from felixcheung/rpkgvignettes. --- R/CRAN_RELEASE.md | 91 ++++++++++++++++++++++++++++ R/README.md | 8 +-- R/check-cran.sh | 33 ++++++++-- R/create-docs.sh | 19 +----- R/pkg/DESCRIPTION | 9 ++- R/pkg/vignettes/sparkr-vignettes.Rmd | 9 +-- 6 files changed, 134 insertions(+), 35 deletions(-) create mode 100644 R/CRAN_RELEASE.md diff --git a/R/CRAN_RELEASE.md b/R/CRAN_RELEASE.md new file mode 100644 index 000000000000..bea8f9fbe4ee --- /dev/null +++ b/R/CRAN_RELEASE.md @@ -0,0 +1,91 @@ +# SparkR CRAN Release + +To release SparkR as a package to CRAN, we would use the `devtools` package. Please work with the +`dev@spark.apache.org` community and R package maintainer on this. + +### Release + +First, check that the `Version:` field in the `pkg/DESCRIPTION` file is updated. Also, check for stale files not under source control. + +Note that while `check-cran.sh` is running `R CMD check`, it is doing so with `--no-manual --no-vignettes`, which skips a few vignettes or PDF checks - therefore it will be preferred to run `R CMD check` on the source package built manually before uploading a release. + +To upload a release, we would need to update the `cran-comments.md`. This should generally contain the results from running the `check-cran.sh` script along with comments on status of all `WARNING` (should not be any) or `NOTE`. As a part of `check-cran.sh` and the release process, the vignettes is build - make sure `SPARK_HOME` is set and Spark jars are accessible. + +Once everything is in place, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::release(); .libPaths(paths) +``` + +For more information please refer to http://r-pkgs.had.co.nz/release.html#release-check + +### Testing: build package manually + +To build package manually such as to inspect the resulting `.tar.gz` file content, we would also use the `devtools` package. + +Source package is what get released to CRAN. CRAN would then build platform-specific binary packages from the source package. + +#### Build source package + +To build source package locally without releasing to CRAN, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg"); .libPaths(paths) +``` + +(http://r-pkgs.had.co.nz/vignettes.html#vignette-workflow-2) + +Similarly, the source package is also created by `check-cran.sh` with `R CMD build pkg`. + +For example, this should be the content of the source package: + +```sh +DESCRIPTION R inst tests +NAMESPACE build man vignettes + +inst/doc/ +sparkr-vignettes.html +sparkr-vignettes.Rmd +sparkr-vignettes.Rman + +build/ +vignette.rds + +man/ + *.Rd files... + +vignettes/ +sparkr-vignettes.Rmd +``` + +#### Test source package + +To install, run this: + +```sh +R CMD INSTALL SparkR_2.1.0.tar.gz +``` + +With "2.1.0" replaced with the version of SparkR. + +This command installs SparkR to the default libPaths. Once that is done, you should be able to start R and run: + +```R +library(SparkR) +vignette("sparkr-vignettes", package="SparkR") +``` + +#### Build binary package + +To build binary package locally, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg", binary = TRUE); .libPaths(paths) +``` + +For example, this should be the content of the binary package: + +```sh +DESCRIPTION Meta R html tests +INDEX NAMESPACE help profile worker +``` diff --git a/R/README.md b/R/README.md index 932d5272d0b4..47f9a86dfde1 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R Libraries of sparkR need to be created in `$SPARK_HOME/R/lib`. This can be done by running the script `$SPARK_HOME/R/install-dev.sh`. By default the above script uses the system wide installation of R. However, this can be changed to any user installed location of R by setting the environment variable `R_HOME` the full path of the base directory where R is installed, before running install-dev.sh script. -Example: +Example: ```bash # where /home/username/R is where R is installed and /home/username/R/bin contains the files R and RScript export R_HOME=/home/username/R @@ -46,7 +46,7 @@ Sys.setenv(SPARK_HOME="/Users/username/spark") # This line loads SparkR from the installed directory .libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) library(SparkR) -sc <- sparkR.init(master="local") +sparkR.session() ``` #### Making changes to SparkR @@ -54,11 +54,11 @@ sc <- sparkR.init(master="local") The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. Once you have made your changes, please include unit tests for them and run existing unit tests using the `R/run-tests.sh` script as described below. - + #### Generating documentation The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. Also, you may need to install these [prerequisites](https://github.com/apache/spark/tree/master/docs#prerequisites). See also, `R/DOCUMENTATION.md` - + ### Examples, Unit tests SparkR comes with several sample programs in the `examples/src/main/r` directory. diff --git a/R/check-cran.sh b/R/check-cran.sh index bb331466ae93..c5f042848c90 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -36,11 +36,27 @@ if [ ! -z "$R_HOME" ] fi echo "USING R_HOME = $R_HOME" -# Build the latest docs +# Build the latest docs, but not vignettes, which is built with the package next $FWDIR/create-docs.sh -# Build a zip file containing the source package -"$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg +# Build source package with vignettes +SPARK_HOME="$(cd "${FWDIR}"/..; pwd)" +. "${SPARK_HOME}"/bin/load-spark-env.sh +if [ -f "${SPARK_HOME}/RELEASE" ]; then + SPARK_JARS_DIR="${SPARK_HOME}/jars" +else + SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +if [ -d "$SPARK_JARS_DIR" ]; then + # Build a zip file containing the source package with vignettes + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg + + find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete +else + echo "Error Spark JARs not found in $SPARK_HOME" + exit 1 +fi # Run check as-cran. VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` @@ -54,11 +70,16 @@ fi if [ -n "$NO_MANUAL" ] then - CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual" + CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual --no-vignettes" fi echo "Running CRAN check with $CRAN_CHECK_OPTIONS options" -"$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz - +if [ -n "$NO_TESTS" ] && [ -n "$NO_MANUAL" ] +then + "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +else + # This will run tests and/or build vignettes, and require SPARK_HOME + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +fi popd > /dev/null diff --git a/R/create-docs.sh b/R/create-docs.sh index 69ffc5f678c3..84e6aa928cb0 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -20,7 +20,7 @@ # Script to create API docs and vignettes for SparkR # This requires `devtools`, `knitr` and `rmarkdown` to be installed on the machine. -# After running this script the html docs can be found in +# After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html # The vignettes can be found in # $SPARK_HOME/R/pkg/vignettes/sparkr_vignettes.html @@ -52,21 +52,4 @@ Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knit popd -# Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then - SPARK_JARS_DIR="${SPARK_HOME}/jars" -else - SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" -fi - -# Only create vignettes if Spark JARs exist -if [ -d "$SPARK_JARS_DIR" ]; then - # render creates SparkR vignettes - Rscript -e 'library(rmarkdown); paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); render("pkg/vignettes/sparkr-vignettes.Rmd"); .libPaths(paths)' - - find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete -else - echo "Skipping R vignettes as Spark JARs not found in $SPARK_HOME" -fi - popd diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 5a83883089e0..fe41a9e7dabb 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,8 +1,8 @@ Package: SparkR Type: Package Title: R Frontend for Apache Spark -Version: 2.0.0 -Date: 2016-08-27 +Version: 2.1.0 +Date: 2016-11-06 Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), email = "shivaram@cs.berkeley.edu"), person("Xiangrui", "Meng", role = "aut", @@ -18,7 +18,9 @@ Depends: Suggests: testthat, e1071, - survival + survival, + knitr, + rmarkdown Description: The SparkR package provides an R frontend for Apache Spark. License: Apache License (== 2.0) Collate: @@ -48,3 +50,4 @@ Collate: 'utils.R' 'window.R' RoxygenNote: 5.0.1 +VignetteBuilder: knitr diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 80e876027bdd..73a5e26a3ba9 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -1,12 +1,13 @@ --- title: "SparkR - Practical Guide" output: - html_document: - theme: united + rmarkdown::html_vignette: toc: true toc_depth: 4 - toc_float: true - highlight: textmate +vignette: > + %\VignetteIndexEntry{SparkR - Practical Guide} + %\VignetteEngine{knitr::rmarkdown} + \usepackage[utf8]{inputenc} --- ## Overview From 46b2550bcd3690a260b995fd4d024a73b92a0299 Mon Sep 17 00:00:00 2001 From: sethah Date: Sat, 12 Nov 2016 01:38:26 +0000 Subject: [PATCH 195/381] [SPARK-18060][ML] Avoid unnecessary computation for MLOR ## What changes were proposed in this pull request? Before this patch, the gradient updates for multinomial logistic regression were computed by an outer loop over the number of classes and an inner loop over the number of features. Inside the inner loop, we standardized the feature value (`value / featuresStd(index)`), which means we performed the computation `numFeatures * numClasses` times. We only need to perform that computation `numFeatures` times, however. If we re-order the inner and outer loop, we can avoid this, but then we lose sequential memory access. In this patch, we instead lay out the coefficients in column major order while we train, so that we can avoid the extra computation and retain sequential memory access. We convert back to row-major order when we create the model. ## How was this patch tested? This is an implementation detail only, so the original behavior should be maintained. All tests pass. I ran some performance tests to verify speedups. The results are below, and show significant speedups. ## Performance Tests **Setup** 3 node bare-metal cluster 120 cores total 384 gb RAM total **Results** NOTE: The `currentMasterTime` and `thisPatchTime` are times in seconds for a single iteration of L-BFGS or OWL-QN. | | numPoints | numFeatures | numClasses | regParam | elasticNetParam | currentMasterTime (sec) | thisPatchTime (sec) | pctSpeedup | |----|-------------|---------------|--------------|------------|-------------------|---------------------------|-----------------------|--------------| | 0 | 1e+07 | 100 | 500 | 0.5 | 0 | 90 | 18 | 80 | | 1 | 1e+08 | 100 | 50 | 0.5 | 0 | 90 | 19 | 78 | | 2 | 1e+08 | 100 | 50 | 0.05 | 1 | 72 | 19 | 73 | | 3 | 1e+06 | 100 | 5000 | 0.5 | 0 | 93 | 53 | 43 | | 4 | 1e+07 | 100 | 5000 | 0.5 | 0 | 900 | 390 | 56 | | 5 | 1e+08 | 100 | 500 | 0.5 | 0 | 840 | 174 | 79 | | 6 | 1e+08 | 100 | 200 | 0.5 | 0 | 360 | 72 | 80 | | 7 | 1e+08 | 1000 | 5 | 0.5 | 0 | 9 | 3 | 66 | Author: sethah Closes #15593 from sethah/MLOR_PERF_COL_MAJOR_COEF. --- .../classification/LogisticRegression.scala | 125 +++++++++++------- 1 file changed, 74 insertions(+), 51 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index c4651054fd76..18b9b3043db8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -438,18 +438,14 @@ class LogisticRegression @Since("1.2.0") ( val standardizationParam = $(standardization) def regParamL1Fun = (index: Int) => { // Remove the L1 penalization on the intercept - val isIntercept = $(fitIntercept) && ((index + 1) % numFeaturesPlusIntercept == 0) + val isIntercept = $(fitIntercept) && index >= numFeatures * numCoefficientSets if (isIntercept) { 0.0 } else { if (standardizationParam) { regParamL1 } else { - val featureIndex = if ($(fitIntercept)) { - index % numFeaturesPlusIntercept - } else { - index % numFeatures - } + val featureIndex = index / numCoefficientSets // If `standardization` is false, we still standardize the data // to improve the rate of convergence; as a result, we have to // perform this reverse standardization by penalizing each component @@ -466,6 +462,15 @@ class LogisticRegression @Since("1.2.0") ( new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) } + /* + The coefficients are laid out in column major order during training. e.g. for + `numClasses = 3` and `numFeatures = 2` and `fitIntercept = true` the layout is: + + Array(beta_11, beta_21, beta_31, beta_12, beta_22, beta_32, intercept_1, intercept_2, + intercept_3) + + where beta_jk corresponds to the coefficient for class `j` and feature `k`. + */ val initialCoefficientsWithIntercept = Vectors.zeros(numCoefficientSets * numFeaturesPlusIntercept) @@ -489,13 +494,14 @@ class LogisticRegression @Since("1.2.0") ( val initialCoefWithInterceptArray = initialCoefficientsWithIntercept.toArray val providedCoef = optInitialModel.get.coefficientMatrix providedCoef.foreachActive { (row, col, value) => - val flatIndex = row * numFeaturesPlusIntercept + col + // convert matrix to column major for training + val flatIndex = col * numCoefficientSets + row // We need to scale the coefficients since they will be trained in the scaled space initialCoefWithInterceptArray(flatIndex) = value * featuresStd(col) } if ($(fitIntercept)) { optInitialModel.get.interceptVector.foreachActive { (index, value) => - val coefIndex = (index + 1) * numFeaturesPlusIntercept - 1 + val coefIndex = numCoefficientSets * numFeatures + index initialCoefWithInterceptArray(coefIndex) = value } } @@ -526,7 +532,7 @@ class LogisticRegression @Since("1.2.0") ( val rawIntercepts = histogram.map(c => math.log(c + 1)) // add 1 for smoothing val rawMean = rawIntercepts.sum / rawIntercepts.length rawIntercepts.indices.foreach { i => - initialCoefficientsWithIntercept.toArray(i * numFeaturesPlusIntercept + numFeatures) = + initialCoefficientsWithIntercept.toArray(numClasses * numFeatures + i) = rawIntercepts(i) - rawMean } } else if ($(fitIntercept)) { @@ -572,16 +578,20 @@ class LogisticRegression @Since("1.2.0") ( /* The coefficients are trained in the scaled space; we're converting them back to the original space. + + Additionally, since the coefficients were laid out in column major order during training + to avoid extra computation, we convert them back to row major before passing them to the + model. + Note that the intercept in scaled space and original space is the same; as a result, no scaling is needed. */ val rawCoefficients = state.x.toArray.clone() val coefficientArray = Array.tabulate(numCoefficientSets * numFeatures) { i => - // flatIndex will loop though rawCoefficients, and skip the intercept terms. - val flatIndex = if ($(fitIntercept)) i + i / numFeatures else i + val colMajorIndex = (i % numFeatures) * numCoefficientSets + i / numFeatures val featureIndex = i % numFeatures if (featuresStd(featureIndex) != 0.0) { - rawCoefficients(flatIndex) / featuresStd(featureIndex) + rawCoefficients(colMajorIndex) / featuresStd(featureIndex) } else { 0.0 } @@ -618,7 +628,7 @@ class LogisticRegression @Since("1.2.0") ( val interceptsArray: Array[Double] = if ($(fitIntercept)) { Array.tabulate(numCoefficientSets) { i => - val coefIndex = (i + 1) * numFeaturesPlusIntercept - 1 + val coefIndex = numFeatures * numCoefficientSets + i rawCoefficients(coefIndex) } } else { @@ -697,6 +707,7 @@ class LogisticRegressionModel private[spark] ( /** * A vector of model coefficients for "binomial" logistic regression. If this model was trained * using the "multinomial" family then an exception is thrown. + * * @return Vector */ @Since("2.0.0") @@ -720,6 +731,7 @@ class LogisticRegressionModel private[spark] ( /** * The model intercept for "binomial" logistic regression. If this model was fit with the * "multinomial" family then an exception is thrown. + * * @return Double */ @Since("1.3.0") @@ -1389,6 +1401,12 @@ class BinaryLogisticRegressionSummary private[classification] ( * $$ *

* + * @note In order to avoid unnecessary computation during calculation of the gradient updates + * we lay out the coefficients in column major order during training. This allows us to + * perform feature standardization once, while still retaining sequential memory access + * for speed. We convert back to row major order when we create the model, + * since this form is optimal for the matrix operations used for prediction. + * * @param bcCoefficients The broadcast coefficients corresponding to the features. * @param bcFeaturesStd The broadcast standard deviation values of the features. * @param numClasses the number of possible outcomes for k classes classification problem in @@ -1486,23 +1504,25 @@ private class LogisticAggregator( var marginOfLabel = 0.0 var maxMargin = Double.NegativeInfinity - val margins = Array.tabulate(numClasses) { i => - var margin = 0.0 - features.foreachActive { (index, value) => - if (localFeaturesStd(index) != 0.0 && value != 0.0) { - margin += localCoefficients(i * numFeaturesPlusIntercept + index) * - value / localFeaturesStd(index) - } + val margins = new Array[Double](numClasses) + features.foreachActive { (index, value) => + val stdValue = value / localFeaturesStd(index) + var j = 0 + while (j < numClasses) { + margins(j) += localCoefficients(index * numClasses + j) * stdValue + j += 1 } - + } + var i = 0 + while (i < numClasses) { if (fitIntercept) { - margin += localCoefficients(i * numFeaturesPlusIntercept + numFeatures) + margins(i) += localCoefficients(numClasses * numFeatures + i) } - if (i == label.toInt) marginOfLabel = margin - if (margin > maxMargin) { - maxMargin = margin + if (i == label.toInt) marginOfLabel = margins(i) + if (margins(i) > maxMargin) { + maxMargin = margins(i) } - margin + i += 1 } /** @@ -1510,33 +1530,39 @@ private class LogisticAggregator( * We address this by subtracting maxMargin from all the margins, so it's guaranteed * that all of the new margins will be smaller than zero to prevent arithmetic overflow. */ + val multipliers = new Array[Double](numClasses) val sum = { var temp = 0.0 - if (maxMargin > 0) { - for (i <- 0 until numClasses) { - margins(i) -= maxMargin - temp += math.exp(margins(i)) - } - } else { - for (i <- 0 until numClasses) { - temp += math.exp(margins(i)) - } + var i = 0 + while (i < numClasses) { + if (maxMargin > 0) margins(i) -= maxMargin + val exp = math.exp(margins(i)) + temp += exp + multipliers(i) = exp + i += 1 } temp } - for (i <- 0 until numClasses) { - val multiplier = math.exp(margins(i)) / sum - { - if (label == i) 1.0 else 0.0 - } - features.foreachActive { (index, value) => - if (localFeaturesStd(index) != 0.0 && value != 0.0) { - localGradientArray(i * numFeaturesPlusIntercept + index) += - weight * multiplier * value / localFeaturesStd(index) + margins.indices.foreach { i => + multipliers(i) = multipliers(i) / sum - (if (label == i) 1.0 else 0.0) + } + features.foreachActive { (index, value) => + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + val stdValue = value / localFeaturesStd(index) + var j = 0 + while (j < numClasses) { + localGradientArray(index * numClasses + j) += + weight * multipliers(j) * stdValue + j += 1 } } - if (fitIntercept) { - localGradientArray(i * numFeaturesPlusIntercept + numFeatures) += weight * multiplier + } + if (fitIntercept) { + var i = 0 + while (i < numClasses) { + localGradientArray(numFeatures * numClasses + i) += weight * multipliers(i) + i += 1 } } @@ -1637,6 +1663,7 @@ private class LogisticCostFun( val bcCoeffs = instances.context.broadcast(coeffs) val featuresStd = bcFeaturesStd.value val numFeatures = featuresStd.length + val numCoefficientSets = if (multinomial) numClasses else 1 val logisticAggregator = { val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) @@ -1656,7 +1683,7 @@ private class LogisticCostFun( var sum = 0.0 coeffs.foreachActive { case (index, value) => // We do not apply regularization to the intercepts - val isIntercept = fitIntercept && ((index + 1) % (numFeatures + 1) == 0) + val isIntercept = fitIntercept && index >= numCoefficientSets * numFeatures if (!isIntercept) { // The following code will compute the loss of the regularization; also // the gradient of the regularization, and add back to totalGradientArray. @@ -1665,11 +1692,7 @@ private class LogisticCostFun( totalGradientArray(index) += regParamL2 * value value * value } else { - val featureIndex = if (fitIntercept) { - index % (numFeatures + 1) - } else { - index % numFeatures - } + val featureIndex = index / numCoefficientSets if (featuresStd(featureIndex) != 0.0) { // If `standardization` is false, we still standardize the data // to improve the rate of convergence; as a result, we have to From 3af894511be6fcc17731e28b284dba432fe911f5 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Fri, 11 Nov 2016 18:36:23 -0800 Subject: [PATCH 196/381] [SPARK-16759][CORE] Add a configuration property to pass caller contexts of upstream applications into Spark ## What changes were proposed in this pull request? Many applications take Spark as a computing engine and run on it. This PR adds a configuration property `spark.log.callerContext` that can be used by Spark's upstream applications (e.g. Oozie) to set up their caller contexts into Spark. In the end, Spark will combine its own caller context with the caller contexts of its upstream applications, and write them into Yarn RM log and HDFS audit log. The audit log has a config to truncate the caller contexts passed in (default 128). The caller contexts will be sent over rpc, so it should be concise. The call context written into HDFS log and Yarn log consists of two parts: the information `A` specified by Spark itself and the value `B` of `spark.log.callerContext` property. Currently `A` typically takes 64 to 74 characters, so `B` can have up to 50 characters (mentioned in the doc `running-on-yarn.md`) ## How was this patch tested? Manual tests. I have run some Spark applications with `spark.log.callerContext` configuration in Yarn client/cluster mode, and verified that the caller contexts were written into Yarn RM log and HDFS audit log correctly. The ways to configure `spark.log.callerContext` property: - In spark-defaults.conf: ``` spark.log.callerContext infoSpecifiedByUpstreamApp ``` - In app's source code: ``` val spark = SparkSession .builder .appName("SparkKMeans") .config("spark.log.callerContext", "infoSpecifiedByUpstreamApp") .getOrCreate() ``` When running on Spark Yarn cluster mode, the driver is unable to pass 'spark.log.callerContext' to Yarn client and AM since Yarn client and AM have already started before the driver performs `.config("spark.log.callerContext", "infoSpecifiedByUpstreamApp")`. The following example shows the command line used to submit a SparkKMeans application and the corresponding records in Yarn RM log and HDFS audit log. Command: ``` ./bin/spark-submit --verbose --executor-cores 3 --num-executors 1 --master yarn --deploy-mode client --class org.apache.spark.examples.SparkKMeans examples/target/original-spark-examples_2.11-2.1.0-SNAPSHOT.jar hdfs://localhost:9000/lr_big.txt 2 5 ``` Yarn RM log: screen shot 2016-10-19 at 9 12 03 pm HDFS audit log: screen shot 2016-10-19 at 10 18 14 pm Author: Weiqing Yang Closes #15563 from weiqingy/SPARK-16759. --- .../spark/internal/config/package.scala | 4 ++ .../org/apache/spark/scheduler/Task.scala | 13 ++++- .../scala/org/apache/spark/util/Utils.scala | 53 ++++++++++++------- docs/configuration.md | 9 ++++ .../spark/deploy/yarn/ApplicationMaster.scala | 3 +- .../org/apache/spark/deploy/yarn/Client.scala | 3 +- 6 files changed, 61 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4a3e3d5c79ef..2951bdc18bc5 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -207,6 +207,10 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val APP_CALLER_CONTEXT = ConfigBuilder("spark.log.callerContext") + .stringConf + .createOptional + private[spark] val FILES_MAX_PARTITION_BYTES = ConfigBuilder("spark.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") .longConf diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 9385e3c31e1e..d39651a72232 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.HashMap import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config.APP_CALLER_CONTEXT import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.serializer.SerializerInstance @@ -92,8 +93,16 @@ private[spark] abstract class Task[T]( kill(interruptThread = false) } - new CallerContext("TASK", appId, appAttemptId, jobId, Option(stageId), Option(stageAttemptId), - Option(taskAttemptId), Option(attemptNumber)).setCurrentContext() + new CallerContext( + "TASK", + SparkEnv.get.conf.get(APP_CALLER_CONTEXT), + appId, + appAttemptId, + jobId, + Option(stageId), + Option(stageAttemptId), + Option(taskAttemptId), + Option(attemptNumber)).setCurrentContext() try { runTask(context) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1de66af632a8..c27cbe319284 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2569,6 +2569,7 @@ private[util] object CallerContext extends Logging { * @param from who sets up the caller context (TASK, CLIENT, APPMASTER) * * The parameters below are optional: + * @param upstreamCallerContext caller context the upstream application passes in * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to * @param jobId id of the job this task belongs to @@ -2578,26 +2579,38 @@ private[util] object CallerContext extends Logging { * @param taskAttemptNumber task attempt id */ private[spark] class CallerContext( - from: String, - appId: Option[String] = None, - appAttemptId: Option[String] = None, - jobId: Option[Int] = None, - stageId: Option[Int] = None, - stageAttemptId: Option[Int] = None, - taskId: Option[Long] = None, - taskAttemptNumber: Option[Int] = None) extends Logging { - - val appIdStr = if (appId.isDefined) s"_${appId.get}" else "" - val appAttemptIdStr = if (appAttemptId.isDefined) s"_${appAttemptId.get}" else "" - val jobIdStr = if (jobId.isDefined) s"_JId_${jobId.get}" else "" - val stageIdStr = if (stageId.isDefined) s"_SId_${stageId.get}" else "" - val stageAttemptIdStr = if (stageAttemptId.isDefined) s"_${stageAttemptId.get}" else "" - val taskIdStr = if (taskId.isDefined) s"_TId_${taskId.get}" else "" - val taskAttemptNumberStr = - if (taskAttemptNumber.isDefined) s"_${taskAttemptNumber.get}" else "" - - val context = "SPARK_" + from + appIdStr + appAttemptIdStr + - jobIdStr + stageIdStr + stageAttemptIdStr + taskIdStr + taskAttemptNumberStr + from: String, + upstreamCallerContext: Option[String] = None, + appId: Option[String] = None, + appAttemptId: Option[String] = None, + jobId: Option[Int] = None, + stageId: Option[Int] = None, + stageAttemptId: Option[Int] = None, + taskId: Option[Long] = None, + taskAttemptNumber: Option[Int] = None) extends Logging { + + private val context = prepareContext("SPARK_" + + from + + appId.map("_" + _).getOrElse("") + + appAttemptId.map("_" + _).getOrElse("") + + jobId.map("_JId_" + _).getOrElse("") + + stageId.map("_SId_" + _).getOrElse("") + + stageAttemptId.map("_" + _).getOrElse("") + + taskId.map("_TId_" + _).getOrElse("") + + taskAttemptNumber.map("_" + _).getOrElse("") + + upstreamCallerContext.map("_" + _).getOrElse("")) + + private def prepareContext(context: String): String = { + // The default max size of Hadoop caller context is 128 + lazy val len = SparkHadoopUtil.get.conf.getInt("hadoop.caller.context.max.size", 128) + if (context == null || context.length <= len) { + context + } else { + val finalContext = context.substring(0, len) + logWarning(s"Truncated Spark caller context from $context to $finalContext") + finalContext + } + } /** * Set up the caller context [[context]] by invoking Hadoop CallerContext API of diff --git a/docs/configuration.md b/docs/configuration.md index 41c1778ee7fc..ea99592408ba 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -202,6 +202,15 @@ of the most common options to set are: or remotely ("cluster") on one of the nodes inside the cluster. + + spark.log.callerContext + (none) + + Application information that will be written into Yarn RM log/HDFS audit log when running on Yarn/HDFS. + Its length depends on the Hadoop configuration hadoop.caller.context.max.size. It should be concise, + and typically can have up to 50 characters. + + Apart from these, the following properties are also available, and may be useful in some situations: diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index f2b9dfb4d184..918cc2dd04ab 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -202,7 +202,8 @@ private[spark] class ApplicationMaster( attemptID = Option(appAttemptId.getAttemptId.toString) } - new CallerContext("APPMASTER", + new CallerContext( + "APPMASTER", sparkConf.get(APP_CALLER_CONTEXT), Option(appAttemptId.getApplicationId.toString), attemptID).setCurrentContext() logInfo("ApplicationAttemptId: " + appAttemptId) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index e77fa386dc93..1b75688b280e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -161,7 +161,8 @@ private[spark] class Client( reportLauncherState(SparkAppHandle.State.SUBMITTED) launcherBackend.setAppId(appId.toString) - new CallerContext("CLIENT", Option(appId.toString)).setCurrentContext() + new CallerContext("CLIENT", sparkConf.get(APP_CALLER_CONTEXT), + Option(appId.toString)).setCurrentContext() // Verify whether the cluster has enough resources for our AM verifyClusterResources(newAppResponse) From bc41d997ea287080f549219722b6d9049adef4e2 Mon Sep 17 00:00:00 2001 From: Guoqiang Li Date: Sat, 12 Nov 2016 09:49:14 +0000 Subject: [PATCH 197/381] [SPARK-18375][SPARK-18383][BUILD][CORE] Upgrade netty to 4.0.42.Final ## What changes were proposed in this pull request? One of the important changes for 4.0.42.Final is "Support any FileRegion implementation when using epoll transport netty/netty#5825". In 4.0.42.Final, `MessageWithHeader` can work properly when `spark.[shuffle|rpc].io.mode` is set to epoll ## How was this patch tested? Existing tests Author: Guoqiang Li Closes #15830 from witgo/SPARK-18375_netty-4.0.42. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 4 ++++ dev/deps/spark-deps-hadoop-2.2 | 2 +- dev/deps/spark-deps-hadoop-2.3 | 2 +- dev/deps/spark-deps-hadoop-2.4 | 2 +- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- pom.xml | 2 +- 7 files changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c27cbe319284..d341982ae9e8 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -39,6 +39,7 @@ import scala.reflect.ClassTag import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} +import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses @@ -2222,6 +2223,9 @@ private[spark] object Utils extends Logging { isBindCollision(e.getCause) case e: MultiException => e.getThrowables.asScala.exists(isBindCollision) + case e: NativeIoException => + (e.getMessage != null && e.getMessage.startsWith("bind() failed: ")) || + isBindCollision(e.getCause) case e: Exception => isBindCollision(e.getCause) case _ => false } diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 6e749ac16cac..bbdea069f949 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -123,7 +123,7 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.3.0.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 515995a0a46b..a2dec41d6451 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -130,7 +130,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index d2139fd95240..c1f02b93d751 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -130,7 +130,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index b5cecf72ec35..4f04636be712 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -138,7 +138,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index a5e03a78e7ea..da3af9ffa155 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -139,7 +139,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/pom.xml b/pom.xml index 8aa0a6c3caab..650b4cd965b6 100644 --- a/pom.xml +++ b/pom.xml @@ -552,7 +552,7 @@ io.netty netty-all - 4.0.41.Final + 4.0.42.Final io.netty From 22cb3a060a440205281b71686637679645454ca6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 12 Nov 2016 06:13:22 -0800 Subject: [PATCH 198/381] [SPARK-14077][ML][FOLLOW-UP] Minor refactor and cleanup for NaiveBayes ## What changes were proposed in this pull request? * Refactor out ```trainWithLabelCheck``` and make ```mllib.NaiveBayes``` call into it. * Avoid capturing the outer object for ```modelType```. * Move ```requireNonnegativeValues``` and ```requireZeroOneBernoulliValues``` to companion object. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #15826 from yanboliang/spark-14077-2. --- .../spark/ml/classification/NaiveBayes.scala | 72 +++++++++---------- .../mllib/classification/NaiveBayes.scala | 6 +- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index b03a07a6bc1e..f1a7676c74b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -76,7 +76,7 @@ class NaiveBayes @Since("1.5.0") ( extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams with DefaultParamsWritable { - import NaiveBayes.{Bernoulli, Multinomial} + import NaiveBayes._ @Since("1.5.0") def this() = this(Identifiable.randomUID("nb")) @@ -110,21 +110,20 @@ class NaiveBayes @Since("1.5.0") ( @Since("2.1.0") def setWeightCol(value: String): this.type = set(weightCol, value) + override protected def train(dataset: Dataset[_]): NaiveBayesModel = { + trainWithLabelCheck(dataset, positiveLabel = true) + } + /** * ml assumes input labels in range [0, numClasses). But this implementation * is also called by mllib NaiveBayes which allows other kinds of input labels - * such as {-1, +1}. Here we use this parameter to switch between different processing logic. - * It should be removed when we remove mllib NaiveBayes. + * such as {-1, +1}. `positiveLabel` is used to determine whether the label + * should be checked and it should be removed when we remove mllib NaiveBayes. */ - private[spark] var isML: Boolean = true - - private[spark] def setIsML(isML: Boolean): this.type = { - this.isML = isML - this - } - - override protected def train(dataset: Dataset[_]): NaiveBayesModel = { - if (isML) { + private[spark] def trainWithLabelCheck( + dataset: Dataset[_], + positiveLabel: Boolean): NaiveBayesModel = { + if (positiveLabel) { val numClasses = getNumClasses(dataset) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -133,28 +132,9 @@ class NaiveBayes @Since("1.5.0") ( } } - val requireNonnegativeValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - - require(values.forall(_ >= 0.0), - s"Naive Bayes requires nonnegative feature values but found $v.") - } - - val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - - require(values.forall(v => v == 0.0 || v == 1.0), - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") - } - + val modelTypeValue = $(modelType) val requireValues: Vector => Unit = { - $(modelType) match { + modelTypeValue match { case Multinomial => requireNonnegativeValues case Bernoulli => @@ -226,13 +206,33 @@ class NaiveBayes @Since("1.5.0") ( @Since("1.6.0") object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { /** String name for multinomial model type. */ - private[spark] val Multinomial: String = "multinomial" + private[classification] val Multinomial: String = "multinomial" /** String name for Bernoulli model type. */ - private[spark] val Bernoulli: String = "bernoulli" + private[classification] val Bernoulli: String = "bernoulli" /* Set of modelTypes that NaiveBayes supports */ - private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli) + private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) + + private[NaiveBayes] def requireNonnegativeValues(v: Vector): Unit = { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(_ >= 0.0), + s"Naive Bayes requires nonnegative feature values but found $v.") + } + + private[NaiveBayes] def requireZeroOneBernoulliValues(v: Vector): Unit = { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(v => v == 0.0 || v == 1.0), + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") + } @Since("1.6.0") override def load(path: String): NaiveBayes = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 33561be4b5bc..767d056861a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -364,12 +364,12 @@ class NaiveBayes private ( val nb = new NewNaiveBayes() .setModelType(modelType) .setSmoothing(lambda) - .setIsML(false) val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) } .toDF("label", "features") - val newModel = nb.fit(dataset) + // mllib NaiveBayes allows input labels like {-1, +1}, so set `positiveLabel` as false. + val newModel = nb.trainWithLabelCheck(dataset, positiveLabel = false) val pi = newModel.pi.toArray val theta = Array.fill[Double](newModel.numClasses, newModel.numFeatures)(0.0) @@ -378,7 +378,7 @@ class NaiveBayes private ( theta(i)(j) = v } - require(newModel.oldLabels != null, + assert(newModel.oldLabels != null, "The underlying ML NaiveBayes training does not produce labels.") new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType) } From 1386fd28daf798bf152606f4da30a36223d75d18 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 12 Nov 2016 14:50:37 -0800 Subject: [PATCH 199/381] [SPARK-18418] Fix flags for make_binary_release for hadoop profile ## What changes were proposed in this pull request? Fix the flags used to specify the hadoop version ## How was this patch tested? Manually tested as part of https://github.com/apache/spark/pull/15659 by having the build succeed. cc joshrosen Author: Holden Karau Closes #15860 from holdenk/minor-fix-release-build-script. --- dev/create-release/release-build.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 96f9b5714ebb..81f0d63054e2 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -187,10 +187,10 @@ if [[ "$1" == "package" ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. FLAGS="-Psparkr -Phive -Phive-thriftserver -Pyarn -Pmesos" - make_binary_release "hadoop2.3" "-Phadoop2.3 $FLAGS" "3033" & - make_binary_release "hadoop2.4" "-Phadoop2.4 $FLAGS" "3034" & - make_binary_release "hadoop2.6" "-Phadoop2.6 $FLAGS" "3035" & - make_binary_release "hadoop2.7" "-Phadoop2.7 $FLAGS" "3036" & + make_binary_release "hadoop2.3" "-Phadoop-2.3 $FLAGS" "3033" & + make_binary_release "hadoop2.4" "-Phadoop-2.4 $FLAGS" "3034" & + make_binary_release "hadoop2.6" "-Phadoop-2.6 $FLAGS" "3035" & + make_binary_release "hadoop2.7" "-Phadoop-2.7 $FLAGS" "3036" & make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn -Pmesos" "3037" & make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn -Pmesos" "3038" & wait From b91a51bb231af321860415075a7f404bc46e0a74 Mon Sep 17 00:00:00 2001 From: Denny Lee Date: Sun, 13 Nov 2016 18:10:06 -0800 Subject: [PATCH 200/381] [SPARK-18426][STRUCTURED STREAMING] Python Documentation Fix for Structured Streaming Programming Guide ## What changes were proposed in this pull request? Update the python section of the Structured Streaming Guide from .builder() to .builder ## How was this patch tested? Validated documentation and successfully running the test example. Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. 'Builder' object is not callable object hence changed .builder() to .builder Author: Denny Lee Closes #15872 from dennyglee/master. --- docs/structured-streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index d838ed35a14f..d2545584ae3b 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -58,7 +58,7 @@ from pyspark.sql.functions import explode from pyspark.sql.functions import split spark = SparkSession \ - .builder() \ + .builder \ .appName("StructuredNetworkWordCount") \ .getOrCreate() {% endhighlight %} From 07be232ea12dfc8dc3701ca948814be7dbebf4ee Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 13 Nov 2016 20:25:12 -0800 Subject: [PATCH 201/381] [SPARK-18412][SPARKR][ML] Fix exception for some SparkR ML algorithms training on libsvm data ## What changes were proposed in this pull request? * Fix the following exceptions which throws when ```spark.randomForest```(classification), ```spark.gbt```(classification), ```spark.naiveBayes``` and ```spark.glm```(binomial family) were fitted on libsvm data. ``` java.lang.IllegalArgumentException: requirement failed: If label column already exists, forceIndexLabel can not be set with true. ``` See [SPARK-18412](https://issues.apache.org/jira/browse/SPARK-18412) for more detail about how to reproduce this bug. * Refactor out ```getFeaturesAndLabels``` to RWrapperUtils, since lots of ML algorithm wrappers use this function. * Drop some unwanted columns when making prediction. ## How was this patch tested? Add unit test. Author: Yanbo Liang Closes #15851 from yanboliang/spark-18412. --- R/pkg/inst/tests/testthat/test_mllib.R | 18 ++++++++-- .../spark/ml/r/GBTClassificationWrapper.scala | 18 ++++------ .../GeneralizedLinearRegressionWrapper.scala | 5 ++- .../apache/spark/ml/r/NaiveBayesWrapper.scala | 14 +++----- .../org/apache/spark/ml/r/RWrapperUtils.scala | 36 ++++++++++++++++--- .../r/RandomForestClassificationWrapper.scala | 18 ++++------ 6 files changed, 68 insertions(+), 41 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index b76f75dbdc68..07df4b6d6f84 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -881,7 +881,8 @@ test_that("spark.kstest", { expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") }) -test_that("spark.randomForest Regression", { +test_that("spark.randomForest", { + # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, numTrees = 1) @@ -923,9 +924,8 @@ test_that("spark.randomForest Regression", { expect_equal(stats$treeWeights, stats2$treeWeights) unlink(modelPath) -}) -test_that("spark.randomForest Classification", { + # classification data <- suppressWarnings(createDataFrame(iris)) model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", maxDepth = 5, maxBins = 16) @@ -971,6 +971,12 @@ test_that("spark.randomForest Classification", { predictions <- collect(predict(model, data))$prediction expect_equal(length(grep("1.0", predictions)), 50) expect_equal(length(grep("2.0", predictions)), 50) + + # spark.randomForest classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.randomForest(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) }) test_that("spark.gbt", { @@ -1039,6 +1045,12 @@ test_that("spark.gbt", { expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) expect_equal(s$numFeatures, 5) expect_equal(s$numTrees, 20) + + # spark.gbt classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), + source = "libsvm") + model <- spark.gbt(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 692) }) sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala index 894602503220..aacb41ee2659 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -23,10 +23,10 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -51,6 +51,7 @@ private[r] class GBTClassifierWrapper private ( pipeline.transform(dataset) .drop(PREDICTED_LABEL_INDEX_COL) .drop(gbtcModel.getFeaturesCol) + .drop(gbtcModel.getLabelCol) } override def write: MLWriter = new @@ -81,19 +82,11 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) - RWrapperUtils.checkDataColumns(rFormula, data) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) - // get feature names from output schema - val schema = rFormulaModel.transform(data).schema - val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) - .attributes.get - val features = featureAttrs.map(_.name.get) - - // get label names from output schema - val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) - .asInstanceOf[NominalAttribute] - val labels = labelAttr.values.get + // get labels and feature names from output schema + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline val rfc = new GBTClassifier() @@ -109,6 +102,7 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] .setMaxMemoryInMB(maxMemoryInMB) .setCacheNodeIds(cacheNodeIds) .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 995b1ef03bce..add4d49110d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -29,6 +29,7 @@ import org.apache.spark.ml.regression._ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ @@ -64,6 +65,7 @@ private[r] class GeneralizedLinearRegressionWrapper private ( .drop(PREDICTED_LABEL_PROB_COL) .drop(PREDICTED_LABEL_INDEX_COL) .drop(glm.getFeaturesCol) + .drop(glm.getLabelCol) } else { pipeline.transform(dataset) .drop(glm.getFeaturesCol) @@ -92,7 +94,7 @@ private[r] object GeneralizedLinearRegressionWrapper regParam: Double): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula().setFormula(formula) if (family == "binomial") rFormula.setForceIndexLabel(true) - RWrapperUtils.checkDataColumns(rFormula, data) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema val schema = rFormulaModel.transform(data).schema @@ -109,6 +111,7 @@ private[r] object GeneralizedLinearRegressionWrapper .setWeightCol(weightCol) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) val pipeline = if (family == "binomial") { // Convert prediction from probability to label index. val probToPred = new ProbabilityToPrediction() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 4fdab2dd9465..0afea4be3d1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -23,9 +23,9 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -46,6 +46,7 @@ private[r] class NaiveBayesWrapper private ( pipeline.transform(dataset) .drop(PREDICTED_LABEL_INDEX_COL) .drop(naiveBayesModel.getFeaturesCol) + .drop(naiveBayesModel.getLabelCol) } override def write: MLWriter = new NaiveBayesWrapper.NaiveBayesWrapperWriter(this) @@ -60,21 +61,16 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) - RWrapperUtils.checkDataColumns(rFormula, data) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema - val schema = rFormulaModel.transform(data).schema - val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) - .asInstanceOf[NominalAttribute] - val labels = labelAttr.values.get - val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) - .attributes.get - val features = featureAttrs.map(_.name.get) + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline val naiveBayes = new NaiveBayes() .setSmoothing(smoothing) .setModelType("bernoulli") .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) val idxToStr = new IndexToString() .setInputCol(PREDICTED_LABEL_INDEX_COL) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala index 379007c4d948..665e50af67d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -18,11 +18,12 @@ package org.apache.spark.ml.r import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} +import org.apache.spark.ml.feature.{RFormula, RFormulaModel} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.Dataset -object RWrapperUtils extends Logging { +private[r] object RWrapperUtils extends Logging { /** * DataFrame column check. @@ -32,14 +33,41 @@ object RWrapperUtils extends Logging { * * @param rFormula RFormula instance * @param data Input dataset - * @return Unit */ def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = { if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) { val newFeaturesName = s"${Identifiable.randomUID(rFormula.getFeaturesCol)}" - logWarning(s"data containing ${rFormula.getFeaturesCol} column, " + + logInfo(s"data containing ${rFormula.getFeaturesCol} column, " + s"using new name $newFeaturesName instead") rFormula.setFeaturesCol(newFeaturesName) } + + if (rFormula.getForceIndexLabel && data.schema.fieldNames.contains(rFormula.getLabelCol)) { + val newLabelName = s"${Identifiable.randomUID(rFormula.getLabelCol)}" + logInfo(s"data containing ${rFormula.getLabelCol} column and we force to index label, " + + s"using new name $newLabelName instead") + rFormula.setLabelCol(newLabelName) + } + } + + /** + * Get the feature names and original labels from the schema + * of DataFrame transformed by RFormulaModel. + * + * @param rFormulaModel The RFormulaModel instance. + * @param data Input dataset. + * @return The feature names and original labels. + */ + def getFeaturesAndLabels( + rFormulaModel: RFormulaModel, + data: Dataset[_]): (Array[String], Array[String]) = { + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + (features, labels) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 31f846dc6cfe..0b860e5af96e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -23,10 +23,10 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -51,6 +51,7 @@ private[r] class RandomForestClassifierWrapper private ( pipeline.transform(dataset) .drop(PREDICTED_LABEL_INDEX_COL) .drop(rfcModel.getFeaturesCol) + .drop(rfcModel.getLabelCol) } override def write: MLWriter = new @@ -82,19 +83,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) - RWrapperUtils.checkDataColumns(rFormula, data) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) - // get feature names from output schema - val schema = rFormulaModel.transform(data).schema - val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) - .attributes.get - val features = featureAttrs.map(_.name.get) - - // get label names from output schema - val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) - .asInstanceOf[NominalAttribute] - val labels = labelAttr.values.get + // get labels and feature names from output schema + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline val rfc = new RandomForestClassifier() @@ -111,6 +104,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC .setCacheNodeIds(cacheNodeIds) .setProbabilityCol(probabilityCol) .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) From f95b124c68ccc2e318f6ac30685aa47770eea8f3 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 14 Nov 2016 16:52:07 +0900 Subject: [PATCH 202/381] [SPARK-18382][WEBUI] "run at null:-1" in UI when no file/line info in call site info ## What changes were proposed in this pull request? Avoid reporting null/-1 file / line number in call sites if encountering StackTraceElement without this info ## How was this patch tested? Existing tests Author: Sean Owen Closes #15862 from srowen/SPARK-18382. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d341982ae9e8..23b95b9f649f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1419,8 +1419,12 @@ private[spark] object Utils extends Logging { } callStack(0) = ste.toString // Put last Spark method on top of the stack trace. } else { - firstUserLine = ste.getLineNumber - firstUserFile = ste.getFileName + if (ste.getFileName != null) { + firstUserFile = ste.getFileName + if (ste.getLineNumber >= 0) { + firstUserLine = ste.getLineNumber + } + } callStack += ste.toString insideSpark = false } From ae6cddb78742be94aa0851ce719f293e0a64ce4f Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Mon, 14 Nov 2016 12:08:06 +0100 Subject: [PATCH 203/381] [SPARK-18166][MLLIB] Fix Poisson GLM bug due to wrong requirement of response values ## What changes were proposed in this pull request? The current implementation of Poisson GLM seems to allow only positive values. This is incorrect since the support of Poisson includes the origin. The bug is easily fixed by changing the test of the Poisson variable from 'require(y **>** 0.0' to 'require(y **>=** 0.0'. mengxr srowen Author: actuaryzhang Author: actuaryzhang Closes #15683 from actuaryzhang/master. --- .../GeneralizedLinearRegression.scala | 4 +- .../GeneralizedLinearRegressionSuite.scala | 45 +++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 1938e8ecc513..1d2961e0277f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -501,8 +501,8 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine val defaultLink: Link = Log override def initialize(y: Double, weight: Double): Double = { - require(y > 0.0, "The response variable of Poisson family " + - s"should be positive, but got $y") + require(y >= 0.0, "The response variable of Poisson family " + + s"should be non-negative, but got $y") y } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 111bc974642d..6a4ac1735b2c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -44,6 +44,7 @@ class GeneralizedLinearRegressionSuite @transient var datasetGaussianInverse: DataFrame = _ @transient var datasetBinomial: DataFrame = _ @transient var datasetPoissonLog: DataFrame = _ + @transient var datasetPoissonLogWithZero: DataFrame = _ @transient var datasetPoissonIdentity: DataFrame = _ @transient var datasetPoissonSqrt: DataFrame = _ @transient var datasetGammaInverse: DataFrame = _ @@ -88,6 +89,12 @@ class GeneralizedLinearRegressionSuite xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "log").toDF() + datasetPoissonLogWithZero = generateGeneralizedLinearRegressionInput( + intercept = -1.5, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 100, seed, noiseLevel = 0.01, + family = "poisson", link = "log") + .map{x => LabeledPoint(if (x.label < 0.7) 0.0 else x.label, x.features)}.toDF() + datasetPoissonIdentity = generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, @@ -139,6 +146,10 @@ class GeneralizedLinearRegressionSuite label + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile( "target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonLog") + datasetPoissonLogWithZero.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonLogWithZero") datasetPoissonIdentity.rdd.map { case Row(label: Double, features: Vector) => label + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile( @@ -456,6 +467,40 @@ class GeneralizedLinearRegressionSuite } } + test("generalized linear regression: poisson family against glm (with zero values)") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="poisson", data=data) + print(as.vector(coef(model))) + } + [1] 0.4272661 -0.1565423 + [1] -3.6911354 0.6214301 0.1295814 + */ + val expected = Seq( + Vectors.dense(0.0, 0.4272661, -0.1565423), + Vectors.dense(-3.6911354, 0.6214301, 0.1295814)) + + import GeneralizedLinearRegression._ + + var idx = 0 + val link = "log" + val dataset = datasetPoissonLogWithZero + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link) + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + + s"$link link and fitIntercept = $fitIntercept (with zero values).") + idx += 1 + } + } + test("generalized linear regression: gamma family against glm") { /* R code: From 637a0bb88f74712001f32a53ff66fd0b8cb67e4a Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Mon, 14 Nov 2016 12:22:36 +0100 Subject: [PATCH 204/381] [SPARK-18396][HISTORYSERVER] Duration" column makes search result confused, maybe we should make it unsearchable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When we search data in History Server, it will check if any columns contains the search string. Duration is represented as long value in table, so if we search simple string like "003", "111", the duration containing "003", ‘111“ will be showed, which make not much sense to users. We cannot simply transfer the long value to meaning format like "1 h", "3.2 min" because they are also used for sorting. Better way to handle it is ban "Duration" columns from searching. ## How was this patch tested manually tests. Before("local-1478225166651" pass the filter because its duration in long value, which is "257244245" contains search string "244"): ![before](https://cloud.githubusercontent.com/assets/5276001/20203166/f851ffc6-a7ff-11e6-8fe6-91a90ca92b23.jpg) After: ![after](https://cloud.githubusercontent.com/assets/5276001/20178646/2129fbb0-a78d-11e6-9edb-39f885ce3ed0.jpg) Author: WangTaoTheTonic Closes #15838 from WangTaoTheTonic/duration. --- .../main/resources/org/apache/spark/ui/static/historypage.js | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 6c0ec8d5fce5..8fd91865b042 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -139,6 +139,9 @@ $(document).ready(function() { {name: 'eighth'}, {name: 'ninth'}, ], + "columnDefs": [ + {"searchable": false, "targets": [5]} + ], "autoWidth": false, "order": [[ 4, "desc" ]] }; From 9d07ceee7860921eafb55b47852f1b51089c98da Mon Sep 17 00:00:00 2001 From: Noritaka Sekiyama Date: Mon, 14 Nov 2016 21:07:59 +0900 Subject: [PATCH 205/381] [SPARK-18432][DOC] Changed HDFS default block size from 64MB to 128MB Changed HDFS default block size from 64MB to 128MB. https://issues.apache.org/jira/browse/SPARK-18432 Author: Noritaka Sekiyama Closes #15879 from moomindani/SPARK-18432. --- docs/programming-guide.md | 6 +++--- docs/tuning.md | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index b9a2110b602a..58bf17b4a84e 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -343,7 +343,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 128MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Scala API also supports several other data formats: @@ -375,7 +375,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 128MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Java API also supports several other data formats: @@ -407,7 +407,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 128MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Python API also supports several other data formats: diff --git a/docs/tuning.md b/docs/tuning.md index 9c43b315bbb9..0de303a3bd9b 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -224,8 +224,8 @@ temporary objects created during task execution. Some steps which may be useful * As an example, if your task is reading data from HDFS, the amount of memory used by the task can be estimated using the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the - size of the block. So if we wish to have 3 or 4 tasks' worth of working space, and the HDFS block size is 64 MB, - we can estimate size of Eden to be `4*3*64MB`. + size of the block. So if we wish to have 3 or 4 tasks' worth of working space, and the HDFS block size is 128 MB, + we can estimate size of Eden to be `4*3*128MB`. * Monitor how the frequency and time taken by garbage collection changes with the new settings. From bdfe60ac921172be0fb77de2f075cc7904a3b238 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 14 Nov 2016 10:03:01 -0800 Subject: [PATCH 206/381] [SPARK-18416][STRUCTURED STREAMING] Fixed temp file leak in state store ## What changes were proposed in this pull request? StateStore.get() causes temporary files to be created immediately, even if the store is not used to make updates for new version. The temp file is not closed as store.commit() is not called in those cases, thus keeping the output stream to temp file open forever. This PR fixes it by opening the temp file only when there are updates being made. ## How was this patch tested? New unit test Author: Tathagata Das Closes #15859 from tdas/SPARK-18416. --- .../state/HDFSBackedStateStoreProvider.scala | 10 +-- .../streaming/state/StateStoreSuite.scala | 63 +++++++++++++++++++ 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 808713161c31..f07feaad5dc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -87,8 +87,7 @@ private[state] class HDFSBackedStateStoreProvider( private val newVersion = version + 1 private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") - private val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) - + private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]() @volatile private var state: STATE = UPDATING @@ -101,7 +100,7 @@ private[state] class HDFSBackedStateStoreProvider( } override def put(key: UnsafeRow, value: UnsafeRow): Unit = { - verify(state == UPDATING, "Cannot remove after already committed or aborted") + verify(state == UPDATING, "Cannot put after already committed or aborted") val isNewKey = !mapToUpdate.containsKey(key) mapToUpdate.put(key, value) @@ -125,6 +124,7 @@ private[state] class HDFSBackedStateStoreProvider( /** Remove keys that match the following condition */ override def remove(condition: UnsafeRow => Boolean): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") + val keyIter = mapToUpdate.keySet().iterator() while (keyIter.hasNext) { val key = keyIter.next @@ -154,7 +154,7 @@ private[state] class HDFSBackedStateStoreProvider( finalizeDeltaFile(tempDeltaFileStream) finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) state = COMMITTED - logInfo(s"Committed version $newVersion for $this") + logInfo(s"Committed version $newVersion for $this to file $finalDeltaFile") newVersion } catch { case NonFatal(e) => @@ -174,7 +174,7 @@ private[state] class HDFSBackedStateStoreProvider( if (tempDeltaFile != null) { fs.delete(tempDeltaFile, true) } - logInfo("Aborted") + logInfo(s"Aborted version $newVersion for $this") } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 504a26516107..533cd0cd2a2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -468,6 +468,69 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(e.getCause.getMessage.contains("Failed to rename")) } + test("SPARK-18416: do not create temp delta file until the store is updated") { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(dir, 0, 0) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + val deltaFileDir = new File(s"$dir/0/0/") + + def numTempFiles: Int = { + if (deltaFileDir.exists) { + deltaFileDir.listFiles.map(_.getName).count(n => n.contains("temp") && !n.startsWith(".")) + } else 0 + } + + def numDeltaFiles: Int = { + if (deltaFileDir.exists) { + deltaFileDir.listFiles.map(_.getName).count(n => n.contains(".delta") && !n.startsWith(".")) + } else 0 + } + + def shouldNotCreateTempFile[T](body: => T): T = { + val before = numTempFiles + val result = body + assert(numTempFiles === before) + result + } + + // Getting the store should not create temp file + val store0 = shouldNotCreateTempFile { + StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + } + + // Put should create a temp file + put(store0, "a", 1) + assert(numTempFiles === 1) + assert(numDeltaFiles === 0) + + // Commit should remove temp file and create a delta file + store0.commit() + assert(numTempFiles === 0) + assert(numDeltaFiles === 1) + + // Remove should create a temp file + val store1 = shouldNotCreateTempFile { + StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + } + remove(store1, _ == "a") + assert(numTempFiles === 1) + assert(numDeltaFiles === 1) + + // Commit should remove temp file and create a delta file + store1.commit() + assert(numTempFiles === 0) + assert(numDeltaFiles === 2) + + // Commit without any updates should create a delta file + val store2 = shouldNotCreateTempFile { + StateStore.get(storeId, keySchema, valueSchema, 2, storeConf, hadoopConf) + } + store2.commit() + assert(numTempFiles === 0) + assert(numDeltaFiles === 3) + } + def getDataFromFiles( provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { From 89d1fa58dbe88560b1f2b0362fcc3035ccc888be Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Mon, 14 Nov 2016 11:10:37 -0800 Subject: [PATCH 207/381] [SPARK-17510][STREAMING][KAFKA] config max rate on a per-partition basis ## What changes were proposed in this pull request? Allow configuration of max rate on a per-topicpartition basis. ## How was this patch tested? Unit tests. The reporter (Jeff Nadler) said he could test on his workload, so let's wait on that report. Author: cody koeninger Closes #15132 from koeninger/SPARK-17510. --- .../kafka010/DirectKafkaInputDStream.scala | 11 ++-- .../spark/streaming/kafka010/KafkaUtils.scala | 53 ++++++++++++++++++- .../kafka010/PerPartitionConfig.scala | 47 ++++++++++++++++ .../kafka010/DirectKafkaStreamSuite.scala | 34 ++++++++---- .../kafka/DirectKafkaInputDStream.scala | 4 +- 5 files changed, 131 insertions(+), 18 deletions(-) create mode 100644 external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 7e57bb18cbd5..794f53c5abfd 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -57,7 +57,8 @@ import org.apache.spark.streaming.scheduler.rate.RateEstimator private[spark] class DirectKafkaInputDStream[K, V]( _ssc: StreamingContext, locationStrategy: LocationStrategy, - consumerStrategy: ConsumerStrategy[K, V] + consumerStrategy: ConsumerStrategy[K, V], + ppc: PerPartitionConfig ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets { val executorKafkaParams = { @@ -128,12 +129,9 @@ private[spark] class DirectKafkaInputDStream[K, V]( } } - private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( - "spark.streaming.kafka.maxRatePerPartition", 0) - protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + val estimatedRateLimit = rateController.map(_.getLatestRate()) // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { @@ -144,11 +142,12 @@ private[spark] class DirectKafkaInputDStream[K, V]( val totalLag = lagPerPartition.values.sum lagPerPartition.map { case (tp, lag) => + val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp) val backpressureRate = Math.round(lag / totalLag.toFloat * rate) tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition } + case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) } } if (effectiveRateLimitPerPartition.values.sum > 0) { diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala index b2190bfa05a3..c11917f59d5b 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala @@ -123,7 +123,31 @@ object KafkaUtils extends Logging { locationStrategy: LocationStrategy, consumerStrategy: ConsumerStrategy[K, V] ): InputDStream[ConsumerRecord[K, V]] = { - new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy) + val ppc = new DefaultPerPartitionConfig(ssc.sparkContext.getConf) + createDirectStream[K, V](ssc, locationStrategy, consumerStrategy, ppc) + } + + /** + * :: Experimental :: + * Scala constructor for a DStream where + * each given Kafka topic/partition corresponds to an RDD partition. + * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent, + * see [[LocationStrategies]] for more details. + * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe, + * see [[ConsumerStrategies]] for more details. + * @param perPartitionConfig configuration of settings such as max rate on a per-partition basis. + * see [[PerPartitionConfig]] for more details. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ + @Experimental + def createDirectStream[K, V]( + ssc: StreamingContext, + locationStrategy: LocationStrategy, + consumerStrategy: ConsumerStrategy[K, V], + perPartitionConfig: PerPartitionConfig + ): InputDStream[ConsumerRecord[K, V]] = { + new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy, perPartitionConfig) } /** @@ -150,6 +174,33 @@ object KafkaUtils extends Logging { jssc.ssc, locationStrategy, consumerStrategy)) } + /** + * :: Experimental :: + * Java constructor for a DStream where + * each given Kafka topic/partition corresponds to an RDD partition. + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent, + * see [[LocationStrategies]] for more details. + * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe, + * see [[ConsumerStrategies]] for more details + * @param perPartitionConfig configuration of settings such as max rate on a per-partition basis. + * see [[PerPartitionConfig]] for more details. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ + @Experimental + def createDirectStream[K, V]( + jssc: JavaStreamingContext, + locationStrategy: LocationStrategy, + consumerStrategy: ConsumerStrategy[K, V], + perPartitionConfig: PerPartitionConfig + ): JavaInputDStream[ConsumerRecord[K, V]] = { + new JavaInputDStream( + createDirectStream[K, V]( + jssc.ssc, locationStrategy, consumerStrategy, perPartitionConfig)) + } + /** * Tweak kafka params to prevent issues on executors */ diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala new file mode 100644 index 000000000000..4792f2a95511 --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkConf +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Interface for user-supplied configurations that can't otherwise be set via Spark properties, + * because they need tweaking on a per-partition basis, + */ +@Experimental +abstract class PerPartitionConfig extends Serializable { + /** + * Maximum rate (number of records per second) at which data will be read + * from each Kafka partition. + */ + def maxRatePerPartition(topicPartition: TopicPartition): Long +} + +/** + * Default per-partition configuration + */ +private class DefaultPerPartitionConfig(conf: SparkConf) + extends PerPartitionConfig { + val maxRate = conf.getLong("spark.streaming.kafka.maxRatePerPartition", 0) + + def maxRatePerPartition(topicPartition: TopicPartition): Long = maxRate +} diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index c81836da3cbb..fde3714d3d02 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -252,7 +252,8 @@ class DirectKafkaStreamSuite val s = new DirectKafkaInputDStream[String, String]( ssc, preferredHosts, - ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf)) s.consumer.poll(0) assert( s.consumer.position(topicPartition) >= offsetBeforeStart, @@ -307,7 +308,8 @@ class DirectKafkaStreamSuite ConsumerStrategies.Assign[String, String]( List(topicPartition), kafkaParams.asScala, - Map(topicPartition -> 11L))) + Map(topicPartition -> 11L)), + new DefaultPerPartitionConfig(sparkConf)) s.consumer.poll(0) assert( s.consumer.position(topicPartition) >= offsetBeforeStart, @@ -520,7 +522,7 @@ class DirectKafkaStreamSuite test("maxMessagesPerPartition with backpressure disabled") { val topic = "maxMessagesPerPartition" - val kafkaStream = getDirectKafkaStream(topic, None) + val kafkaStream = getDirectKafkaStream(topic, None, None) val input = Map(new TopicPartition(topic, 0) -> 50L, new TopicPartition(topic, 1) -> 50L) assert(kafkaStream.maxMessagesPerPartition(input).get == @@ -530,7 +532,7 @@ class DirectKafkaStreamSuite test("maxMessagesPerPartition with no lag") { val topic = "maxMessagesPerPartition" val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 100)) - val kafkaStream = getDirectKafkaStream(topic, rateController) + val kafkaStream = getDirectKafkaStream(topic, rateController, None) val input = Map(new TopicPartition(topic, 0) -> 0L, new TopicPartition(topic, 1) -> 0L) assert(kafkaStream.maxMessagesPerPartition(input).isEmpty) @@ -539,11 +541,19 @@ class DirectKafkaStreamSuite test("maxMessagesPerPartition respects max rate") { val topic = "maxMessagesPerPartition" val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 1000)) - val kafkaStream = getDirectKafkaStream(topic, rateController) + val ppc = Some(new PerPartitionConfig { + def maxRatePerPartition(tp: TopicPartition) = + if (tp.topic == topic && tp.partition == 0) { + 50 + } else { + 100 + } + }) + val kafkaStream = getDirectKafkaStream(topic, rateController, ppc) val input = Map(new TopicPartition(topic, 0) -> 1000L, new TopicPartition(topic, 1) -> 1000L) assert(kafkaStream.maxMessagesPerPartition(input).get == - Map(new TopicPartition(topic, 0) -> 10L, new TopicPartition(topic, 1) -> 10L)) + Map(new TopicPartition(topic, 0) -> 5L, new TopicPartition(topic, 1) -> 10L)) } test("using rate controller") { @@ -572,7 +582,9 @@ class DirectKafkaStreamSuite new DirectKafkaInputDStream[String, String]( ssc, preferredHosts, - ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) { + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) { override protected[streaming] val rateController = Some(new DirectKafkaRateController(id, estimator)) }.map(r => (r.key, r.value)) @@ -618,7 +630,10 @@ class DirectKafkaStreamSuite }.toSeq.sortBy { _._1 } } - private def getDirectKafkaStream(topic: String, mockRateController: Option[RateController]) = { + private def getDirectKafkaStream( + topic: String, + mockRateController: Option[RateController], + ppc: Option[PerPartitionConfig]) = { val batchIntervalMilliseconds = 100 val sparkConf = new SparkConf() @@ -645,7 +660,8 @@ class DirectKafkaStreamSuite tps.foreach(tp => consumer.seek(tp, 0)) consumer } - } + }, + ppc.getOrElse(new DefaultPerPartitionConfig(sparkConf)) ) { override protected[streaming] val rateController = mockRateController } diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index c3c799375bbe..d52c230eb784 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -88,12 +88,12 @@ class DirectKafkaInputDStream[ protected val kc = new KafkaCluster(kafkaParams) - private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( + private val maxRateLimitPerPartition: Long = context.sparkContext.getConf.getLong( "spark.streaming.kafka.maxRatePerPartition", 0) protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + val estimatedRateLimit = rateController.map(_.getLatestRate()) // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { From 75934457d75996be71ffd0d4b448497d656c0d40 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 14 Nov 2016 19:42:00 +0000 Subject: [PATCH 208/381] [SPARK-11496][GRAPHX][FOLLOWUP] Add param checking for runParallelPersonalizedPageRank ## What changes were proposed in this pull request? add the param checking to keep in line with other algos ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #15876 from zhengruifeng/param_check_runParallelPersonalizedPageRank. --- .../main/scala/org/apache/spark/graphx/lib/PageRank.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index f4b00757a8b5..c0c3c73463aa 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -185,6 +185,13 @@ object PageRank extends Logging { def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15, sources: Array[VertexId]): Graph[Vector, Double] = { + require(numIter > 0, s"Number of iterations must be greater than 0," + + s" but got ${numIter}") + require(resetProb >= 0 && resetProb <= 1, s"Random reset probability must belong" + + s" to [0, 1], but got ${resetProb}") + require(sources.nonEmpty, s"The list of sources must be non-empty," + + s" but got ${sources.mkString("[", ",", "]")}") + // TODO if one sources vertex id is outside of the int range // we won't be able to store its activations in a sparse vector val zero = Vectors.sparse(sources.size, List()).asBreeze From bd85603ba5f9e61e1aa8326d3e4d5703b5977a4c Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Mon, 14 Nov 2016 20:59:15 +0100 Subject: [PATCH 209/381] [SPARK-17348][SQL] Incorrect results from subquery transformation ## What changes were proposed in this pull request? Return an Analysis exception when there is a correlated non-equality predicate in a subquery and the correlated column from the outer reference is not from the immediate parent operator of the subquery. This PR prevents incorrect results from subquery transformation in such case. Test cases, both positive and negative tests, are added. ## How was this patch tested? sql/test, catalyst/test, hive/test, and scenarios that will produce incorrect results without this PR and product correct results when subquery transformation does happen. Author: Nattavut Sutyanyong Closes #15763 from nsyca/spark-17348. --- .../sql/catalyst/analysis/Analyzer.scala | 44 +++++++++ .../sql/catalyst/analysis/CheckAnalysis.scala | 7 -- .../org/apache/spark/sql/SubquerySuite.scala | 95 ++++++++++++++++++- 3 files changed, 137 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dd68d60d3e83..c14f35351708 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1031,6 +1031,37 @@ class Analyzer( } } + // SPARK-17348: A potential incorrect result case. + // When a correlated predicate is a non-equality predicate, + // certain operators are not permitted from the operator + // hosting the correlated predicate up to the operator on the outer table. + // Otherwise, the pull up of the correlated predicate + // will generate a plan with a different semantics + // which could return incorrect result. + // Currently we check for Aggregate and Window operators + // + // Below shows an example of a Logical Plan during Analyzer phase that + // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] + // through the Aggregate (or Window) operator could alter the result of + // the Aggregate. + // + // Project [c1#76] + // +- Project [c1#87, c2#88] + // : (Aggregate or Window operator) + // : +- Filter [outer(c2#77) >= c2#88)] + // : +- SubqueryAlias t2, `t2` + // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] + // : +- LocalRelation [_1#84, _2#85] + // +- SubqueryAlias t1, `t1` + // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] + // +- LocalRelation [_1#73, _2#74] + def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { + if (found) { + // Report a non-supported case as an exception + failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") + } + } + /** Determine which correlated predicate references are missing from this plan. */ def missingReferences(p: LogicalPlan): AttributeSet = { val localPredicateReferences = p.collect(predicateMap) @@ -1041,12 +1072,20 @@ class Analyzer( localPredicateReferences -- p.outputSet } + var foundNonEqualCorrelatedPred : Boolean = false + // Simplify the predicates before pulling them out. val transformed = BooleanSimplification(sub) transformUp { case f @ Filter(cond, child) => // Find all predicates with an outer reference. val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) + // Find any non-equality correlated predicates + foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { + case _: EqualTo | _: EqualNullSafe => false + case _ => true + } + // Rewrite the filter without the correlated predicates if any. correlated match { case Nil => f @@ -1068,12 +1107,17 @@ class Analyzer( } case a @ Aggregate(grouping, expressions, child) => failOnOuterReference(a) + failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) + val referencesToAdd = missingReferences(a) if (referencesToAdd.nonEmpty) { Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) } else { a } + case w : Window => + failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, w) + w case j @ Join(left, _, RightOuter, _) => failOnOuterReference(j) failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 3455a567b778..7b75c1f70974 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -119,13 +119,6 @@ trait CheckAnalysis extends PredicateHelper { } case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty => - // Make sure we are using equi-joins. - conditions.foreach { - case _: EqualTo | _: EqualNullSafe => // ok - case e => failAnalysis( - s"The correlated scalar subquery can only contain equality predicates: $e") - } - // Make sure correlated scalar subqueries contain one row for every outer row by // enforcing that they are aggregates which contain exactly one aggregate expressions. // The analyzer has already checked that subquery contained only one output column, and diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 89348668340b..c84a6f161893 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -498,10 +498,10 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("non-equal correlated scalar subquery") { val msg1 = intercept[AnalysisException] { - sql("select a, (select b from l l2 where l2.a < l1.a) sum_b from l l1") + sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1") } assert(msg1.getMessage.contains( - "The correlated scalar subquery can only contain equality predicates")) + "Correlated column is not allowed in a non-equality predicate:")) } test("disjunctive correlated scalar subquery") { @@ -639,6 +639,97 @@ class SubquerySuite extends QueryTest with SharedSQLContext { | from t1 left join t2 on t1.c1=t2.c2) t3 | where c3 not in (select c2 from t2)""".stripMargin), Row(2) :: Nil) + } + } + + test("SPARK-17348: Correlated subqueries with non-equality predicate (good case)") { + withTempView("t1", "t2") { + Seq((1, 1)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 0)).toDF("c1", "c2").createOrReplaceTempView("t2") + + // Simple case + checkAnswer( + sql( + """ + | select c1 + | from t1 + | where c1 in (select t2.c1 + | from t2 + | where t1.c2 >= t2.c2)""".stripMargin), + Row(1) :: Nil) + + // More complex case with OR predicate + checkAnswer( + sql( + """ + | select t1.c1 + | from t1, t1 as t3 + | where t1.c1 = t3.c1 + | and (t1.c1 in (select t2.c1 + | from t2 + | where t1.c2 >= t2.c2 + | or t3.c2 < t2.c2) + | or t1.c2 >= 0)""".stripMargin), + Row(1) :: Nil) + } + } + + test("SPARK-17348: Correlated subqueries with non-equality predicate (error case)") { + withTempView("t1", "t2", "t3", "t4") { + Seq((1, 1)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 0)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((2, 1)).toDF("c1", "c2").createOrReplaceTempView("t3") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t4") + + // Simplest case + intercept[AnalysisException] { + sql( + """ + | select t1.c1 + | from t1 + | where t1.c1 in (select max(t2.c1) + | from t2 + | where t1.c2 >= t2.c2)""".stripMargin).collect() + } + + // Add a HAVING on top and augmented within an OR predicate + intercept[AnalysisException] { + sql( + """ + | select t1.c1 + | from t1 + | where t1.c1 in (select max(t2.c1) + | from t2 + | where t1.c2 >= t2.c2 + | having count(*) > 0 ) + | or t1.c2 >= 0""".stripMargin).collect() + } + + // Add a HAVING on top and augmented within an OR predicate + intercept[AnalysisException] { + sql( + """ + | select t1.c1 + | from t1, t1 as t3 + | where t1.c1 = t3.c1 + | and (t1.c1 in (select max(t2.c1) + | from t2 + | where t1.c2 = t2.c2 + | or t3.c2 = t2.c2) + | )""".stripMargin).collect() + } + + // In Window expression: changing the data set to + // demonstrate if this query ran, it would return incorrect result. + intercept[AnalysisException] { + sql( + """ + | select c1 + | from t3 + | where c1 in (select max(t4.c1) over () + | from t4 + | where t3.c2 >= t4.c2)""".stripMargin).collect() + } } } } From c07187823a98f0d1a0f58c06e28a27e1abed157a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 14 Nov 2016 16:46:26 -0800 Subject: [PATCH 210/381] [SPARK-18124] Observed delay based Event Time Watermarks This PR adds a new method `withWatermark` to the `Dataset` API, which can be used specify an _event time watermark_. An event time watermark allows the streaming engine to reason about the point in time after which we no longer expect to see late data. This PR also has augmented `StreamExecution` to use this watermark for several purposes: - To know when a given time window aggregation is finalized and thus results can be emitted when using output modes that do not allow updates (e.g. `Append` mode). - To minimize the amount of state that we need to keep for on-going aggregations, by evicting state for groups that are no longer expected to change. Although, we do still maintain all state if the query requires (i.e. if the event time is not present in the `groupBy` or when running in `Complete` mode). An example that emits windowed counts of records, waiting up to 5 minutes for late data to arrive. ```scala df.withWatermark("eventTime", "5 minutes") .groupBy(window($"eventTime", "1 minute") as 'window) .count() .writeStream .format("console") .mode("append") // In append mode, we only output finalized aggregations. .start() ``` ### Calculating the watermark. The current event time is computed by looking at the `MAX(eventTime)` seen this epoch across all of the partitions in the query minus some user defined _delayThreshold_. An additional constraint is that the watermark must increase monotonically. Note that since we must coordinate this value across partitions occasionally, the actual watermark used is only guaranteed to be at least `delay` behind the actual event time. In some cases we may still process records that arrive more than delay late. This mechanism was chosen for the initial implementation over processing time for two reasons: - it is robust to downtime that could affect processing delay - it does not require syncing of time or timezones between the producer and the processing engine. ### Other notable implementation details - A new trigger metric `eventTimeWatermark` outputs the current value of the watermark. - We mark the event time column in the `Attribute` metadata using the key `spark.watermarkDelay`. This allows downstream operations to know which column holds the event time. Operations like `window` propagate this metadata. - `explain()` marks the watermark with a suffix of `-T${delayMs}` to ease debugging of how this information is propagated. - Currently, we don't filter out late records, but instead rely on the state store to avoid emitting records that are both added and filtered in the same epoch. ### Remaining in this PR - [ ] The test for recovery is currently failing as we don't record the watermark used in the offset log. We will need to do so to ensure determinism, but this is deferred until #15626 is merged. ### Other follow-ups There are some natural additional features that we should consider for future work: - Ability to write records that arrive too late to some external store in case any out-of-band remediation is required. - `Update` mode so you can get partial results before a group is evicted. - Other mechanisms for calculating the watermark. In particular a watermark based on quantiles would be more robust to outliers. Author: Michael Armbrust Closes #15702 from marmbrus/watermarks. --- .../spark/unsafe/types/CalendarInterval.java | 4 + .../apache/spark/sql/AnalysisException.scala | 3 +- .../sql/catalyst/analysis/Analyzer.scala | 8 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 10 + .../UnsupportedOperationChecker.scala | 18 +- .../sql/catalyst/analysis/unresolved.scala | 3 +- .../expressions/namedExpressions.scala | 17 +- .../plans/logical/EventTimeWatermark.scala | 51 +++++ .../scala/org/apache/spark/sql/Dataset.scala | 40 +++- .../spark/sql/execution/SparkStrategies.scala | 12 +- .../sql/execution/aggregate/AggUtils.scala | 9 +- .../sql/execution/command/commands.scala | 2 +- .../streaming/EventTimeWatermarkExec.scala | 93 +++++++++ .../sql/execution/streaming/ForeachSink.scala | 3 +- .../streaming/IncrementalExecution.scala | 12 +- .../streaming/StatefulAggregate.scala | 170 +++++++++------- .../execution/streaming/StreamExecution.scala | 25 ++- .../execution/streaming/StreamMetrics.scala | 1 + .../state/HDFSBackedStateStoreProvider.scala | 23 ++- .../streaming/state/StateStore.scala | 7 +- .../streaming/state/StateStoreSuite.scala | 6 +- .../spark/sql/streaming/WatermarkSuite.scala | 191 ++++++++++++++++++ 22 files changed, 597 insertions(+), 111 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 518ed6470a75..a7b0e6f80c2b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -252,6 +252,10 @@ public static long parseSecondNano(String secondNano) throws IllegalArgumentExce public final int months; public final long microseconds; + public final long milliseconds() { + return this.microseconds / MICROS_PER_MILLI; + } + public CalendarInterval(int months, long microseconds) { this.months = months; this.microseconds = microseconds; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 7defb9df862c..ff8576157305 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -31,7 +31,8 @@ class AnalysisException protected[sql] ( val message: String, val line: Option[Int] = None, val startPosition: Option[Int] = None, - val plan: Option[LogicalPlan] = None, + // Some plans fail to serialize due to bugs in scala collections. + @transient val plan: Option[LogicalPlan] = None, val cause: Option[Throwable] = None) extends Exception(message, cause.orNull) with Serializable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c14f35351708..ec5f710fd987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2272,7 +2272,13 @@ object TimeWindowing extends Rule[LogicalPlan] { windowExpressions.head.timeColumn.resolved && windowExpressions.head.checkInputDataTypes().isSuccess) { val window = windowExpressions.head - val windowAttr = AttributeReference("window", window.dataType)() + + val metadata = window.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + val windowAttr = + AttributeReference("window", window.dataType, metadata = metadata)() val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt val windows = Seq.tabulate(maxNumOverlapping + 1) { i => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7b75c1f70974..98e50d0d3c67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -148,6 +148,16 @@ trait CheckAnalysis extends PredicateHelper { } operator match { + case etw: EventTimeWatermark => + etw.eventTime.dataType match { + case s: StructType + if s.find(_.name == "end").map(_.dataType) == Some(TimestampType) => + case _: TimestampType => + case _ => + failAnalysis( + s"Event time must be defined on a window or a timestamp, but " + + s"${etw.eventTime.name} is of type ${etw.eventTime.dataType.simpleString}") + } case f: Filter if f.condition.dataType != BooleanType => failAnalysis( s"filter expression '${f.condition.sql}' " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index e81370c504ab..c054fcbef36f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.{AnalysisException, InternalOutputModes} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.OutputMode @@ -55,9 +56,20 @@ object UnsupportedOperationChecker { // Disallow some output mode outputMode match { case InternalOutputModes.Append if aggregates.nonEmpty => - throwError( - s"$outputMode output mode not supported when there are streaming aggregations on " + - s"streaming DataFrames/DataSets")(plan) + val aggregate = aggregates.head + + // Find any attributes that are associated with an eventTime watermark. + val watermarkAttributes = aggregate.groupingExpressions.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } + + // We can append rows to the sink once the group is under the watermark. Without this + // watermark a group is never "finished" so we would never output anything. + if (watermarkAttributes.isEmpty) { + throwError( + s"$outputMode output mode not supported when there are streaming aggregations on " + + s"streaming DataFrames/DataSets")(plan) + } case InternalOutputModes.Complete | InternalOutputModes.Update if aggregates.isEmpty => throwError( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 235ae0478245..36ed9ba50372 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, Codege import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, Metadata, StructType} /** * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully @@ -98,6 +98,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def withNullability(newNullability: Boolean): UnresolvedAttribute = this override def withQualifier(newQualifier: Option[String]): UnresolvedAttribute = this override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) + override def withMetadata(newMetadata: Metadata): Attribute = this override def toString: String = s"'$name" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 306a99d5a37b..127475713605 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -22,6 +22,7 @@ import java.util.{Objects, UUID} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.types._ @@ -104,6 +105,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn def withNullability(newNullability: Boolean): Attribute def withQualifier(newQualifier: Option[String]): Attribute def withName(newName: String): Attribute + def withMetadata(newMetadata: Metadata): Attribute override def toAttribute: Attribute = this def newInstance(): Attribute @@ -292,11 +294,22 @@ case class AttributeReference( } } + override def withMetadata(newMetadata: Metadata): Attribute = { + AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier, isGenerated) + } + override protected final def otherCopyArgs: Seq[AnyRef] = { exprId :: qualifier :: isGenerated :: Nil } - override def toString: String = s"$name#${exprId.id}$typeSuffix" + /** Used to signal the column used to calculate an eventTime watermark (e.g. a#1-T{delayMs}) */ + private def delaySuffix = if (metadata.contains(EventTimeWatermark.delayKey)) { + s"-T${metadata.getLong(EventTimeWatermark.delayKey)}ms" + } else { + "" + } + + override def toString: String = s"$name#${exprId.id}$typeSuffix$delaySuffix" // Since the expression id is not in the first constructor it is missing from the default // tree string. @@ -332,6 +345,8 @@ case class PrettyAttribute( override def withQualifier(newQualifier: Option[String]): Attribute = throw new UnsupportedOperationException override def withName(newName: String): Attribute = throw new UnsupportedOperationException + override def withMetadata(newMetadata: Metadata): Attribute = + throw new UnsupportedOperationException override def qualifier: Option[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala new file mode 100644 index 000000000000..4224a7997c41 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.types.MetadataBuilder +import org.apache.spark.unsafe.types.CalendarInterval + +object EventTimeWatermark { + /** The [[org.apache.spark.sql.types.Metadata]] key used to hold the eventTime watermark delay. */ + val delayKey = "spark.watermarkDelayMs" +} + +/** + * Used to mark a user specified column as holding the event time for a row. + */ +case class EventTimeWatermark( + eventTime: Attribute, + delay: CalendarInterval, + child: LogicalPlan) extends LogicalPlan { + + // Update the metadata on the eventTime column to include the desired delay. + override val output: Seq[Attribute] = child.output.map { a => + if (a semanticEquals eventTime) { + val updatedMetadata = new MetadataBuilder() + .withMetadata(a.metadata) + .putLong(EventTimeWatermark.delayKey, delay.milliseconds) + .build() + a.withMetadata(updatedMetadata) + } else { + a + } + } + + override val children: Seq[LogicalPlan] = child :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index eb2b20afc37c..af30683cc01c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -50,6 +50,7 @@ import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils private[sql] object Dataset { @@ -476,7 +477,7 @@ class Dataset[T] private[sql]( * `collect()`, will throw an [[AnalysisException]] when there is a streaming * source present. * - * @group basic + * @group streaming * @since 2.0.0 */ @Experimental @@ -496,8 +497,6 @@ class Dataset[T] private[sql]( /** * Returns a checkpointed version of this Dataset. * - * @param eager When true, materializes the underlying checkpointed RDD eagerly. - * * @group basic * @since 2.1.0 */ @@ -535,6 +534,41 @@ class Dataset[T] private[sql]( )(sparkSession)).as[T] } + /** + * :: Experimental :: + * Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time + * before which we assume no more late data is going to arrive. + * + * Spark will use this watermark for several purposes: + * - To know when a given time window aggregation can be finalized and thus can be emitted when + * using output modes that do not allow updates. + * - To minimize the amount of state that we need to keep for on-going aggregations. + * + * The current watermark is computed by looking at the `MAX(eventTime)` seen across + * all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost + * of coordinating this value across partitions, the actual watermark used is only guaranteed + * to be at least `delayThreshold` behind the actual event time. In some cases we may still + * process records that arrive more than `delayThreshold` late. + * + * @param eventTime the name of the column that contains the event time of the row. + * @param delayThreshold the minimum delay to wait to data to arrive late, relative to the latest + * record that has been processed in the form of an interval + * (e.g. "1 minute" or "5 hours"). + * + * @group streaming + * @since 2.1.0 + */ + @Experimental + @InterfaceStability.Evolving + // We only accept an existing column name, not a derived column here as a watermark that is + // defined on a derived column cannot referenced elsewhere in the plan. + def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan { + val parsedDelay = + Option(CalendarInterval.fromString("interval " + delayThreshold)) + .getOrElse(throw new AnalysisException(s"Unable to parse time delay '$delayThreshold'")) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan) + } + /** * Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated, * and all cells will be aligned right. For example: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 190fdd84343e..2308ae8a6c61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -18,20 +18,23 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, SaveMode, Strategy} +import org.apache.spark.sql.{SaveMode, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} -import org.apache.spark.sql.execution.streaming.{MemoryPlan, StreamingExecutionRelation, StreamingRelation, StreamingRelationExec} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.StreamingQuery /** * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting @@ -224,6 +227,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object StatefulAggregationStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case EventTimeWatermark(columnName, delay, child) => + EventTimeWatermarkExec(columnName, delay, planLater(child)) :: Nil + case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 3c8ef1ad84c0..8b8ccf4239b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -328,8 +328,13 @@ object AggUtils { } // Note: stateId and returnAllStates are filled in later with preparation rules // in IncrementalExecution. - val saved = StateStoreSaveExec( - groupingAttributes, stateId = None, returnAllStates = None, partialMerged2) + val saved = + StateStoreSaveExec( + groupingAttributes, + stateId = None, + outputMode = None, + eventTimeWatermark = None, + partialMerged2) val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index d82e54e57564..52d8dc22a2d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -104,7 +104,7 @@ case class ExplainCommand( if (logicalPlan.isStreaming) { // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the // output mode does not matter since there is no `Sink`. - new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "", 0) + new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "", 0, 0) } else { sparkSession.sessionState.executePlan(logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala new file mode 100644 index 000000000000..4c8cb069d23a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.math.max + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.MetadataBuilder +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.AccumulatorV2 + +/** Tracks the maximum positive long seen. */ +class MaxLong(protected var currentValue: Long = 0) + extends AccumulatorV2[Long, Long] { + + override def isZero: Boolean = value == 0 + override def value: Long = currentValue + override def copy(): AccumulatorV2[Long, Long] = new MaxLong(currentValue) + + override def reset(): Unit = { + currentValue = 0 + } + + override def add(v: Long): Unit = { + currentValue = max(v, value) + } + + override def merge(other: AccumulatorV2[Long, Long]): Unit = { + currentValue = max(value, other.value) + } +} + +/** + * Used to mark a column as the containing the event time for a given record. In addition to + * adding appropriate metadata to this column, this operator also tracks the maximum observed event + * time. Based on the maximum observed time and a user specified delay, we can calculate the + * `watermark` after which we assume we will no longer see late records for a particular time + * period. + */ +case class EventTimeWatermarkExec( + eventTime: Attribute, + delay: CalendarInterval, + child: SparkPlan) extends SparkPlan { + + // TODO: Use Spark SQL Metrics? + val maxEventTime = new MaxLong + sparkContext.register(maxEventTime) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val getEventTime = UnsafeProjection.create(eventTime :: Nil, child.output) + iter.map { row => + maxEventTime.add(getEventTime(row).getLong(0)) + row + } + } + } + + // Update the metadata on the eventTime column to include the desired delay. + override val output: Seq[Attribute] = child.output.map { a => + if (a semanticEquals eventTime) { + val updatedMetadata = new MetadataBuilder() + .withMetadata(a.metadata) + .putLong(EventTimeWatermark.delayKey, delay.milliseconds) + .build() + + a.withMetadata(updatedMetadata) + } else { + a + } + } + + override def children: Seq[SparkPlan] = child :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index 24f98b9211f1..f5c550dd6ac3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -60,7 +60,8 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria deserialized, data.queryExecution.asInstanceOf[IncrementalExecution].outputMode, data.queryExecution.asInstanceOf[IncrementalExecution].checkpointLocation, - data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId) + data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId, + data.queryExecution.asInstanceOf[IncrementalExecution].currentEventTimeWatermark) incrementalExecution.toRdd.mapPartitions { rows => rows.map(_.get(0, objectType)) }.asInstanceOf[RDD[T]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 05294df2673d..e9d072f8a98b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -32,11 +32,13 @@ class IncrementalExecution( logicalPlan: LogicalPlan, val outputMode: OutputMode, val checkpointLocation: String, - val currentBatchId: Long) + val currentBatchId: Long, + val currentEventTimeWatermark: Long) extends QueryExecution(sparkSession, logicalPlan) { // TODO: make this always part of planning. - val stateStrategy = sparkSession.sessionState.planner.StatefulAggregationStrategy +: + val stateStrategy = + sparkSession.sessionState.planner.StatefulAggregationStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies @@ -57,17 +59,17 @@ class IncrementalExecution( val state = new Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = plan transform { - case StateStoreSaveExec(keys, None, None, + case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId) - val returnAllStates = if (outputMode == InternalOutputModes.Complete) true else false operatorId += 1 StateStoreSaveExec( keys, Some(stateId), - Some(returnAllStates), + Some(outputMode), + Some(currentEventTimeWatermark), agg.withNewChildren( StateStoreRestoreExec( keys, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index ad8238f189c6..7af978a9c4aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -21,12 +21,17 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratePredicate, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution +import org.apache.spark.sql.InternalOutputModes._ +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + /** Used to identify the state store for a given operator. */ case class OperatorStateId( @@ -92,8 +97,9 @@ case class StateStoreRestoreExec( */ case class StateStoreSaveExec( keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId], - returnAllStates: Option[Boolean], + stateId: Option[OperatorStateId] = None, + outputMode: Option[OutputMode] = None, + eventTimeWatermark: Option[Long] = None, child: SparkPlan) extends execution.UnaryExecNode with StatefulOperator { @@ -104,9 +110,9 @@ case class StateStoreSaveExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver - assert(returnAllStates.nonEmpty, - "Incorrect planning in IncrementalExecution, returnAllStates have not been set") - val saveAndReturnFunc = if (returnAllStates.get) saveAndReturnAll _ else saveAndReturnUpdated _ + assert(outputMode.nonEmpty, + "Incorrect planning in IncrementalExecution, outputMode has not been set") + child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, operatorId = getStateId.operatorId, @@ -114,75 +120,95 @@ case class StateStoreSaveExec( keyExpressions.toStructType, child.output.toStructType, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator) - )(saveAndReturnFunc) + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val numOutputRows = longMetric("numOutputRows") + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + + outputMode match { + // Update and output all rows in the StateStore. + case Some(Complete) => + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numUpdatedStateRows += 1 + } + store.commit() + numTotalStateRows += store.numKeys() + store.iterator().map { case (k, v) => + numOutputRows += 1 + v.asInstanceOf[InternalRow] + } + + // Update and output only rows being evicted from the StateStore + case Some(Append) => + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numUpdatedStateRows += 1 + } + + val watermarkAttribute = + keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)).get + // If we are evicting based on a window, use the end of the window. Otherwise just + // use the attribute itself. + val evictionExpression = + if (watermarkAttribute.dataType.isInstanceOf[StructType]) { + LessThanOrEqual( + GetStructField(watermarkAttribute, 1), + Literal(eventTimeWatermark.get * 1000)) + } else { + LessThanOrEqual( + watermarkAttribute, + Literal(eventTimeWatermark.get * 1000)) + } + + logInfo(s"Filtering state store on: $evictionExpression") + val predicate = newPredicate(evictionExpression, keyExpressions) + store.remove(predicate.eval) + + store.commit() + + numTotalStateRows += store.numKeys() + store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed => + numOutputRows += 1 + removed.value.asInstanceOf[InternalRow] + } + + // Update and output modified rows from the StateStore. + case Some(Update) => + new Iterator[InternalRow] { + private[this] val baseIterator = iter + + override def hasNext: Boolean = { + if (!baseIterator.hasNext) { + store.commit() + numTotalStateRows += store.numKeys() + false + } else { + true + } + } + + override def next(): InternalRow = { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numOutputRows += 1 + numUpdatedStateRows += 1 + row + } + } + + case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode") + } + } } override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning - - /** - * Save all the rows to the state store, and return all the rows in the state store. - * Note that this returns an iterator that pipelines the saving to store with downstream - * processing. - */ - private def saveAndReturnUpdated( - store: StateStore, - iter: Iterator[InternalRow]): Iterator[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - val numTotalStateRows = longMetric("numTotalStateRows") - val numUpdatedStateRows = longMetric("numUpdatedStateRows") - - new Iterator[InternalRow] { - private[this] val baseIterator = iter - private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - store.commit() - numTotalStateRows += store.numKeys() - false - } else { - true - } - } - - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numOutputRows += 1 - numUpdatedStateRows += 1 - row - } - } - } - - /** - * Save all the rows to the state store, and return all the rows in the state store. - * Note that the saving to store is blocking; only after all the rows have been saved - * is the iterator on the update store data is generated. - */ - private def saveAndReturnAll( - store: StateStore, - iter: Iterator[InternalRow]): Iterator[InternalRow] = { - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - val numOutputRows = longMetric("numOutputRows") - val numTotalStateRows = longMetric("numTotalStateRows") - val numUpdatedStateRows = longMetric("numUpdatedStateRows") - - while (iter.hasNext) { - val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numUpdatedStateRows += 1 - } - store.commit() - numTotalStateRows += store.numKeys() - store.iterator().map { case (k, v) => - numOutputRows += 1 - v.asInstanceOf[InternalRow] - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 57e89f85361e..3ca6feac05ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -92,6 +92,9 @@ class StreamExecution( /** The current batchId or -1 if execution has not yet been initialized. */ private var currentBatchId: Long = -1 + /** The current eventTime watermark, used to bound the lateness of data that will processed. */ + private var currentEventTimeWatermark: Long = 0 + /** All stream sources present in the query plan. */ private val sources = logicalPlan.collect { case s: StreamingExecutionRelation => s.source } @@ -427,7 +430,8 @@ class StreamExecution( triggerLogicalPlan, outputMode, checkpointFile("state"), - currentBatchId) + currentBatchId, + currentEventTimeWatermark) lastExecution.executedPlan // Force the lazy generation of execution plan } @@ -436,6 +440,25 @@ class StreamExecution( sink.addBatch(currentBatchId, nextBatch) reportNumRows(executedPlan, triggerLogicalPlan, newData) + // Update the eventTime watermark if we find one in the plan. + // TODO: Does this need to be an AttributeMap? + lastExecution.executedPlan.collect { + case e: EventTimeWatermarkExec => + logTrace(s"Maximum observed eventTime: ${e.maxEventTime.value}") + (e.maxEventTime.value / 1000) - e.delay.milliseconds() + }.headOption.foreach { newWatermark => + if (newWatermark > currentEventTimeWatermark) { + logInfo(s"Updating eventTime watermark to: $newWatermark ms") + currentEventTimeWatermark = newWatermark + } else { + logTrace(s"Event time didn't move: $newWatermark < $currentEventTimeWatermark") + } + + if (newWatermark != 0) { + streamMetrics.reportTriggerDetail(EVENT_TIME_WATERMARK, newWatermark) + } + } + awaitBatchLock.lock() try { // Wake up any threads that are waiting for the stream to progress. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala index e98d1883e459..5645554a58f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala @@ -221,6 +221,7 @@ object StreamMetrics extends Logging { val IS_TRIGGER_ACTIVE = "isTriggerActive" val IS_DATA_PRESENT_IN_TRIGGER = "isDataPresentInTrigger" val STATUS_MESSAGE = "statusMessage" + val EVENT_TIME_WATERMARK = "eventTimeWatermark" val START_TIMESTAMP = "timestamp.triggerStart" val GET_OFFSET_TIMESTAMP = "timestamp.afterGetOffset" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index f07feaad5dc7..493fdaaec506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -109,7 +109,7 @@ private[state] class HDFSBackedStateStoreProvider( case Some(ValueAdded(_, _)) => // Value did not exist in previous version and was added already, keep it marked as added allUpdates.put(key, ValueAdded(key, value)) - case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) => + case Some(ValueUpdated(_, _)) | Some(ValueRemoved(_, _)) => // Value existed in previous version and updated/removed, mark it as updated allUpdates.put(key, ValueUpdated(key, value)) case None => @@ -124,24 +124,25 @@ private[state] class HDFSBackedStateStoreProvider( /** Remove keys that match the following condition */ override def remove(condition: UnsafeRow => Boolean): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") - - val keyIter = mapToUpdate.keySet().iterator() - while (keyIter.hasNext) { - val key = keyIter.next - if (condition(key)) { - keyIter.remove() + val entryIter = mapToUpdate.entrySet().iterator() + while (entryIter.hasNext) { + val entry = entryIter.next + if (condition(entry.getKey)) { + val value = entry.getValue + val key = entry.getKey + entryIter.remove() Option(allUpdates.get(key)) match { case Some(ValueUpdated(_, _)) | None => // Value existed in previous version and maybe was updated, mark removed - allUpdates.put(key, KeyRemoved(key)) + allUpdates.put(key, ValueRemoved(key, value)) case Some(ValueAdded(_, _)) => // Value did not exist in previous version and was added, should not appear in updates allUpdates.remove(key) - case Some(KeyRemoved(_)) => + case Some(ValueRemoved(_, _)) => // Remove already in update map, no need to change } - writeToDeltaFile(tempDeltaFileStream, KeyRemoved(key)) + writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value)) } } } @@ -334,7 +335,7 @@ private[state] class HDFSBackedStateStoreProvider( writeUpdate(key, value) case ValueUpdated(key, value) => writeUpdate(key, value) - case KeyRemoved(key) => + case ValueRemoved(key, value) => writeRemove(key) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 7132e284c28f..9bc6c0e2b933 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -99,13 +99,16 @@ trait StateStoreProvider { /** Trait representing updates made to a [[StateStore]]. */ -sealed trait StoreUpdate +sealed trait StoreUpdate { + def key: UnsafeRow + def value: UnsafeRow +} case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate -case class KeyRemoved(key: UnsafeRow) extends StoreUpdate +case class ValueRemoved(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 533cd0cd2a2e..05fc7345a7da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -668,11 +668,11 @@ private[state] object StateStoreSuite { } def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { - iterator.map { _ match { + iterator.map { case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value)) case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value)) - case KeyRemoved(key) => Removed(rowToString(key)) - }}.toSet + case ValueRemoved(key, _) => Removed(rowToString(key)) + }.toSet } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala new file mode 100644 index 000000000000..3617ec0f564c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.functions.{count, window} + +class WatermarkSuite extends StreamTest with BeforeAndAfter with Logging { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("error on bad column") { + val inputData = MemoryStream[Int].toDF() + val e = intercept[AnalysisException] { + inputData.withWatermark("badColumn", "1 minute") + } + assert(e.getMessage contains "badColumn") + } + + test("error on wrong type") { + val inputData = MemoryStream[Int].toDF() + val e = intercept[AnalysisException] { + inputData.withWatermark("value", "1 minute") + } + assert(e.getMessage contains "value") + assert(e.getMessage contains "int") + } + + + test("watermark metric") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 15), + AssertOnLastQueryStatus { status => + status.triggerDetails.get(StreamMetrics.EVENT_TIME_WATERMARK) === "5000" + }, + AddData(inputData, 15), + AssertOnLastQueryStatus { status => + status.triggerDetails.get(StreamMetrics.EVENT_TIME_WATERMARK) === "5000" + }, + AddData(inputData, 25), + AssertOnLastQueryStatus { status => + status.triggerDetails.get(StreamMetrics.EVENT_TIME_WATERMARK) === "15000" + } + ) + } + + test("append-mode watermark aggregation") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckAnswer(), + AddData(inputData, 25), // Evict items less than previous watermark. + CheckAnswer((10, 5)) + ) + } + + ignore("recovery") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + StopStream, + StartStream(), + CheckAnswer(), + AddData(inputData, 25), // Evict items less than previous watermark. + StopStream, + StartStream(), + CheckAnswer((10, 5)) + ) + } + + test("dropping old data") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11, 12), + CheckAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckAnswer(), + AddData(inputData, 25), // Evict items less than previous watermark. + CheckAnswer((10, 3)), + AddData(inputData, 10), // 10 is later than 15 second watermark + CheckAnswer((10, 3)), + AddData(inputData, 25), + CheckAnswer((10, 3)) // Should not emit an incorrect partial result. + ) + } + + test("complete mode") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + // No eviction when asked to compute complete results. + testStream(windowedAggregation, OutputMode.Complete)( + AddData(inputData, 10, 11, 12), + CheckAnswer((10, 3)), + AddData(inputData, 25), + CheckAnswer((10, 3), (25, 1)), + AddData(inputData, 25), + CheckAnswer((10, 3), (25, 2)), + AddData(inputData, 10), + CheckAnswer((10, 4), (25, 2)), + AddData(inputData, 25), + CheckAnswer((10, 4), (25, 3)) + ) + } + + test("group by on raw timestamp") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy($"eventTime") + .agg(count("*") as 'count) + .select($"eventTime".cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10), + CheckAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckAnswer(), + AddData(inputData, 25), // Evict items less than previous watermark. + CheckAnswer((10, 1)) + ) + } +} From c31def1ddcbed340bfc071d54fb3dc7945cb525a Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 14 Nov 2016 21:15:39 -0800 Subject: [PATCH 211/381] [SPARK-18428][DOC] Update docs for GraphX ## What changes were proposed in this pull request? 1, Add link of `VertexRDD` and `EdgeRDD` 2, Notify in `Vertex and Edge RDDs` that not all methods are listed 3, `VertexID` -> `VertexId` ## How was this patch tested? No tests, only docs is modified Author: Zheng RuiFeng Closes #15875 from zhengruifeng/update_graphop_doc. --- docs/graphx-programming-guide.md | 68 ++++++++++++++++---------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 58671e6f146d..1097cf1211c1 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -11,6 +11,7 @@ description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT [EdgeRDD]: api/scala/index.html#org.apache.spark.graphx.EdgeRDD +[VertexRDD]: api/scala/index.html#org.apache.spark.graphx.VertexRDD [Edge]: api/scala/index.html#org.apache.spark.graphx.Edge [EdgeTriplet]: api/scala/index.html#org.apache.spark.graphx.EdgeTriplet [Graph]: api/scala/index.html#org.apache.spark.graphx.Graph @@ -89,7 +90,7 @@ with user defined objects attached to each vertex and edge. A directed multigra graph with potentially multiple parallel edges sharing the same source and destination vertex. The ability to support parallel edges simplifies modeling scenarios where there can be multiple relationships (e.g., co-worker and friend) between the same vertices. Each vertex is keyed by a -*unique* 64-bit long identifier (`VertexID`). GraphX does not impose any ordering constraints on +*unique* 64-bit long identifier (`VertexId`). GraphX does not impose any ordering constraints on the vertex identifiers. Similarly, edges have corresponding source and destination vertex identifiers. @@ -130,12 +131,12 @@ class Graph[VD, ED] { } {% endhighlight %} -The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexID, +The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexId, VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED]` provide additional functionality built around graph computation and leverage internal optimizations. We discuss the -`VertexRDD` and `EdgeRDD` API in greater detail in the section on [vertex and edge +`VertexRDD`[VertexRDD] and `EdgeRDD`[EdgeRDD] API in greater detail in the section on [vertex and edge RDDs](#vertex_and_edge_rdds) but for now they can be thought of as simply RDDs of the form: -`RDD[(VertexID, VD)]` and `RDD[Edge[ED]]`. +`RDD[(VertexId, VD)]` and `RDD[Edge[ED]]`. ### Example Property Graph @@ -197,7 +198,7 @@ graph.edges.filter(e => e.srcId > e.dstId).count {% endhighlight %} > Note that `graph.vertices` returns an `VertexRDD[(String, String)]` which extends -> `RDD[(VertexID, (String, String))]` and so we use the scala `case` expression to deconstruct the +> `RDD[(VertexId, (String, String))]` and so we use the scala `case` expression to deconstruct the > tuple. On the other hand, `graph.edges` returns an `EdgeRDD` containing `Edge[String]` objects. > We could have also used the case class type constructor as in the following: > {% highlight scala %} @@ -287,7 +288,7 @@ class Graph[VD, ED] { // Change the partitioning heuristic ============================================================ def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED] // Transform vertex and edge attributes ========================================================== - def mapVertices[VD2](map: (VertexID, VD) => VD2): Graph[VD2, ED] + def mapVertices[VD2](map: (VertexId, VD) => VD2): Graph[VD2, ED] def mapEdges[ED2](map: Edge[ED] => ED2): Graph[VD, ED2] def mapEdges[ED2](map: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2]): Graph[VD, ED2] def mapTriplets[ED2](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] @@ -297,18 +298,18 @@ class Graph[VD, ED] { def reverse: Graph[VD, ED] def subgraph( epred: EdgeTriplet[VD,ED] => Boolean = (x => true), - vpred: (VertexID, VD) => Boolean = ((v, d) => true)) + vpred: (VertexId, VD) => Boolean = ((v, d) => true)) : Graph[VD, ED] def mask[VD2, ED2](other: Graph[VD2, ED2]): Graph[VD, ED] def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED] // Join RDDs with the graph ====================================================================== - def joinVertices[U](table: RDD[(VertexID, U)])(mapFunc: (VertexID, VD, U) => VD): Graph[VD, ED] - def outerJoinVertices[U, VD2](other: RDD[(VertexID, U)]) - (mapFunc: (VertexID, VD, Option[U]) => VD2) + def joinVertices[U](table: RDD[(VertexId, U)])(mapFunc: (VertexId, VD, U) => VD): Graph[VD, ED] + def outerJoinVertices[U, VD2](other: RDD[(VertexId, U)]) + (mapFunc: (VertexId, VD, Option[U]) => VD2) : Graph[VD2, ED] // Aggregate information about adjacent triplets ================================================= - def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]] - def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexID, VD)]] + def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] + def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] def aggregateMessages[Msg: ClassTag]( sendMsg: EdgeContext[VD, ED, Msg] => Unit, mergeMsg: (Msg, Msg) => Msg, @@ -316,15 +317,15 @@ class Graph[VD, ED] { : VertexRDD[A] // Iterative graph-parallel computation ========================================================== def pregel[A](initialMsg: A, maxIterations: Int, activeDirection: EdgeDirection)( - vprog: (VertexID, VD, A) => VD, - sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexID,A)], + vprog: (VertexId, VD, A) => VD, + sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId,A)], mergeMsg: (A, A) => A) : Graph[VD, ED] // Basic graph algorithms ======================================================================== def pageRank(tol: Double, resetProb: Double = 0.15): Graph[Double, Double] - def connectedComponents(): Graph[VertexID, ED] + def connectedComponents(): Graph[VertexId, ED] def triangleCount(): Graph[Int, ED] - def stronglyConnectedComponents(numIter: Int): Graph[VertexID, ED] + def stronglyConnectedComponents(numIter: Int): Graph[VertexId, ED] } {% endhighlight %} @@ -481,7 +482,7 @@ original value. > is therefore recommended that the input RDD be made unique using the following which will > also *pre-index* the resulting values to substantially accelerate the subsequent join. > {% highlight scala %} -val nonUniqueCosts: RDD[(VertexID, Double)] +val nonUniqueCosts: RDD[(VertexId, Double)] val uniqueCosts: VertexRDD[Double] = graph.vertices.aggregateUsingIndex(nonUnique, (a,b) => a + b) val joinedGraph = graph.joinVertices(uniqueCosts)( @@ -511,7 +512,7 @@ val degreeGraph = graph.outerJoinVertices(outDegrees) { (id, oldAttr, outDegOpt) > provide type annotation for the user defined function: > {% highlight scala %} val joinedGraph = graph.joinVertices(uniqueCosts, - (id: VertexID, oldCost: Double, extraCost: Double) => oldCost + extraCost) + (id: VertexId, oldCost: Double, extraCost: Double) => oldCost + extraCost) {% endhighlight %} > @@ -558,7 +559,7 @@ The user defined `mergeMsg` function takes two messages destined to the same ver yields a single message. Think of `mergeMsg` as the reduce function in map-reduce. The [`aggregateMessages`][Graph.aggregateMessages] operator returns a `VertexRDD[Msg]` containing the aggregate message (of type `Msg`) destined to each vertex. Vertices that did not -receive a message are not included in the returned `VertexRDD`. +receive a message are not included in the returned `VertexRDD`[VertexRDD]. + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.LinearRegression). + {% include_example python/ml/linear_regression_with_elastic_net.py %}
@@ -519,18 +546,21 @@ function and extracting model summary statistics.
+ Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GeneralizedLinearRegression) for more details. {% include_example scala/org/apache/spark/examples/ml/GeneralizedLinearRegressionExample.scala %}
+ Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GeneralizedLinearRegression.html) for more details. {% include_example java/org/apache/spark/examples/ml/JavaGeneralizedLinearRegressionExample.java %}
+ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GeneralizedLinearRegression) for more details. {% include_example python/ml/generalized_linear_regression_example.py %} @@ -705,14 +735,23 @@ The implementation matches the result from R's survival function
+ +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.AFTSurvivalRegression) for more details. + {% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %}
+ +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/AFTSurvivalRegression.html) for more details. + {% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %}
+ +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.AFTSurvivalRegression) for more details. + {% include_example python/ml/aft_survival_regression.py %}
diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index adb057ba7e25..b4d6be94f5eb 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -207,14 +207,29 @@ This example covers the concepts of `Estimator`, `Transformer`, and `Param`.
+ +Refer to the [`Estimator` Scala docs](api/scala/index.html#org.apache.spark.ml.Estimator), +the [`Transformer` Scala docs](api/scala/index.html#org.apache.spark.ml.Transformer) and +the [`Params` Scala docs](api/scala/index.html#org.apache.spark.ml.param.Params) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala %}
+ +Refer to the [`Estimator` Java docs](api/java/org/apache/spark/ml/Estimator.html), +the [`Transformer` Java docs](api/java/org/apache/spark/ml/Transformer.html) and +the [`Params` Java docs](api/java/org/apache/spark/ml/param/Params.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java %}
+ +Refer to the [`Estimator` Python docs](api/python/pyspark.ml.html#pyspark.ml.Estimator), +the [`Transformer` Python docs](api/python/pyspark.ml.html#pyspark.ml.Transformer) and +the [`Params` Python docs](api/python/pyspark.ml.html#pyspark.ml.param.Params) for more details on the API. + {% include_example python/ml/estimator_transformer_param_example.py %}
@@ -227,14 +242,24 @@ This example follows the simple text document `Pipeline` illustrated in the figu
+ +Refer to the [`Pipeline` Scala docs](api/scala/index.html#org.apache.spark.ml.Pipeline) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/PipelineExample.scala %}
+ + +Refer to the [`Pipeline` Java docs](api/java/org/apache/spark/ml/Pipeline.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaPipelineExample.java %}
+ +Refer to the [`Pipeline` Python docs](api/python/pyspark.ml.html#pyspark.ml.Pipeline) for more details on the API. + {% include_example python/ml/pipeline_example.py %}
diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md index e4b070331db4..a135adc4334c 100644 --- a/docs/ml-tuning.md +++ b/docs/ml-tuning.md @@ -75,15 +75,23 @@ However, it is also a well-established method for choosing parameters which is m
+ +Refer to the [`CrossValidator` Scala docs](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala %}
+ +Refer to the [`CrossValidator` Java docs](api/java/org/apache/spark/ml/tuning/CrossValidator.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java %}
+Refer to the [`CrossValidator` Python docs](api/python/pyspark.ml.html#pyspark.ml.tuning.CrossValidator) for more details on the API. + {% include_example python/ml/cross_validator.py %}
@@ -107,14 +115,23 @@ Like `CrossValidator`, `TrainValidationSplit` finally fits the `Estimator` using
+ +Refer to the [`TrainValidationSplit` Scala docs](api/scala/index.html#org.apache.spark.ml.tuning.TrainValidationSplit) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala %}
+ +Refer to the [`TrainValidationSplit` Java docs](api/java/org/apache/spark/ml/tuning/TrainValidationSplit.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java %}
+ +Refer to the [`TrainValidationSplit` Python docs](api/python/pyspark.ml.html#pyspark.ml.tuning.TrainValidationSplit) for more details on the API. + {% include_example python/ml/train_validation_split.py %}
From 7569cf6cb85bda7d0e76d3e75e286d4796e77e08 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Wed, 16 Nov 2016 11:59:00 +0000 Subject: [PATCH 233/381] [SPARK-18420][BUILD] Fix the errors caused by lint check in Java ## What changes were proposed in this pull request? Small fix, fix the errors caused by lint check in Java - Clear unused objects and `UnusedImports`. - Add comments around the method `finalize` of `NioBufferedFileInputStream`to turn off checkstyle. - Cut the line which is longer than 100 characters into two lines. ## How was this patch tested? Travis CI. ``` $ build/mvn -T 4 -q -DskipTests -Pyarn -Phadoop-2.3 -Pkinesis-asl -Phive -Phive-thriftserver install $ dev/lint-java ``` Before: ``` Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/network/util/TransportConf.java:[21,8] (imports) UnusedImports: Unused import - org.apache.commons.crypto.cipher.CryptoCipherFactory. [ERROR] src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java:[516,5] (modifier) RedundantModifier: Redundant 'public' modifier. [ERROR] src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java:[133] (coding) NoFinalizer: Avoid using finalizer method. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java:[71] (sizes) LineLength: Line is longer than 100 characters (found 113). [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java:[112] (sizes) LineLength: Line is longer than 100 characters (found 110). [ERROR] src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java:[31,17] (modifier) ModifierOrder: 'static' modifier out of order with the JLS suggestions. [ERROR]src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java:[64] (sizes) LineLength: Line is longer than 100 characters (found 103). [ERROR] src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java:[22,8] (imports) UnusedImports: Unused import - org.apache.spark.ml.linalg.Vectors. [ERROR] src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java:[51] (regexp) RegexpSingleline: No trailing whitespace allowed. ``` After: ``` $ build/mvn -T 4 -q -DskipTests -Pyarn -Phadoop-2.3 -Pkinesis-asl -Phive -Phive-thriftserver install $ dev/lint-java Using `mvn` from path: /home/travis/build/ConeyLiu/spark/build/apache-maven-3.3.9/bin/mvn Checkstyle checks passed. ``` Author: Xianyang Liu Closes #15865 from ConeyLiu/master. --- .../apache/spark/network/util/TransportConf.java | 1 - .../apache/spark/network/sasl/SparkSaslSuite.java | 2 +- .../spark/io/NioBufferedFileInputStream.java | 2 ++ dev/checkstyle.xml | 15 +++++++++++++++ .../spark/examples/ml/JavaInteractionExample.java | 3 +-- ...vaLogisticRegressionWithElasticNetExample.java | 4 ++-- .../sql/catalyst/expressions/UnsafeArrayData.java | 3 ++- .../sql/catalyst/expressions/UnsafeMapData.java | 3 ++- .../sql/catalyst/expressions/HiveHasherSuite.java | 1 - 9 files changed, 25 insertions(+), 9 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index d0d072849d38..012bb098f6fc 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -18,7 +18,6 @@ package org.apache.spark.network.util; import com.google.common.primitives.Ints; -import org.apache.commons.crypto.cipher.CryptoCipherFactory; /** * A central location that tracks all the settings we expose to users. diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 4e6146cf070d..ef2ab34b2277 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -513,7 +513,7 @@ private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAd boolean foundEncryptionHandler; String encryptHandlerName; - public EncryptionCheckerBootstrap(String encryptHandlerName) { + EncryptionCheckerBootstrap(String encryptHandlerName) { this.encryptHandlerName = encryptHandlerName; } diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index f6d1288cb263..ea5f1a9abf69 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -130,8 +130,10 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } + //checkstyle.off: NoFinalizer @Override protected void finalize() throws IOException { close(); } + //checkstyle.on: NoFinalizer } diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml index 3de6aa91dcd5..92c5251c8503 100644 --- a/dev/checkstyle.xml +++ b/dev/checkstyle.xml @@ -52,6 +52,20 @@ + + + + + + + @@ -168,5 +182,6 @@ + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java index 4213c05703cc..3684a87e22e7 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java @@ -19,7 +19,6 @@ import org.apache.spark.ml.feature.Interaction; import org.apache.spark.ml.feature.VectorAssembler; -import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; @@ -48,7 +47,7 @@ public static void main(String[] args) { RowFactory.create(5, 9, 2, 7, 10, 7, 3), RowFactory.create(6, 1, 1, 4, 2, 8, 4) ); - + StructType schema = new StructType(new StructField[]{ new StructField("id1", DataTypes.IntegerType, false, Metadata.empty()), new StructField("id2", DataTypes.IntegerType, false, Metadata.empty()), diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java index b8fb5972ea41..4cdec21d2302 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java @@ -60,8 +60,8 @@ public static void main(String[] args) { LogisticRegressionModel mlrModel = mlr.fit(training); // Print the coefficients and intercepts for logistic regression with multinomial family - System.out.println("Multinomial coefficients: " - + lrModel.coefficientMatrix() + "\nMultinomial intercepts: " + mlrModel.interceptVector()); + System.out.println("Multinomial coefficients: " + lrModel.coefficientMatrix() + + "\nMultinomial intercepts: " + mlrModel.interceptVector()); // $example off$ spark.stop(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 86523c147401..e8c33871f97b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -109,7 +109,8 @@ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { // Read the number of elements from the first 8 bytes. final long numElements = Platform.getLong(baseObject, baseOffset); assert numElements >= 0 : "numElements (" + numElements + ") should >= 0"; - assert numElements <= Integer.MAX_VALUE : "numElements (" + numElements + ") should <= Integer.MAX_VALUE"; + assert numElements <= Integer.MAX_VALUE : + "numElements (" + numElements + ") should <= Integer.MAX_VALUE"; this.numElements = (int)numElements; this.baseObject = baseObject; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index 35029f5a50e3..f17441dfccb6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -68,7 +68,8 @@ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { // Read the numBytes of key array from the first 8 bytes. final long keyArraySize = Platform.getLong(baseObject, baseOffset); assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0"; - assert keyArraySize <= Integer.MAX_VALUE : "keyArraySize (" + keyArraySize + ") should <= Integer.MAX_VALUE"; + assert keyArraySize <= Integer.MAX_VALUE : + "keyArraySize (" + keyArraySize + ") should <= Integer.MAX_VALUE"; final int valueArraySize = sizeInBytes - (int)keyArraySize - 8; assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0"; diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index 67a5eb0c7fe8..b67c6f3e6e85 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -28,7 +28,6 @@ import java.util.Set; public class HiveHasherSuite { - private final static HiveHasher hasher = new HiveHasher(); @Test public void testKnownIntegerInputs() { From 608ecc512b759514c75a1b475582f237ed569f10 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 16 Nov 2016 08:25:15 -0800 Subject: [PATCH 234/381] [SPARK-18415][SQL] Weird Plan Output when CTE used in RunnableCommand ### What changes were proposed in this pull request? Currently, when CTE is used in RunnableCommand, the Analyzer does not replace the logical node `With`. The child plan of RunnableCommand is not resolved. Thus, the output of the `With` plan node looks very confusing. For example, ``` sql( """ |CREATE VIEW cte_view AS |WITH w AS (SELECT 1 AS n), cte1 (select 2), cte2 as (select 3) |SELECT n FROM w """.stripMargin).explain() ``` The output is like ``` ExecutedCommand +- CreateViewCommand `cte_view`, WITH w AS (SELECT 1 AS n), cte1 (select 2), cte2 as (select 3) SELECT n FROM w, false, false, PersistedView +- 'With [(w,SubqueryAlias w +- Project [1 AS n#16] +- OneRowRelation$ ), (cte1,'SubqueryAlias cte1 +- 'Project [unresolvedalias(2, None)] +- OneRowRelation$ ), (cte2,'SubqueryAlias cte2 +- 'Project [unresolvedalias(3, None)] +- OneRowRelation$ )] +- 'Project ['n] +- 'UnresolvedRelation `w` ``` After the fix, the output is as shown below. ``` ExecutedCommand +- CreateViewCommand `cte_view`, WITH w AS (SELECT 1 AS n), cte1 (select 2), cte2 as (select 3) SELECT n FROM w, false, false, PersistedView +- CTE [w, cte1, cte2] : :- SubqueryAlias w : : +- Project [1 AS n#16] : : +- OneRowRelation$ : :- 'SubqueryAlias cte1 : : +- 'Project [unresolvedalias(2, None)] : : +- OneRowRelation$ : +- 'SubqueryAlias cte2 : +- 'Project [unresolvedalias(3, None)] : +- OneRowRelation$ +- 'Project ['n] +- 'UnresolvedRelation `w` ``` BTW, this PR also fixes the output of the view type. ### How was this patch tested? Manual Author: gatorsmile Closes #15854 from gatorsmile/cteName. --- .../catalyst/plans/logical/basicLogicalOperators.scala | 8 ++++++++ .../org/apache/spark/sql/execution/command/views.scala | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 4dcc2885536e..4e333d57f362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * When planning take() or collect() operations, this special node that is inserted at the top of @@ -404,6 +405,13 @@ case class InsertIntoTable( */ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) extends UnaryNode { override def output: Seq[Attribute] = child.output + + override def simpleString: String = { + val cteAliases = Utils.truncatedString(cteRelations.map(_._1), "[", ", ", "]") + s"CTE $cteAliases" + } + + override def innerChildren: Seq[QueryPlan[_]] = cteRelations.map(_._2) } case class WithWindowDefinition( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 30472ec45ce4..154141bf83c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -33,7 +33,9 @@ import org.apache.spark.sql.types.MetadataBuilder * ViewType is used to specify the expected view type when we want to create or replace a view in * [[CreateViewCommand]]. */ -sealed trait ViewType +sealed trait ViewType { + override def toString: String = getClass.getSimpleName.stripSuffix("$") +} /** * LocalTempView means session-scoped local temporary views. Its lifetime is the lifetime of the From 0048ce7ce64b02cbb6a1c4a2963a0b1b9541047e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Nov 2016 10:00:59 -0800 Subject: [PATCH 235/381] [SPARK-18459][SPARK-18460][STRUCTUREDSTREAMING] Rename triggerId to batchId and add triggerDetails to json in StreamingQueryStatus ## What changes were proposed in this pull request? SPARK-18459: triggerId seems like a number that should be increasing with each trigger, whether or not there is data in it. However, actually, triggerId increases only where there is a batch of data in a trigger. So its better to rename it to batchId. SPARK-18460: triggerDetails was missing from json representation. Fixed it. ## How was this patch tested? Updated existing unit tests. Author: Tathagata Das Closes #15895 from tdas/SPARK-18459. --- python/pyspark/sql/streaming.py | 6 ++--- .../execution/streaming/StreamMetrics.scala | 8 +++---- .../sql/streaming/StreamingQueryStatus.scala | 4 ++-- .../streaming/StreamMetricsSuite.scala | 8 +++---- .../StreamingQueryListenerSuite.scala | 4 ++-- .../streaming/StreamingQueryStatusSuite.scala | 22 +++++++++++++++++-- 6 files changed, 35 insertions(+), 17 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index f326f1623269..0e4589be976e 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -212,12 +212,12 @@ def __str__(self): Processing rate 23.5 rows/sec Latency: 345.0 ms Trigger details: + batchId: 5 isDataPresentInTrigger: true isTriggerActive: true latency.getBatch.total: 20 latency.getOffset.total: 10 numRows.input.total: 100 - triggerId: 5 Source statuses [1 source]: Source 1 - MySource1 Available offset: 0 @@ -341,8 +341,8 @@ def triggerDetails(self): If no trigger is currently active, then it will have details of the last completed trigger. >>> sqs.triggerDetails - {u'triggerId': u'5', u'latency.getBatch.total': u'20', u'numRows.input.total': u'100', - u'isTriggerActive': u'true', u'latency.getOffset.total': u'10', + {u'latency.getBatch.total': u'20', u'numRows.input.total': u'100', + u'isTriggerActive': u'true', u'batchId': u'5', u'latency.getOffset.total': u'10', u'isDataPresentInTrigger': u'true'} """ return self._jsqs.triggerDetails() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala index 5645554a58f6..942e6ed8944b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala @@ -78,13 +78,13 @@ class StreamMetrics(sources: Set[Source], triggerClock: Clock, codahaleSourceNam // =========== Setter methods =========== - def reportTriggerStarted(triggerId: Long): Unit = synchronized { + def reportTriggerStarted(batchId: Long): Unit = synchronized { numInputRows.clear() triggerDetails.clear() sourceTriggerDetails.values.foreach(_.clear()) - reportTriggerDetail(TRIGGER_ID, triggerId) - sources.foreach(s => reportSourceTriggerDetail(s, TRIGGER_ID, triggerId)) + reportTriggerDetail(BATCH_ID, batchId) + sources.foreach(s => reportSourceTriggerDetail(s, BATCH_ID, batchId)) reportTriggerDetail(IS_TRIGGER_ACTIVE, true) currentTriggerStartTimestamp = triggerClock.getTimeMillis() reportTriggerDetail(START_TIMESTAMP, currentTriggerStartTimestamp) @@ -217,7 +217,7 @@ object StreamMetrics extends Logging { } - val TRIGGER_ID = "triggerId" + val BATCH_ID = "batchId" val IS_TRIGGER_ACTIVE = "isTriggerActive" val IS_DATA_PRESENT_IN_TRIGGER = "isDataPresentInTrigger" val STATUS_MESSAGE = "statusMessage" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index 99c7729d0235..ba732ff7fc2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -102,7 +102,7 @@ class StreamingQueryStatus private( ("inputRate" -> JDouble(inputRate)) ~ ("processingRate" -> JDouble(processingRate)) ~ ("latency" -> latency.map(JDouble).getOrElse(JNothing)) ~ - ("triggerDetails" -> JsonProtocol.mapToJson(triggerDetails.asScala)) + ("triggerDetails" -> JsonProtocol.mapToJson(triggerDetails.asScala)) ~ ("sourceStatuses" -> JArray(sourceStatuses.map(_.jsonValue).toList)) ~ ("sinkStatus" -> sinkStatus.jsonValue) } @@ -151,7 +151,7 @@ private[sql] object StreamingQueryStatus { desc = "MySink", offsetDesc = OffsetSeq(Some(LongOffset(1)) :: None :: Nil).toString), triggerDetails = Map( - TRIGGER_ID -> "5", + BATCH_ID -> "5", IS_TRIGGER_ACTIVE -> "true", IS_DATA_PRESENT_IN_TRIGGER -> "true", GET_OFFSET_LATENCY -> "10", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala index 938423db6474..38c4ece43977 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala @@ -50,10 +50,10 @@ class StreamMetricsSuite extends SparkFunSuite { assert(sm.currentSourceProcessingRate(source) === 0.0) assert(sm.currentLatency() === None) assert(sm.currentTriggerDetails() === - Map(TRIGGER_ID -> "1", IS_TRIGGER_ACTIVE -> "true", + Map(BATCH_ID -> "1", IS_TRIGGER_ACTIVE -> "true", START_TIMESTAMP -> "0", "key" -> "value")) assert(sm.currentSourceTriggerDetails(source) === - Map(TRIGGER_ID -> "1", "key2" -> "value2")) + Map(BATCH_ID -> "1", "key2" -> "value2")) // Finishing the trigger should calculate the rates, except input rate which needs // to have another trigger interval @@ -66,11 +66,11 @@ class StreamMetricsSuite extends SparkFunSuite { assert(sm.currentSourceProcessingRate(source) === 100.0) assert(sm.currentLatency() === None) assert(sm.currentTriggerDetails() === - Map(TRIGGER_ID -> "1", IS_TRIGGER_ACTIVE -> "false", + Map(BATCH_ID -> "1", IS_TRIGGER_ACTIVE -> "false", START_TIMESTAMP -> "0", FINISH_TIMESTAMP -> "1000", NUM_INPUT_ROWS -> "100", "key" -> "value")) assert(sm.currentSourceTriggerDetails(source) === - Map(TRIGGER_ID -> "1", NUM_SOURCE_INPUT_ROWS -> "100", "key2" -> "value2")) + Map(BATCH_ID -> "1", NUM_SOURCE_INPUT_ROWS -> "100", "key2" -> "value2")) // After another trigger starts, the rates and latencies should not change until // new rows are reported diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index cebb32a0a56c..98f3bec7080a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -84,7 +84,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { AssertOnLastQueryStatus { status: StreamingQueryStatus => // Check the correctness of the trigger info of the last completed batch reported by // onQueryProgress - assert(status.triggerDetails.containsKey("triggerId")) + assert(status.triggerDetails.containsKey("batchId")) assert(status.triggerDetails.get("isTriggerActive") === "false") assert(status.triggerDetails.get("isDataPresentInTrigger") === "true") @@ -104,7 +104,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { assert(status.triggerDetails.get("numRows.state.aggregation1.updated") === "1") assert(status.sourceStatuses.length === 1) - assert(status.sourceStatuses(0).triggerDetails.containsKey("triggerId")) + assert(status.sourceStatuses(0).triggerDetails.containsKey("batchId")) assert(status.sourceStatuses(0).triggerDetails.get("latency.getOffset.source") === "100") assert(status.sourceStatuses(0).triggerDetails.get("latency.getBatch.source") === "200") assert(status.sourceStatuses(0).triggerDetails.get("numRows.input.source") === "2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala index 6af19fb0c232..50a7d92ede9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala @@ -48,12 +48,12 @@ class StreamingQueryStatusSuite extends SparkFunSuite { | Processing rate 23.5 rows/sec | Latency: 345.0 ms | Trigger details: + | batchId: 5 | isDataPresentInTrigger: true | isTriggerActive: true | latency.getBatch.total: 20 | latency.getOffset.total: 10 | numRows.input.total: 100 - | triggerId: 5 | Source statuses [1 source]: | Source 1 - MySource1 | Available offset: 0 @@ -72,7 +72,11 @@ class StreamingQueryStatusSuite extends SparkFunSuite { test("json") { assert(StreamingQueryStatus.testStatus.json === """ - |{"sourceStatuses":[{"description":"MySource1","offsetDesc":"0","inputRate":15.5, + |{"name":"query","id":1,"timestamp":123,"inputRate":15.5,"processingRate":23.5, + |"latency":345.0,"triggerDetails":{"latency.getBatch.total":"20", + |"numRows.input.total":"100","isTriggerActive":"true","batchId":"5", + |"latency.getOffset.total":"10","isDataPresentInTrigger":"true"}, + |"sourceStatuses":[{"description":"MySource1","offsetDesc":"0","inputRate":15.5, |"processingRate":23.5,"triggerDetails":{"numRows.input.source":"100", |"latency.getOffset.source":"10","latency.getBatch.source":"20"}}], |"sinkStatus":{"description":"MySink","offsetDesc":"[1, -]"}} @@ -84,6 +88,20 @@ class StreamingQueryStatusSuite extends SparkFunSuite { StreamingQueryStatus.testStatus.prettyJson === """ |{ + | "name" : "query", + | "id" : 1, + | "timestamp" : 123, + | "inputRate" : 15.5, + | "processingRate" : 23.5, + | "latency" : 345.0, + | "triggerDetails" : { + | "latency.getBatch.total" : "20", + | "numRows.input.total" : "100", + | "isTriggerActive" : "true", + | "batchId" : "5", + | "latency.getOffset.total" : "10", + | "isDataPresentInTrigger" : "true" + | }, | "sourceStatuses" : [ { | "description" : "MySource1", | "offsetDesc" : "0", From bb6cdfd9a6a6b6c91aada7c3174436146045ed1e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Nov 2016 11:03:10 -0800 Subject: [PATCH 236/381] [SPARK-18461][DOCS][STRUCTUREDSTREAMING] Added more information about monitoring streaming queries ## What changes were proposed in this pull request? screen shot 2016-11-15 at 6 27 32 pm screen shot 2016-11-15 at 6 27 45 pm Author: Tathagata Das Closes #15897 from tdas/SPARK-18461. --- .../structured-streaming-programming-guide.md | 182 +++++++++++++++++- 1 file changed, 179 insertions(+), 3 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index d2545584ae3b..77b66b3b3a49 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1087,9 +1087,185 @@ spark.streams().awaitAnyTermination() # block until any one of them terminates
-Finally, for asynchronous monitoring of streaming queries, you can create and attach a `StreamingQueryListener` -([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs), -which will give you regular callback-based updates when queries are started and terminated. + +## Monitoring Streaming Queries +There are two ways you can monitor queries. You can directly get the current status +of an active query using `streamingQuery.status`, which will return a `StreamingQueryStatus` object +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryStatus)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryStatus.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.StreamingQueryStatus) docs) +that has all the details like current ingestion rates, processing rates, average latency, +details of the currently active trigger, etc. + +
+
+ +{% highlight scala %} +val query: StreamingQuery = ... + +println(query.status) + +/* Will print the current status of the query + +Status of query 'queryName' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + batchId: 5 + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: 0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [1, -] +*/ +{% endhighlight %} + +
+
+ +{% highlight java %} +StreamingQuery query = ... + +System.out.println(query.status); + +/* Will print the current status of the query + +Status of query 'queryName' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + batchId: 5 + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: 0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [1, -] +*/ +{% endhighlight %} + +
+
+ +{% highlight python %} +query = ... // a StreamingQuery + +print(query.status) + +''' +Will print the current status of the query + +Status of query 'queryName' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + batchId: 5 + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: 0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [1, -] +''' +{% endhighlight %} + +
+
+ + +You can also asynchronously monitor all queries associated with a +`SparkSession` by attaching a `StreamingQueryListener` +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs). +Once you attach your custom `StreamingQueryListener` object with +`sparkSession.streams.attachListener()`, you will get callbacks when a query is started and +stopped and when there is progress made in an active query. Here is an example, + +
+
+ +{% highlight scala %} +val spark: SparkSession = ... + +spark.streams.addListener(new StreamingQueryListener() { + + override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { + println("Query started: " + queryTerminated.queryStatus.name) + } + override def onQueryTerminated(queryTerminated: QueryTerminatedEvent): Unit = { + println("Query terminated: " + queryTerminated.queryStatus.name) + } + override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = { + println("Query made progress: " + queryProgress.queryStatus) + } +}) +{% endhighlight %} + +
+
+ +{% highlight java %} +SparkSession spark = ... + +spark.streams.addListener(new StreamingQueryListener() { + + @Overrides void onQueryStarted(QueryStartedEvent queryStarted) { + System.out.println("Query started: " + queryTerminated.queryStatus.name); + } + @Overrides void onQueryTerminated(QueryTerminatedEvent queryTerminated) { + System.out.println("Query terminated: " + queryTerminated.queryStatus.name); + } + @Overrides void onQueryProgress(QueryProgressEvent queryProgress) { + System.out.println("Query made progress: " + queryProgress.queryStatus); + } +}); +{% endhighlight %} + +
+
+{% highlight bash %} +Not available in Python. +{% endhighlight %} + +
+
## Recovering from Failures with Checkpointing In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. As of Spark 2.0, this checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). From a36a76ac43c36a3b897a748bd9f138b629dbc684 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 16 Nov 2016 14:22:15 -0800 Subject: [PATCH 237/381] [SPARK-1267][SPARK-18129] Allow PySpark to be pip installed ## What changes were proposed in this pull request? This PR aims to provide a pip installable PySpark package. This does a bunch of work to copy the jars over and package them with the Python code (to prevent challenges from trying to use different versions of the Python code with different versions of the JAR). It does not currently publish to PyPI but that is the natural follow up (SPARK-18129). Done: - pip installable on conda [manual tested] - setup.py installed on a non-pip managed system (RHEL) with YARN [manual tested] - Automated testing of this (virtualenv) - packaging and signing with release-build* Possible follow up work: - release-build update to publish to PyPI (SPARK-18128) - figure out who owns the pyspark package name on prod PyPI (is it someone with in the project or should we ask PyPI or should we choose a different name to publish with like ApachePySpark?) - Windows support and or testing ( SPARK-18136 ) - investigate details of wheel caching and see if we can avoid cleaning the wheel cache during our test - consider how we want to number our dev/snapshot versions Explicitly out of scope: - Using pip installed PySpark to start a standalone cluster - Using pip installed PySpark for non-Python Spark programs *I've done some work to test release-build locally but as a non-committer I've just done local testing. ## How was this patch tested? Automated testing with virtualenv, manual testing with conda, a system wide install, and YARN integration. release-build changes tested locally as a non-committer (no testing of upload artifacts to Apache staging websites) Author: Holden Karau Author: Juliet Hougland Author: Juliet Hougland Closes #15659 from holdenk/SPARK-1267-pip-install-pyspark. --- .gitignore | 2 + bin/beeline | 2 +- bin/find-spark-home | 41 ++++ bin/load-spark-env.sh | 2 +- bin/pyspark | 6 +- bin/run-example | 2 +- bin/spark-class | 6 +- bin/spark-shell | 4 +- bin/spark-sql | 2 +- bin/spark-submit | 2 +- bin/sparkR | 2 +- dev/create-release/release-build.sh | 26 ++- dev/create-release/release-tag.sh | 11 +- dev/lint-python | 4 +- dev/make-distribution.sh | 16 +- dev/pip-sanity-check.py | 36 +++ dev/run-pip-tests | 115 ++++++++++ dev/run-tests-jenkins.py | 1 + dev/run-tests.py | 7 + dev/sparktestsupport/__init__.py | 1 + docs/building-spark.md | 8 + docs/index.md | 4 +- .../spark/launcher/CommandBuilderUtils.java | 2 +- python/MANIFEST.in | 22 ++ python/README.md | 32 +++ python/pyspark/__init__.py | 1 + python/pyspark/find_spark_home.py | 74 +++++++ python/pyspark/java_gateway.py | 3 +- python/pyspark/version.py | 19 ++ python/setup.cfg | 22 ++ python/setup.py | 209 ++++++++++++++++++ 31 files changed, 660 insertions(+), 24 deletions(-) create mode 100755 bin/find-spark-home create mode 100644 dev/pip-sanity-check.py create mode 100755 dev/run-pip-tests create mode 100644 python/MANIFEST.in create mode 100644 python/README.md create mode 100755 python/pyspark/find_spark_home.py create mode 100644 python/pyspark/version.py create mode 100644 python/setup.cfg create mode 100644 python/setup.py diff --git a/.gitignore b/.gitignore index 39d17e1793f7..5634a434db0c 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,8 @@ project/plugins/project/build.properties project/plugins/src_managed/ project/plugins/target/ python/lib/pyspark.zip +python/deps +python/pyspark/python reports/ scalastyle-on-compile.generated.xml scalastyle-output.xml diff --git a/bin/beeline b/bin/beeline index 1627626941a7..058534699e44 100755 --- a/bin/beeline +++ b/bin/beeline @@ -25,7 +25,7 @@ set -o posix # Figure out if SPARK_HOME is set if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi CLASS="org.apache.hive.beeline.BeeLine" diff --git a/bin/find-spark-home b/bin/find-spark-home new file mode 100755 index 000000000000..fa78407d4175 --- /dev/null +++ b/bin/find-spark-home @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Attempts to find a proper value for SPARK_HOME. Should be included using "source" directive. + +FIND_SPARK_HOME_PYTHON_SCRIPT="$(cd "$(dirname "$0")"; pwd)/find_spark_home.py" + +# Short cirtuit if the user already has this set. +if [ ! -z "${SPARK_HOME}" ]; then + exit 0 +elif [ ! -f "$FIND_SPARK_HOME_PYTHON_SCRIPT" ]; then + # If we are not in the same directory as find_spark_home.py we are not pip installed so we don't + # need to search the different Python directories for a Spark installation. + # Note only that, if the user has pip installed PySpark but is directly calling pyspark-shell or + # spark-submit in another directory we want to use that version of PySpark rather than the + # pip installed version of PySpark. + export SPARK_HOME="$(cd "$(dirname "$0")"/..; pwd)" +else + # We are pip installed, use the Python script to resolve a reasonable SPARK_HOME + # Default to standard python interpreter unless told otherwise + if [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then + PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"python"}" + fi + export SPARK_HOME=$($PYSPARK_DRIVER_PYTHON "$FIND_SPARK_HOME_PYTHON_SCRIPT") +fi diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index eaea964ed5b3..8a2f709960a2 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -23,7 +23,7 @@ # Figure out where Spark is installed if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi if [ -z "$SPARK_ENV_LOADED" ]; then diff --git a/bin/pyspark b/bin/pyspark index d6b3ab0a4432..98387c2ec5b8 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh @@ -46,7 +46,7 @@ WORKS_WITH_IPYTHON=$(python -c 'import sys; print(sys.version_info >= (2, 7, 0)) # Determine the Python executable to use for the executors: if [[ -z "$PYSPARK_PYTHON" ]]; then - if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! WORKS_WITH_IPYTHON ]]; then + if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! $WORKS_WITH_IPYTHON ]]; then echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2 exit 1 else @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m $1 + exec "$PYSPARK_DRIVER_PYTHON" -m "$1" exit fi diff --git a/bin/run-example b/bin/run-example index dd0e3c412026..4ba5399311d3 100755 --- a/bin/run-example +++ b/bin/run-example @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/run-example [options] example-class [example args]" diff --git a/bin/spark-class b/bin/spark-class index 377c8d1add3f..77ea40cc3794 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi . "${SPARK_HOME}"/bin/load-spark-env.sh @@ -27,7 +27,7 @@ fi if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" else - if [ `command -v java` ]; then + if [ "$(command -v java)" ]; then RUNNER="java" else echo "JAVA_HOME is not set" >&2 @@ -36,7 +36,7 @@ else fi # Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then +if [ -d "${SPARK_HOME}/jars" ]; then SPARK_JARS_DIR="${SPARK_HOME}/jars" else SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" diff --git a/bin/spark-shell b/bin/spark-shell index 6583b5bd880e..421f36cac3d4 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -21,7 +21,7 @@ # Shell script for starting the Spark Shell REPL cygwin=false -case "`uname`" in +case "$(uname)" in CYGWIN*) cygwin=true;; esac @@ -29,7 +29,7 @@ esac set -o posix if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" diff --git a/bin/spark-sql b/bin/spark-sql index 970d12cbf51d..b08b944ebd31 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" diff --git a/bin/spark-submit b/bin/spark-submit index 023f9c162f4b..4e9d3614e637 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi # disable randomized hash for string in Python 3.3+ diff --git a/bin/sparkR b/bin/sparkR index 2c07a82e2173..29ab10df8ab6 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 81f0d63054e2..1dbfa3b6e361 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -162,14 +162,35 @@ if [[ "$1" == "package" ]]; then export ZINC_PORT=$ZINC_PORT echo "Creating distribution: $NAME ($FLAGS)" + # Write out the NAME and VERSION to PySpark version info we rewrite the - into a . and SNAPSHOT + # to dev0 to be closer to PEP440. We use the NAME as a "local version". + PYSPARK_VERSION=`echo "$SPARK_VERSION+$NAME" | sed -r "s/-/./" | sed -r "s/SNAPSHOT/dev0/"` + echo "__version__='$PYSPARK_VERSION'" > python/pyspark/version.py + # Get maven home set by MVN MVN_HOME=`$MVN -version 2>&1 | grep 'Maven home' | awk '{print $NF}'` - ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz $FLAGS \ + echo "Creating distribution" + ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz --pip $FLAGS \ -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log cd .. - cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . + echo "Copying and signing python distribution" + PYTHON_DIST_NAME=pyspark-$PYSPARK_VERSION.tar.gz + cp spark-$SPARK_VERSION-bin-$NAME/python/dist/$PYTHON_DIST_NAME . + + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ + --output $PYTHON_DIST_NAME.asc \ + --detach-sig $PYTHON_DIST_NAME + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + MD5 $PYTHON_DIST_NAME > \ + $PYTHON_DIST_NAME.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 $PYTHON_DIST_NAME > \ + $PYTHON_DIST_NAME.sha + + echo "Copying and signing regular binary distribution" + cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz @@ -208,6 +229,7 @@ if [[ "$1" == "package" ]]; then # Re-upload a second time and leave the files in the timestamped upload directory: LFTP mkdir -p $dest_dir LFTP mput -O $dest_dir 'spark-*' + LFTP mput -O $dest_dir 'pyspark-*' exit 0 fi diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index b7e5100ca740..370a62ce15bc 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -65,6 +65,7 @@ sed -i".tmp1" 's/Version.*$/Version: '"$RELEASE_VERSION"'/g' R/pkg/DESCRIPTION # Set the release version in docs sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml +sed -i".tmp3" 's/__version__ = .*$/__version__ = "'"$RELEASE_VERSION"'"/' python/pyspark/version.py git commit -a -m "Preparing Spark release $RELEASE_TAG" echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" @@ -74,12 +75,16 @@ git tag $RELEASE_TAG $MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs # Remove -SNAPSHOT before setting the R version as R expects version strings to only have numbers R_NEXT_VERSION=`echo $NEXT_VERSION | sed 's/-SNAPSHOT//g'` -sed -i".tmp2" 's/Version.*$/Version: '"$R_NEXT_VERSION"'/g' R/pkg/DESCRIPTION +sed -i".tmp4" 's/Version.*$/Version: '"$R_NEXT_VERSION"'/g' R/pkg/DESCRIPTION +# Write out the R_NEXT_VERSION to PySpark version info we use dev0 instead of SNAPSHOT to be closer +# to PEP440. +sed -i".tmp5" 's/__version__ = .*$/__version__ = "'"$R_NEXT_VERSION.dev0"'"/' python/pyspark/version.py + # Update docs with next version -sed -i".tmp3" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$NEXT_VERSION"'/g' docs/_config.yml +sed -i".tmp6" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$NEXT_VERSION"'/g' docs/_config.yml # Use R version for short version -sed -i".tmp4" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION"'/g' docs/_config.yml +sed -i".tmp7" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION"'/g' docs/_config.yml git commit -a -m "Preparing development version $NEXT_VERSION" diff --git a/dev/lint-python b/dev/lint-python index 63487043a50b..3f878c2dad6b 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -20,7 +20,9 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" PATHS_TO_CHECK="./python/pyspark/ ./examples/src/main/python/ ./dev/sparktestsupport" -PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py ./dev/run-tests-jenkins.py" +# TODO: fix pep8 errors with the rest of the Python scripts under dev +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/*.py ./dev/run-tests-jenkins.py" +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/pip-sanity-check.py" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 9be4fdfa51c9..49b46fbc3fb2 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -33,6 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`/.."; pwd)" DISTDIR="$SPARK_HOME/dist" MAKE_TGZ=false +MAKE_PIP=false NAME=none MVN="$SPARK_HOME/build/mvn" @@ -40,7 +41,7 @@ function exit_with_usage { echo "make-distribution.sh - tool for making binary distributions of Spark" echo "" echo "usage:" - cl_options="[--name] [--tgz] [--mvn ]" + cl_options="[--name] [--tgz] [--pip] [--mvn ]" echo "make-distribution.sh $cl_options " echo "See Spark's \"Building Spark\" doc for correct Maven options." echo "" @@ -67,6 +68,9 @@ while (( "$#" )); do --tgz) MAKE_TGZ=true ;; + --pip) + MAKE_PIP=true + ;; --mvn) MVN="$2" shift @@ -201,6 +205,16 @@ fi # Copy data files cp -r "$SPARK_HOME/data" "$DISTDIR" +# Make pip package +if [ "$MAKE_PIP" == "true" ]; then + echo "Building python distribution package" + cd $SPARK_HOME/python + python setup.py sdist + cd .. +else + echo "Skipping creating pip installable PySpark" +fi + # Copy other things mkdir "$DISTDIR"/conf cp "$SPARK_HOME"/conf/*.template "$DISTDIR"/conf diff --git a/dev/pip-sanity-check.py b/dev/pip-sanity-check.py new file mode 100644 index 000000000000..430c2ab52766 --- /dev/null +++ b/dev/pip-sanity-check.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark.sql import SparkSession +import sys + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("PipSanityCheck")\ + .getOrCreate() + sc = spark.sparkContext + rdd = sc.parallelize(range(100), 10) + value = rdd.reduce(lambda x, y: x + y) + if (value != 4950): + print("Value {0} did not match expected value.".format(value), file=sys.stderr) + sys.exit(-1) + print("Successfully ran pip sanity check") + + spark.stop() diff --git a/dev/run-pip-tests b/dev/run-pip-tests new file mode 100755 index 000000000000..e1da18e60bb3 --- /dev/null +++ b/dev/run-pip-tests @@ -0,0 +1,115 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Stop on error +set -e +# Set nullglob for when we are checking existence based on globs +shopt -s nullglob + +FWDIR="$(cd "$(dirname "$0")"/..; pwd)" +cd "$FWDIR" + +echo "Constucting virtual env for testing" +VIRTUALENV_BASE=$(mktemp -d) + +# Clean up the virtual env enviroment used if we created one. +function delete_virtualenv() { + echo "Cleaning up temporary directory - $VIRTUALENV_BASE" + rm -rf "$VIRTUALENV_BASE" +} +trap delete_virtualenv EXIT + +# Some systems don't have pip or virtualenv - in those cases our tests won't work. +if ! hash virtualenv 2>/dev/null; then + echo "Missing virtualenv skipping pip installability tests." + exit 0 +fi +if ! hash pip 2>/dev/null; then + echo "Missing pip, skipping pip installability tests." + exit 0 +fi + +# Figure out which Python execs we should test pip installation with +PYTHON_EXECS=() +if hash python2 2>/dev/null; then + # We do this since we are testing with virtualenv and the default virtual env python + # is in /usr/bin/python + PYTHON_EXECS+=('python2') +elif hash python 2>/dev/null; then + # If python2 isn't installed fallback to python if available + PYTHON_EXECS+=('python') +fi +if hash python3 2>/dev/null; then + PYTHON_EXECS+=('python3') +fi + +# Determine which version of PySpark we are building for archive name +PYSPARK_VERSION=$(python -c "exec(open('python/pyspark/version.py').read());print __version__") +PYSPARK_DIST="$FWDIR/python/dist/pyspark-$PYSPARK_VERSION.tar.gz" +# The pip install options we use for all the pip commands +PIP_OPTIONS="--upgrade --no-cache-dir --force-reinstall " +# Test both regular user and edit/dev install modes. +PIP_COMMANDS=("pip install $PIP_OPTIONS $PYSPARK_DIST" + "pip install $PIP_OPTIONS -e python/") + +for python in "${PYTHON_EXECS[@]}"; do + for install_command in "${PIP_COMMANDS[@]}"; do + echo "Testing pip installation with python $python" + # Create a temp directory for us to work in and save its name to a file for cleanup + echo "Using $VIRTUALENV_BASE for virtualenv" + VIRTUALENV_PATH="$VIRTUALENV_BASE"/$python + rm -rf "$VIRTUALENV_PATH" + mkdir -p "$VIRTUALENV_PATH" + virtualenv --python=$python "$VIRTUALENV_PATH" + source "$VIRTUALENV_PATH"/bin/activate + # Upgrade pip + pip install --upgrade pip + + echo "Creating pip installable source dist" + cd "$FWDIR"/python + $python setup.py sdist + + + echo "Installing dist into virtual env" + cd dist + # Verify that the dist directory only contains one thing to install + sdists=(*.tar.gz) + if [ ${#sdists[@]} -ne 1 ]; then + echo "Unexpected number of targets found in dist directory - please cleanup existing sdists first." + exit -1 + fi + # Do the actual installation + cd "$FWDIR" + $install_command + + cd / + + echo "Run basic sanity check on pip installed version with spark-submit" + spark-submit "$FWDIR"/dev/pip-sanity-check.py + echo "Run basic sanity check with import based" + python "$FWDIR"/dev/pip-sanity-check.py + echo "Run the tests for context.py" + python "$FWDIR"/python/pyspark/context.py + + cd "$FWDIR" + + done +done + +exit 0 diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index a48d918f9dc1..1d1e72faccf2 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -128,6 +128,7 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_MIMA"]: 'MiMa tests', ERROR_CODES["BLOCK_SPARK_UNIT_TESTS"]: 'Spark unit tests', ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: 'PySpark unit tests', + ERROR_CODES["BLOCK_PYSPARK_PIP_TESTS"]: 'PySpark pip packaging tests', ERROR_CODES["BLOCK_SPARKR_UNIT_TESTS"]: 'SparkR unit tests', ERROR_CODES["BLOCK_TIMEOUT"]: 'from timeout after a configured wait of \`%s\`' % ( tests_timeout) diff --git a/dev/run-tests.py b/dev/run-tests.py index 5d661f5f1a1c..ab285ac96af7 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -432,6 +432,12 @@ def run_python_tests(test_modules, parallelism): run_cmd(command) +def run_python_packaging_tests(): + set_title_and_block("Running PySpark packaging tests", "BLOCK_PYSPARK_PIP_TESTS") + command = [os.path.join(SPARK_HOME, "dev", "run-pip-tests")] + run_cmd(command) + + def run_build_tests(): set_title_and_block("Running build tests", "BLOCK_BUILD_TESTS") run_cmd([os.path.join(SPARK_HOME, "dev", "test-dependencies.sh")]) @@ -583,6 +589,7 @@ def main(): modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: run_python_tests(modules_with_python_tests, opts.parallelism) + run_python_packaging_tests() if any(m.should_run_r_tests for m in test_modules): run_sparkr_tests() diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 89015f8c4fb9..38f25da41f77 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -33,5 +33,6 @@ "BLOCK_SPARKR_UNIT_TESTS": 20, "BLOCK_JAVA_STYLE": 21, "BLOCK_BUILD_TESTS": 22, + "BLOCK_PYSPARK_PIP_TESTS": 23, "BLOCK_TIMEOUT": 124 } diff --git a/docs/building-spark.md b/docs/building-spark.md index 2b404bd3e116..88da0cc9c3bb 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -265,6 +265,14 @@ or Java 8 tests are automatically enabled when a Java 8 JDK is detected. If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. +## PySpark pip installable + +If you are building Spark for use in a Python environment and you wish to pip install it, you will first need to build the Spark JARs as described above. Then you can construct an sdist package suitable for setup.py and pip installable package. + + cd python; python setup.py sdist + +**Note:** Due to packaging requirements you can not directly pip install from the Python directory, rather you must first build the sdist package as described above. + ## PySpark Tests with Maven If you are building PySpark and wish to run the PySpark tests you will need to build Spark with Hive support. diff --git a/docs/index.md b/docs/index.md index fe51439ae08d..39de11de854a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -14,7 +14,9 @@ It also supports a rich set of higher-level tools including [Spark SQL](sql-prog Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions. Users can also download a "Hadoop free" binary and run Spark with any Hadoop version -[by augmenting Spark's classpath](hadoop-provided.html). +[by augmenting Spark's classpath](hadoop-provided.html). +Scala and Java users can include Spark in their projects using its maven cooridnates and in the future Python users can also install Spark from PyPI. + If you'd like to build Spark from source, visit [Building Spark](building-spark.html). diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 62a22008d0d5..250b2a882feb 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -357,7 +357,7 @@ static int javaMajorVersion(String javaVersion) { static String findJarsDir(String sparkHome, String scalaVersion, boolean failIfNotFound) { // TODO: change to the correct directory once the assembly build is changed. File libdir; - if (new File(sparkHome, "RELEASE").isFile()) { + if (new File(sparkHome, "jars").isDirectory()) { libdir = new File(sparkHome, "jars"); checkState(!failIfNotFound || libdir.isDirectory(), "Library directory '%s' does not exist.", diff --git a/python/MANIFEST.in b/python/MANIFEST.in new file mode 100644 index 000000000000..bbcce1baa439 --- /dev/null +++ b/python/MANIFEST.in @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +global-exclude *.py[cod] __pycache__ .DS_Store +recursive-include deps/jars *.jar +graft deps/bin +recursive-include deps/examples *.py +recursive-include lib *.zip +include README.md diff --git a/python/README.md b/python/README.md new file mode 100644 index 000000000000..0a5c8010b848 --- /dev/null +++ b/python/README.md @@ -0,0 +1,32 @@ +# Apache Spark + +Spark is a fast and general cluster computing system for Big Data. It provides +high-level APIs in Scala, Java, Python, and R, and an optimized engine that +supports general computation graphs for data analysis. It also supports a +rich set of higher-level tools including Spark SQL for SQL and DataFrames, +MLlib for machine learning, GraphX for graph processing, +and Spark Streaming for stream processing. + + + +## Online Documentation + +You can find the latest Spark documentation, including a programming +guide, on the [project web page](http://spark.apache.org/documentation.html) + + +## Python Packaging + +This README file only contains basic information related to pip installed PySpark. +This packaging is currently experimental and may change in future versions (although we will do our best to keep compatibility). +Using PySpark requires the Spark JARs, and if you are building this from source please see the builder instructions at +["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). + +The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to setup your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html). + + +**NOTE:** If you are using this with a Spark standalone cluster you must ensure that the version (including minor version) matches or you may experience odd errors. + +## Python Requirements + +At its core PySpark depends on Py4J (currently version 0.10.4), but additional sub-packages have their own requirements (including numpy and pandas). \ No newline at end of file diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index ec1687415a7f..5f93586a48a5 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -50,6 +50,7 @@ from pyspark.serializers import MarshalSerializer, PickleSerializer from pyspark.status import * from pyspark.profiler import Profiler, BasicProfiler +from pyspark.version import __version__ def since(version): diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py new file mode 100755 index 000000000000..212a618b767a --- /dev/null +++ b/python/pyspark/find_spark_home.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This script attempt to determine the correct setting for SPARK_HOME given +# that Spark may have been installed on the system with pip. + +from __future__ import print_function +import os +import sys + + +def _find_spark_home(): + """Find the SPARK_HOME.""" + # If the enviroment has SPARK_HOME set trust it. + if "SPARK_HOME" in os.environ: + return os.environ["SPARK_HOME"] + + def is_spark_home(path): + """Takes a path and returns true if the provided path could be a reasonable SPARK_HOME""" + return (os.path.isfile(os.path.join(path, "bin/spark-submit")) and + (os.path.isdir(os.path.join(path, "jars")) or + os.path.isdir(os.path.join(path, "assembly")))) + + paths = ["../", os.path.dirname(os.path.realpath(__file__))] + + # Add the path of the PySpark module if it exists + if sys.version < "3": + import imp + try: + module_home = imp.find_module("pyspark")[1] + paths.append(module_home) + # If we are installed in edit mode also look two dirs up + paths.append(os.path.join(module_home, "../../")) + except ImportError: + # Not pip installed no worries + pass + else: + from importlib.util import find_spec + try: + module_home = os.path.dirname(find_spec("pyspark").origin) + paths.append(module_home) + # If we are installed in edit mode also look two dirs up + paths.append(os.path.join(module_home, "../../")) + except ImportError: + # Not pip installed no worries + pass + + # Normalize the paths + paths = [os.path.abspath(p) for p in paths] + + try: + return next(path for path in paths if is_spark_home(path)) + except StopIteration: + print("Could not find valid SPARK_HOME while searching {0}".format(paths), file=sys.stderr) + exit(-1) + +if __name__ == "__main__": + print(_find_spark_home()) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index c1cf843d8438..3c783ae541a1 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -29,6 +29,7 @@ xrange = range from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from pyspark.find_spark_home import _find_spark_home from pyspark.serializers import read_int @@ -41,7 +42,7 @@ def launch_gateway(conf=None): if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: - SPARK_HOME = os.environ["SPARK_HOME"] + SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" diff --git a/python/pyspark/version.py b/python/pyspark/version.py new file mode 100644 index 000000000000..08a301695fda --- /dev/null +++ b/python/pyspark/version.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "2.1.0.dev0" diff --git a/python/setup.cfg b/python/setup.cfg new file mode 100644 index 000000000000..d100b932bbaf --- /dev/null +++ b/python/setup.cfg @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +[bdist_wheel] +universal = 1 + +[metadata] +description-file = README.md diff --git a/python/setup.py b/python/setup.py new file mode 100644 index 000000000000..625aea04073f --- /dev/null +++ b/python/setup.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import glob +import os +import sys +from setuptools import setup, find_packages +from shutil import copyfile, copytree, rmtree + +if sys.version_info < (2, 7): + print("Python versions prior to 2.7 are not supported for pip installed PySpark.", + file=sys.stderr) + exit(-1) + +try: + exec(open('pyspark/version.py').read()) +except IOError: + print("Failed to load PySpark version file for packaging. You must be in Spark's python dir.", + file=sys.stderr) + sys.exit(-1) +VERSION = __version__ +# A temporary path so we can access above the Python project root and fetch scripts and jars we need +TEMP_PATH = "deps" +SPARK_HOME = os.path.abspath("../") + +# Provide guidance about how to use setup.py +incorrect_invocation_message = """ +If you are installing pyspark from spark source, you must first build Spark and +run sdist. + + To build Spark with maven you can run: + ./build/mvn -DskipTests clean package + Building the source dist is done in the Python directory: + cd python + python setup.py sdist + pip install dist/*.tar.gz""" + +# Figure out where the jars are we need to package with PySpark. +JARS_PATH = glob.glob(os.path.join(SPARK_HOME, "assembly/target/scala-*/jars/")) + +if len(JARS_PATH) == 1: + JARS_PATH = JARS_PATH[0] +elif (os.path.isfile("../RELEASE") and len(glob.glob("../jars/spark*core*.jar")) == 1): + # Release mode puts the jars in a jars directory + JARS_PATH = os.path.join(SPARK_HOME, "jars") +elif len(JARS_PATH) > 1: + print("Assembly jars exist for multiple scalas ({0}), please cleanup assembly/target".format( + JARS_PATH), file=sys.stderr) + sys.exit(-1) +elif len(JARS_PATH) == 0 and not os.path.exists(TEMP_PATH): + print(incorrect_invocation_message, file=sys.stderr) + sys.exit(-1) + +EXAMPLES_PATH = os.path.join(SPARK_HOME, "examples/src/main/python") +SCRIPTS_PATH = os.path.join(SPARK_HOME, "bin") +SCRIPTS_TARGET = os.path.join(TEMP_PATH, "bin") +JARS_TARGET = os.path.join(TEMP_PATH, "jars") +EXAMPLES_TARGET = os.path.join(TEMP_PATH, "examples") + + +# Check and see if we are under the spark path in which case we need to build the symlink farm. +# This is important because we only want to build the symlink farm while under Spark otherwise we +# want to use the symlink farm. And if the symlink farm exists under while under Spark (e.g. a +# partially built sdist) we should error and have the user sort it out. +in_spark = (os.path.isfile("../core/src/main/scala/org/apache/spark/SparkContext.scala") or + (os.path.isfile("../RELEASE") and len(glob.glob("../jars/spark*core*.jar")) == 1)) + + +def _supports_symlinks(): + """Check if the system supports symlinks (e.g. *nix) or not.""" + return getattr(os, "symlink", None) is not None + + +if (in_spark): + # Construct links for setup + try: + os.mkdir(TEMP_PATH) + except: + print("Temp path for symlink to parent already exists {0}".format(TEMP_PATH), + file=sys.stderr) + exit(-1) + +try: + # We copy the shell script to be under pyspark/python/pyspark so that the launcher scripts + # find it where expected. The rest of the files aren't copied because they are accessed + # using Python imports instead which will be resolved correctly. + try: + os.makedirs("pyspark/python/pyspark") + except OSError: + # Don't worry if the directory already exists. + pass + copyfile("pyspark/shell.py", "pyspark/python/pyspark/shell.py") + + if (in_spark): + # Construct the symlink farm - this is necessary since we can't refer to the path above the + # package root and we need to copy the jars and scripts which are up above the python root. + if _supports_symlinks(): + os.symlink(JARS_PATH, JARS_TARGET) + os.symlink(SCRIPTS_PATH, SCRIPTS_TARGET) + os.symlink(EXAMPLES_PATH, EXAMPLES_TARGET) + else: + # For windows fall back to the slower copytree + copytree(JARS_PATH, JARS_TARGET) + copytree(SCRIPTS_PATH, SCRIPTS_TARGET) + copytree(EXAMPLES_PATH, EXAMPLES_TARGET) + else: + # If we are not inside of SPARK_HOME verify we have the required symlink farm + if not os.path.exists(JARS_TARGET): + print("To build packaging must be in the python directory under the SPARK_HOME.", + file=sys.stderr) + + if not os.path.isdir(SCRIPTS_TARGET): + print(incorrect_invocation_message, file=sys.stderr) + exit(-1) + + # Scripts directive requires a list of each script path and does not take wild cards. + script_names = os.listdir(SCRIPTS_TARGET) + scripts = list(map(lambda script: os.path.join(SCRIPTS_TARGET, script), script_names)) + # We add find_spark_home.py to the bin directory we install so that pip installed PySpark + # will search for SPARK_HOME with Python. + scripts.append("pyspark/find_spark_home.py") + + # Parse the README markdown file into rst for PyPI + long_description = "!!!!! missing pandoc do not upload to PyPI !!!!" + try: + import pypandoc + long_description = pypandoc.convert('README.md', 'rst') + except ImportError: + print("Could not import pypandoc - required to package PySpark", file=sys.stderr) + + setup( + name='pyspark', + version=VERSION, + description='Apache Spark Python API', + long_description=long_description, + author='Spark Developers', + author_email='dev@spark.apache.org', + url='https://github.com/apache/spark/tree/master/python', + packages=['pyspark', + 'pyspark.mllib', + 'pyspark.ml', + 'pyspark.sql', + 'pyspark.streaming', + 'pyspark.bin', + 'pyspark.jars', + 'pyspark.python.pyspark', + 'pyspark.python.lib', + 'pyspark.examples.src.main.python'], + include_package_data=True, + package_dir={ + 'pyspark.jars': 'deps/jars', + 'pyspark.bin': 'deps/bin', + 'pyspark.python.lib': 'lib', + 'pyspark.examples.src.main.python': 'deps/examples', + }, + package_data={ + 'pyspark.jars': ['*.jar'], + 'pyspark.bin': ['*'], + 'pyspark.python.lib': ['*.zip'], + 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, + scripts=scripts, + license='http://www.apache.org/licenses/LICENSE-2.0', + install_requires=['py4j==0.10.4'], + setup_requires=['pypandoc'], + extras_require={ + 'ml': ['numpy>=1.7'], + 'mllib': ['numpy>=1.7'], + 'sql': ['pandas'] + }, + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: Implementation :: CPython', + 'Programming Language :: Python :: Implementation :: PyPy'] + ) +finally: + # We only cleanup the symlink farm if we were in Spark, otherwise we are installing rather than + # packaging. + if (in_spark): + # Depending on cleaning up the symlink farm or copied version + if _supports_symlinks(): + os.remove(os.path.join(TEMP_PATH, "jars")) + os.remove(os.path.join(TEMP_PATH, "bin")) + os.remove(os.path.join(TEMP_PATH, "examples")) + else: + rmtree(os.path.join(TEMP_PATH, "jars")) + rmtree(os.path.join(TEMP_PATH, "bin")) + rmtree(os.path.join(TEMP_PATH, "examples")) + os.rmdir(TEMP_PATH) From 2ca8ae9aa1b29bf1f46d0b656d9885e438e67f53 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 16 Nov 2016 14:32:36 -0800 Subject: [PATCH 238/381] [SPARK-18186] Migrate HiveUDAFFunction to TypedImperativeAggregate for partial aggregation support ## What changes were proposed in this pull request? While being evaluated in Spark SQL, Hive UDAFs don't support partial aggregation. This PR migrates `HiveUDAFFunction`s to `TypedImperativeAggregate`, which already provides partial aggregation support for aggregate functions that may use arbitrary Java objects as aggregation states. The following snippet shows the effect of this PR: ```scala import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax sql(s"CREATE FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") spark.range(100).createOrReplaceTempView("t") // A query using both Spark SQL native `max` and Hive `max` sql(s"SELECT max(id), hive_max(id) FROM t").explain() ``` Before this PR: ``` == Physical Plan == SortAggregate(key=[], functions=[max(id#1L), default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax7475f57e), id#1L, false, 0, 0)]) +- Exchange SinglePartition +- *Range (0, 100, step=1, splits=Some(1)) ``` After this PR: ``` == Physical Plan == SortAggregate(key=[], functions=[max(id#1L), default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax5e18a6a7), id#1L, false, 0, 0)]) +- Exchange SinglePartition +- SortAggregate(key=[], functions=[partial_max(id#1L), partial_default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax5e18a6a7), id#1L, false, 0, 0)]) +- *Range (0, 100, step=1, splits=Some(1)) ``` The tricky part of the PR is mostly about updating and passing around aggregation states of `HiveUDAFFunction`s since the aggregation state of a Hive UDAF may appear in three different forms. Let's take a look at the testing `MockUDAF` added in this PR as an example. This UDAF computes the count of non-null values together with the count of nulls of a given column. Its aggregation state may appear as the following forms at different time: 1. A `MockUDAFBuffer`, which is a concrete subclass of `GenericUDAFEvaluator.AggregationBuffer` The form used by Hive UDAF API. This form is required by the following scenarios: - Calling `GenericUDAFEvaluator.iterate()` to update an existing aggregation state with new input values. - Calling `GenericUDAFEvaluator.terminate()` to get the final aggregated value from an existing aggregation state. - Calling `GenericUDAFEvaluator.merge()` to merge other aggregation states into an existing aggregation state. The existing aggregation state to be updated must be in this form. Conversions: - To form 2: `GenericUDAFEvaluator.terminatePartial()` - To form 3: Convert to form 2 first, and then to 3. 2. An `Object[]` array containing two `java.lang.Long` values. The form used to interact with Hive's `ObjectInspector`s. This form is required by the following scenarios: - Calling `GenericUDAFEvaluator.terminatePartial()` to convert an existing aggregation state in form 1 to form 2. - Calling `GenericUDAFEvaluator.merge()` to merge other aggregation states into an existing aggregation state. The input aggregation state must be in this form. Conversions: - To form 1: No direct method. Have to create an empty `AggregationBuffer` and merge it into the empty buffer. - To form 3: `unwrapperFor()`/`unwrap()` method of `HiveInspectors` 3. The byte array that holds data of an `UnsafeRow` with two `LongType` fields. The form used by Spark SQL to shuffle partial aggregation results. This form is required because `TypedImperativeAggregate` always asks its subclasses to serialize their aggregation states into a byte array. Conversions: - To form 1: Convert to form 2 first, and then to 1. - To form 2: `wrapperFor()`/`wrap()` method of `HiveInspectors` Here're some micro-benchmark results produced by the most recent master and this PR branch. Master: ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ w/o groupBy 339 / 372 3.1 323.2 1.0X w/ groupBy 503 / 529 2.1 479.7 0.7X ``` This PR: ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ w/o groupBy 116 / 126 9.0 110.8 1.0X w/ groupBy 151 / 159 6.9 144.0 0.8X ``` Benchmark code snippet: ```scala test("Hive UDAF benchmark") { val N = 1 << 20 sparkSession.sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") val benchmark = new Benchmark( name = "hive udaf vs spark af", valuesPerIteration = N, minNumIters = 5, warmupTime = 5.seconds, minTime = 5.seconds, outputPerIteration = true ) benchmark.addCase("w/o groupBy") { _ => sparkSession.range(N).agg("id" -> "hive_max").collect() } benchmark.addCase("w/ groupBy") { _ => sparkSession.range(N).groupBy($"id" % 10).agg("id" -> "hive_max").collect() } benchmark.run() sparkSession.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") } ``` ## How was this patch tested? New test suite `HiveUDAFSuite` is added. Author: Cheng Lian Closes #15703 from liancheng/partial-agg-hive-udaf. --- .../org/apache/spark/sql/hive/hiveUDFs.scala | 199 +++++++++++++----- .../sql/hive/execution/HiveUDAFSuite.scala | 152 +++++++++++++ 2 files changed, 301 insertions(+), 50 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 42033080dc34..32edd4aec286 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.hive +import java.nio.ByteBuffer + import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, - ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.spark.internal.Logging @@ -58,7 +60,7 @@ private[hive] case class HiveSimpleUDF( @transient private lazy val isUDFDeterministic = { - val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) + val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) udfType != null && udfType.deterministic() } @@ -75,7 +77,7 @@ private[hive] case class HiveSimpleUDF( @transient lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector( - method.getGenericReturnType(), ObjectInspectorOptions.JAVA)) + method.getGenericReturnType, ObjectInspectorOptions.JAVA)) @transient private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) @@ -263,8 +265,35 @@ private[hive] case class HiveGenericUDTF( } /** - * Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt - * performance a lot. + * While being evaluated by Spark SQL, the aggregation state of a Hive UDAF may be in the following + * three formats: + * + * 1. An instance of some concrete `GenericUDAFEvaluator.AggregationBuffer` class + * + * This is the native Hive representation of an aggregation state. Hive `GenericUDAFEvaluator` + * methods like `iterate()`, `merge()`, `terminatePartial()`, and `terminate()` use this format. + * We call these methods to evaluate Hive UDAFs. + * + * 2. A Java object that can be inspected using the `ObjectInspector` returned by the + * `GenericUDAFEvaluator.init()` method. + * + * Hive uses this format to produce a serializable aggregation state so that it can shuffle + * partial aggregation results. Whenever we need to convert a Hive `AggregationBuffer` instance + * into a Spark SQL value, we have to convert it to this format first and then do the conversion + * with the help of `ObjectInspector`s. + * + * 3. A Spark SQL value + * + * We use this format for serializing Hive UDAF aggregation states on Spark side. To be more + * specific, we convert `AggregationBuffer`s into equivalent Spark SQL values, write them into + * `UnsafeRow`s, and then retrieve the byte array behind those `UnsafeRow`s as serialization + * results. + * + * We may use the following methods to convert the aggregation state back and forth: + * + * - `wrap()`/`wrapperFor()`: from 3 to 1 + * - `unwrap()`/`unwrapperFor()`: from 1 to 3 + * - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3 */ private[hive] case class HiveUDAFFunction( name: String, @@ -273,7 +302,7 @@ private[hive] case class HiveUDAFFunction( isUDAFBridgeRequired: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with HiveInspectors { + extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -281,73 +310,73 @@ private[hive] case class HiveUDAFFunction( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) + // Hive `ObjectInspector`s for all child expressions (input parameters of the function). @transient - private lazy val resolver = - if (isUDAFBridgeRequired) { + private lazy val inputInspectors = children.map(toInspector).toArray + + // Spark SQL data types of input parameters. + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + + private def newEvaluator(): GenericUDAFEvaluator = { + val resolver = if (isUDAFBridgeRequired) { new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() } - @transient - private lazy val inspectors = children.map(toInspector).toArray - - @transient - private lazy val functionAndInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - val f = resolver.getEvaluator(parameterInfo) - f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) + resolver.getEvaluator(parameterInfo) } + // The UDAF evaluator used to consume raw input rows and produce partial aggregation results. @transient - private lazy val function = functionAndInspector._1 + private lazy val partial1ModeEvaluator = newEvaluator() + // Hive `ObjectInspector` used to inspect partial aggregation results. @transient - private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + private val partialResultInspector = partial1ModeEvaluator.init( + GenericUDAFEvaluator.Mode.PARTIAL1, + inputInspectors + ) + // The UDAF evaluator used to merge partial aggregation results. @transient - private lazy val returnInspector = functionAndInspector._2 + private lazy val partial2ModeEvaluator = { + val evaluator = newEvaluator() + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector)) + evaluator + } + // Spark SQL data type of partial aggregation results @transient - private lazy val unwrapper = unwrapperFor(returnInspector) + private lazy val partialResultDataType = inspectorToDataType(partialResultInspector) + // The UDAF evaluator used to compute the final result from a partial aggregation result objects. @transient - private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _ - - override def eval(input: InternalRow): Any = unwrapper(function.evaluate(buffer)) + private lazy val finalModeEvaluator = newEvaluator() + // Hive `ObjectInspector` used to inspect the final aggregation result object. @transient - private lazy val inputProjection = new InterpretedProjection(children) + private val returnInspector = finalModeEvaluator.init( + GenericUDAFEvaluator.Mode.FINAL, + Array(partialResultInspector) + ) + // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format. @transient - private lazy val cached = new Array[AnyRef](children.length) + private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + // Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into + // Spark SQL specific format. @transient - private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray - - // Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation - // buffer for it. - override def aggBufferSchema: StructType = StructType(Nil) - - override def update(_buffer: InternalRow, input: InternalRow): Unit = { - val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes)) - } - - override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { - throw new UnsupportedOperationException( - "Hive UDAF doesn't support partial aggregate") - } + private lazy val resultUnwrapper = unwrapperFor(returnInspector) - override def initialize(_buffer: InternalRow): Unit = { - buffer = function.getNewAggregationBuffer - } - - override val aggBufferAttributes: Seq[AttributeReference] = Nil + @transient + private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override val inputAggBufferAttributes: Seq[AttributeReference] = Nil + @transient + private lazy val aggBufferSerDe: AggregationBufferSerDe = new AggregationBufferSerDe // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our // catalyst type checking framework. @@ -355,7 +384,7 @@ private[hive] case class HiveUDAFFunction( override def nullable: Boolean = true - override def supportsPartial: Boolean = false + override def supportsPartial: Boolean = true override lazy val dataType: DataType = inspectorToDataType(returnInspector) @@ -365,4 +394,74 @@ private[hive] case class HiveUDAFFunction( val distinct = if (isDistinct) "DISTINCT " else " " s"$name($distinct${children.map(_.sql).mkString(", ")})" } + + override def createAggregationBuffer(): AggregationBuffer = + partial1ModeEvaluator.getNewAggregationBuffer + + @transient + private lazy val inputProjection = UnsafeProjection.create(children) + + override def update(buffer: AggregationBuffer, input: InternalRow): Unit = { + partial1ModeEvaluator.iterate( + buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes)) + } + + override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = { + // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation + // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts + // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and + // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion. + partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input)) + } + + override def eval(buffer: AggregationBuffer): Any = { + resultUnwrapper(finalModeEvaluator.terminate(buffer)) + } + + override def serialize(buffer: AggregationBuffer): Array[Byte] = { + // Serializes an `AggregationBuffer` that holds partial aggregation results so that we can + // shuffle it for global aggregation later. + aggBufferSerDe.serialize(buffer) + } + + override def deserialize(bytes: Array[Byte]): AggregationBuffer = { + // Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare + // for global aggregation by merging multiple partial aggregation results within a single group. + aggBufferSerDe.deserialize(bytes) + } + + // Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects + private class AggregationBufferSerDe { + private val partialResultUnwrapper = unwrapperFor(partialResultInspector) + + private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType) + + private val projection = UnsafeProjection.create(Array(partialResultDataType)) + + private val mutableRow = new GenericInternalRow(1) + + def serialize(buffer: AggregationBuffer): Array[Byte] = { + // `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object + // that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`. + // Then we can unwrap it to a Spark SQL value. + mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer))) + val unsafeRow = projection(mutableRow) + val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes) + unsafeRow.writeTo(bytes) + bytes.array() + } + + def deserialize(bytes: Array[Byte]): AggregationBuffer = { + // `GenericUDAFEvaluator` doesn't provide any method that is capable to convert an object + // returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The + // workaround here is creating an initial `AggregationBuffer` first and then merge the + // deserialized object into the buffer. + val buffer = partial2ModeEvaluator.getNewAggregationBuffer + val unsafeRow = new UnsafeRow(1) + unsafeRow.pointTo(bytes, bytes.length) + val partialResult = unsafeRow.get(0, partialResultDataType) + partial2ModeEvaluator.merge(buffer, partialResultWrapper(partialResult)) + buffer + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala new file mode 100644 index 000000000000..c9ef72ee112c --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDAFEvaluator, GenericUDAFMax} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.{AggregationBuffer, Mode} +import org.apache.hadoop.hive.ql.util.JavaDataModel +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { + import testImplicits._ + + protected override def beforeAll(): Unit = { + sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'") + sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") + + Seq( + (0: Integer) -> "val_0", + (1: Integer) -> "val_1", + (2: Integer) -> null, + (3: Integer) -> null + ).toDF("key", "value").repartition(2).createOrReplaceTempView("t") + } + + protected override def afterAll(): Unit = { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock") + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } + + test("built-in Hive UDAF") { + val df = sql("SELECT key % 2, hive_max(key) FROM t GROUP BY key % 2") + + val aggs = df.queryExecution.executedPlan.collect { + case agg: ObjectHashAggregateExec => agg + } + + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) + + checkAnswer(df, Seq( + Row(0, 2), + Row(1, 3) + )) + } + + test("customized Hive UDAF") { + val df = sql("SELECT key % 2, mock(value) FROM t GROUP BY key % 2") + + val aggs = df.queryExecution.executedPlan.collect { + case agg: ObjectHashAggregateExec => agg + } + + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) + + checkAnswer(df, Seq( + Row(0, Row(1, 1)), + Row(1, Row(1, 1)) + )) + } +} + +/** + * A testing Hive UDAF that computes the counts of both non-null values and nulls of a given column. + */ +class MockUDAF extends AbstractGenericUDAFResolver { + override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator +} + +class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long) + extends GenericUDAFEvaluator.AbstractAggregationBuffer { + + override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2 +} + +class MockUDAFEvaluator extends GenericUDAFEvaluator { + private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + + private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + + private val bufferOI = { + val fieldNames = Seq("nonNullCount", "nullCount").asJava + val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava + ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs) + } + + private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount") + + private val nullCountField = bufferOI.getStructFieldRef("nullCount") + + override def getNewAggregationBuffer: AggregationBuffer = new MockUDAFBuffer(0L, 0L) + + override def reset(agg: AggregationBuffer): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + buffer.nonNullCount = 0L + buffer.nullCount = 0L + } + + override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = bufferOI + + override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + if (parameters.head eq null) { + buffer.nullCount += 1L + } else { + buffer.nonNullCount += 1L + } + } + + override def merge(agg: AggregationBuffer, partial: Object): Unit = { + if (partial ne null) { + val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField)) + val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField)) + val buffer = agg.asInstanceOf[MockUDAFBuffer] + buffer.nonNullCount += nonNullCount + buffer.nullCount += nullCount + } + } + + override def terminatePartial(agg: AggregationBuffer): AnyRef = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long) + } + + override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg) +} From 55589987be89ff78dadf44498352fbbd811a206e Mon Sep 17 00:00:00 2001 From: Artur Sukhenko Date: Wed, 16 Nov 2016 15:08:01 -0800 Subject: [PATCH 239/381] [YARN][DOC] Increasing NodeManager's heap size with External Shuffle Service ## What changes were proposed in this pull request? Suggest users to increase `NodeManager's` heap size if `External Shuffle Service` is enabled as `NM` can spend a lot of time doing GC resulting in shuffle operations being a bottleneck due to `Shuffle Read blocked time` bumped up. Also because of GC `NodeManager` can use an enormous amount of CPU and cluster performance will suffer. I have seen NodeManager using 5-13G RAM and up to 2700% CPU with `spark_shuffle` service on. ## How was this patch tested? #### Added step 5: ![shuffle_service](https://cloud.githubusercontent.com/assets/15244468/20355499/2fec0fde-ac2a-11e6-8f8b-1c80daf71be1.png) Author: Artur Sukhenko Closes #15906 from Devian-ua/nmHeapSize. --- docs/running-on-yarn.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index cd18808681ec..fe0221ce7c5b 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -559,6 +559,8 @@ pre-packaged distribution. 1. In the `yarn-site.xml` on each node, add `spark_shuffle` to `yarn.nodemanager.aux-services`, then set `yarn.nodemanager.aux-services.spark_shuffle.class` to `org.apache.spark.network.yarn.YarnShuffleService`. +1. Increase `NodeManager's` heap size by setting `YARN_HEAPSIZE` (1000 by default) in `etc/hadoop/yarn-env.sh` +to avoid garbage collection issues during shuffle. 1. Restart all `NodeManager`s in your cluster. The following extra configuration options are available when the shuffle service is running on YARN: From 170eeb345f951de89a39fe565697b3e913011768 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 17 Nov 2016 11:21:08 +0800 Subject: [PATCH 240/381] [SPARK-18442][SQL] Fix nullability of WrapOption. ## What changes were proposed in this pull request? The nullability of `WrapOption` should be `false`. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #15887 from ueshin/issues/SPARK-18442. --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 50e2ac3c36d9..0e3d99127ed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -341,7 +341,7 @@ case class WrapOption(child: Expression, optType: DataType) override def dataType: DataType = ObjectType(classOf[Option[_]]) - override def nullable: Boolean = true + override def nullable: Boolean = false override def inputTypes: Seq[AbstractDataType] = optType :: Nil From 07b3f045cd6f79b92bc86b3b1b51d3d5e6bd37ce Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 17 Nov 2016 00:00:38 -0800 Subject: [PATCH 241/381] [SPARK-18464][SQL] support old table which doesn't store schema in metastore ## What changes were proposed in this pull request? Before Spark 2.1, users can create an external data source table without schema, and we will infer the table schema at runtime. In Spark 2.1, we decided to infer the schema when the table was created, so that we don't need to infer it again and again at runtime. This is a good improvement, but we should still respect and support old tables which doesn't store table schema in metastore. ## How was this patch tested? regression test. Author: Wenchen Fan Closes #15900 from cloud-fan/hive-catalog. --- .../spark/sql/execution/command/tables.scala | 8 ++++++- .../spark/sql/hive/HiveExternalCatalog.scala | 5 +++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +++- .../sql/hive/MetastoreDataSourcesSuite.scala | 22 +++++++++++++++++++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 119e732d0202..7049e53a7868 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -431,7 +431,13 @@ case class DescribeTableCommand( describeSchema(catalog.lookupRelation(table).schema, result) } else { val metadata = catalog.getTableMetadata(table) - describeSchema(metadata.schema, result) + if (metadata.schema.isEmpty) { + // In older version(prior to 2.1) of Spark, the table schema can be empty and should be + // inferred at runtime. We should still support it. + describeSchema(catalog.lookupRelation(metadata.identifier).schema, result) + } else { + describeSchema(metadata.schema, result) + } describePartitionInfo(metadata, result) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index cbd00da81cfc..843305883abc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1023,6 +1023,11 @@ object HiveExternalCatalog { // After SPARK-6024, we removed this flag. // Although we are not using `spark.sql.sources.schema` any more, we need to still support. DataType.fromJson(schema.get).asInstanceOf[StructType] + } else if (props.filterKeys(_.startsWith(DATASOURCE_SCHEMA_PREFIX)).isEmpty) { + // If there is no schema information in table properties, it means the schema of this table + // was empty when saving into metastore, which is possible in older version(prior to 2.1) of + // Spark. We should respect it. + new StructType() } else { val numSchemaParts = props.get(DATASOURCE_SCHEMA_NUMPARTS) if (numSchemaParts.isDefined) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 8e5fc88aad44..edbde5d10b47 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -64,7 +64,9 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val dataSource = DataSource( sparkSession, - userSpecifiedSchema = Some(table.schema), + // In older version(prior to 2.1) of Spark, the table schema can be empty and should be + // inferred at runtime. We should still support it. + userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema), partitionColumns = table.partitionColumnNames, bucketSpec = table.bucketSpec, className = table.provider.get, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index c50f92e783c8..4ab1a54edc46 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1371,4 +1371,26 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } } + + test("SPARK-18464: support old table which doesn't store schema in table properties") { + withTable("old") { + withTempPath { path => + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath) + val tableDesc = CatalogTable( + identifier = TableIdentifier("old", Some("default")), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat.empty.copy( + properties = Map("path" -> path.getAbsolutePath) + ), + schema = new StructType(), + properties = Map( + HiveExternalCatalog.DATASOURCE_PROVIDER -> "parquet")) + hiveClient.createTable(tableDesc, ignoreIfExists = false) + + checkAnswer(spark.table("old"), Row(1, "a")) + + checkAnswer(sql("DESC old"), Row("i", "int", null) :: Row("j", "string", null) :: Nil) + } + } + } } From a3cac7bd86a6fe8e9b42da1bf580aaeb59378304 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Thu, 17 Nov 2016 11:13:22 +0000 Subject: [PATCH 242/381] [YARN][DOC] Remove non-Yarn specific configurations from running-on-yarn.md ## What changes were proposed in this pull request? Remove `spark.driver.memory`, `spark.executor.memory`, `spark.driver.cores`, and `spark.executor.cores` from `running-on-yarn.md` as they are not Yarn-specific, and they are also defined in`configuration.md`. ## How was this patch tested? Build passed & Manually check. Author: Weiqing Yang Closes #15869 from weiqingy/yarnDoc. --- docs/running-on-yarn.md | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index fe0221ce7c5b..4d1fafc07b8f 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -117,28 +117,6 @@ To use a custom metrics.properties for the application master and executors, upd Use lower-case suffixes, e.g. k, m, g, t, and p, for kibi-, mebi-, gibi-, tebi-, and pebibytes, respectively. - - spark.driver.memory - 1g - - Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 1g, 2g). - -
Note: In client mode, this config must not be set through the SparkConf - directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-memory command line option - or in your default properties file. - - - - spark.driver.cores - 1 - - Number of cores used by the driver in YARN cluster mode. - Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN Application Master. - In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN Application Master instead. - - spark.yarn.am.cores 1 @@ -233,13 +211,6 @@ To use a custom metrics.properties for the application master and executors, upd Comma-separated list of jars to be placed in the working directory of each executor. - - spark.executor.cores - 1 in YARN mode, all the available cores on the worker in standalone mode. - - The number of cores to use on each executor. For YARN and standalone mode only. - - spark.executor.instances 2 @@ -247,13 +218,6 @@ To use a custom metrics.properties for the application master and executors, upd The number of executors for static allocation. With spark.dynamicAllocation.enabled, the initial set of executors will be at least this large. - - spark.executor.memory - 1g - - Amount of memory to use per executor process (e.g. 2g, 8g). - - spark.yarn.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 From 49b6f456aca350e9e2c170782aa5cc75e7822680 Mon Sep 17 00:00:00 2001 From: anabranch Date: Thu, 17 Nov 2016 11:34:55 +0000 Subject: [PATCH 243/381] [SPARK-18365][DOCS] Improve Sample Method Documentation ## What changes were proposed in this pull request? I found the documentation for the sample method to be confusing, this adds more clarification across all languages. - [x] Scala - [x] Python - [x] R - [x] RDD Scala - [ ] RDD Python with SEED - [X] RDD Java - [x] RDD Java with SEED - [x] RDD Python ## How was this patch tested? NA Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. Author: anabranch Author: Bill Chambers Closes #15815 from anabranch/SPARK-18365. --- R/pkg/R/DataFrame.R | 4 +++- .../main/scala/org/apache/spark/api/java/JavaRDD.scala | 8 ++++++-- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 3 +++ python/pyspark/rdd.py | 5 +++++ python/pyspark/sql/dataframe.py | 5 +++++ .../src/main/scala/org/apache/spark/sql/Dataset.scala | 10 ++++++++-- 6 files changed, 30 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1cf9b38ea648..4e3d97bb3ad0 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -936,7 +936,9 @@ setMethod("unique", #' Sample #' -#' Return a sampled subset of this SparkDataFrame using a random seed. +#' Return a sampled subset of this SparkDataFrame using a random seed. +#' Note: this is not guaranteed to provide exactly the fraction specified +#' of the total count of of the given SparkDataFrame. #' #' @param x A SparkDataFrame #' @param withReplacement Sampling with replacement or not diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 20d6c9341bf7..d67cff64e6e4 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -98,7 +98,9 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD with a random seed. + * Note: this is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size @@ -109,7 +111,9 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) sample(withReplacement, fraction, Utils.random.nextLong) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD, with a user-supplied seed. + * Note: this is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e018af35cb18..cded899db1f5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -466,6 +466,9 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. * + * Note: this is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2de2c2fd1a60..a163ceafe9d3 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -386,6 +386,11 @@ def sample(self, withReplacement, fraction, seed=None): with replacement: expected number of times each element is chosen; fraction must be >= 0 :param seed: seed for the random number generator + .. note:: + + This is not guaranteed to provide exactly the fraction specified of the total count + of the given :class:`DataFrame`. + >>> rdd = sc.parallelize(range(100), 4) >>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14 True diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 29710acf54c4..38998900837c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -549,6 +549,11 @@ def distinct(self): def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. + .. note:: + + This is not guaranteed to provide exactly the fraction specified of the total count + of the given :class:`DataFrame`. + >>> df.sample(False, 0.5, 42).count() 2 """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index af30683cc01c..3761773698df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1646,7 +1646,10 @@ class Dataset[T] private[sql]( } /** - * Returns a new Dataset by sampling a fraction of rows. + * Returns a new [[Dataset]] by sampling a fraction of rows, using a user-supplied seed. + * + * Note: this is NOT guaranteed to provide exactly the fraction of the count + * of the given [[Dataset]]. * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. @@ -1665,7 +1668,10 @@ class Dataset[T] private[sql]( } /** - * Returns a new Dataset by sampling a fraction of rows, using a random seed. + * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed. + * + * Note: this is NOT guaranteed to provide exactly the fraction of the total count + * of the given [[Dataset]]. * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. From de77c67750dc868d75d6af173c3820b75a9fe4b7 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Thu, 17 Nov 2016 13:37:42 +0000 Subject: [PATCH 244/381] [SPARK-17462][MLLIB]use VersionUtils to parse Spark version strings ## What changes were proposed in this pull request? Several places in MLlib use custom regexes or other approaches to parse Spark versions. Those should be fixed to use the VersionUtils. This PR replaces custom regexes with VersionUtils to get Spark version numbers. ## How was this patch tested? Existing tests. Signed-off-by: VinceShieh vincent.xieintel.com Author: VinceShieh Closes #15055 from VinceShieh/SPARK-17462. --- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 6 ++---- mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a0d481b294ac..26505b4cc150 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.util.VersionUtils.majorVersion /** * Common params for KMeans and KMeansModel @@ -232,10 +233,7 @@ object KMeansModel extends MLReadable[KMeansModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val versionRegex = "([0-9]+)\\.(.+)".r - val versionRegex(major, _) = metadata.sparkVersion - - val clusterCenters = if (major.toInt >= 2) { + val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 444006fe1edb..1e49352b8517 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.VersionUtils.majorVersion /** * Params for [[PCA]] and [[PCAModel]]. @@ -204,11 +205,8 @@ object PCAModel extends MLReadable[PCAModel] { override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - val versionRegex = "([0-9]+)\\.(.+)".r - val versionRegex(major, _) = metadata.sparkVersion - val dataPath = new Path(path, "data").toString - val model = if (major.toInt >= 2) { + val model = if (majorVersion(metadata.sparkVersion) >= 2) { val Row(pc: DenseMatrix, explainedVariance: DenseVector) = sparkSession.read.parquet(dataPath) .select("pc", "explainedVariance") From cdaf4ce9fe58c4606be8aa2a5c3756d30545c850 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 17 Nov 2016 13:40:16 +0000 Subject: [PATCH 245/381] [SPARK-18480][DOCS] Fix wrong links for ML guide docs ## What changes were proposed in this pull request? 1, There are two `[Graph.partitionBy]` in `graphx-programming-guide.md`, the first one had no effert. 2, `DataFrame`, `Transformer`, `Pipeline` and `Parameter` in `ml-pipeline.md` were linked to `ml-guide.html` by mistake. 3, `PythonMLLibAPI` in `mllib-linear-methods.md` was not accessable, because class `PythonMLLibAPI` is private. 4, Other link updates. ## How was this patch tested? manual tests Author: Zheng RuiFeng Closes #15912 from zhengruifeng/md_fix. --- docs/graphx-programming-guide.md | 1 - docs/ml-classification-regression.md | 4 ++-- docs/ml-features.md | 2 +- docs/ml-pipeline.md | 12 ++++++------ docs/mllib-linear-methods.md | 4 +--- .../main/scala/org/apache/spark/ml/feature/LSH.scala | 2 +- .../spark/ml/tree/impl/GradientBoostedTrees.scala | 8 ++++---- .../org/apache/spark/ml/tree/impl/RandomForest.scala | 8 ++++---- 8 files changed, 19 insertions(+), 22 deletions(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 1097cf1211c1..e271b28fb4f2 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -36,7 +36,6 @@ description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT [Graph.fromEdgeTuples]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdgeTuples[VD](RDD[(VertexId,VertexId)],VD,Option[PartitionStrategy])(ClassTag[VD]):Graph[VD,Int] [Graph.fromEdges]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdges[VD,ED](RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] [PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy -[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph$@partitionBy(partitionStrategy:org.apache.spark.graphx.PartitionStrategy):org.apache.spark.graphx.Graph[VD,ED] [PageRank]: api/scala/index.html#org.apache.spark.graphx.lib.PageRank$ [ConnectedComponents]: api/scala/index.html#org.apache.spark.graphx.lib.ConnectedComponents$ [TriangleCount]: api/scala/index.html#org.apache.spark.graphx.lib.TriangleCount$ diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 1aacc3e054b5..43cc79b9c081 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -984,7 +984,7 @@ Random forests combine many decision trees in order to reduce the risk of overfi The `spark.ml` implementation supports random forests for binary and multiclass classification and for regression, using both continuous and categorical features. -For more information on the algorithm itself, please see the [`spark.mllib` documentation on random forests](mllib-ensembles.html). +For more information on the algorithm itself, please see the [`spark.mllib` documentation on random forests](mllib-ensembles.html#random-forests). ### Inputs and Outputs @@ -1065,7 +1065,7 @@ GBTs iteratively train decision trees in order to minimize a loss function. The `spark.ml` implementation supports GBTs for binary classification and for regression, using both continuous and categorical features. -For more information on the algorithm itself, please see the [`spark.mllib` documentation on GBTs](mllib-ensembles.html). +For more information on the algorithm itself, please see the [`spark.mllib` documentation on GBTs](mllib-ensembles.html#gradient-boosted-trees-gbts). ### Inputs and Outputs diff --git a/docs/ml-features.md b/docs/ml-features.md index 19ec5746978a..d2f036fb083d 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -710,7 +710,7 @@ for more details on the API. `VectorIndexer` helps index categorical features in datasets of `Vector`s. It can both automatically decide which features are categorical and convert original values to category indices. Specifically, it does the following: -1. Take an input column of type [Vector](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) and a parameter `maxCategories`. +1. Take an input column of type [Vector](api/scala/index.html#org.apache.spark.ml.linalg.Vector) and a parameter `maxCategories`. 2. Decide which features should be categorical based on the number of distinct values, where features with at most `maxCategories` are declared categorical. 3. Compute 0-based category indices for each categorical feature. 4. Index categorical features and transform original feature values to indices. diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index b4d6be94f5eb..0384513ab701 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -38,26 +38,26 @@ algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Pipelines API, where the pipeline concept is mostly inspired by the [scikit-learn](http://scikit-learn.org/) project. -* **[`DataFrame`](ml-guide.html#dataframe)**: This ML API uses `DataFrame` from Spark SQL as an ML +* **[`DataFrame`](ml-pipeline.html#dataframe)**: This ML API uses `DataFrame` from Spark SQL as an ML dataset, which can hold a variety of data types. E.g., a `DataFrame` could have different columns storing text, feature vectors, true labels, and predictions. -* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. +* **[`Transformer`](ml-pipeline.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. E.g., an ML model is a `Transformer` which transforms a `DataFrame` with features into a `DataFrame` with predictions. -* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. +* **[`Estimator`](ml-pipeline.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. E.g., a learning algorithm is an `Estimator` which trains on a `DataFrame` and produces a model. -* **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. +* **[`Pipeline`](ml-pipeline.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. -* **[`Parameter`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. +* **[`Parameter`](ml-pipeline.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. ## DataFrame Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. This API adopts the `DataFrame` from Spark SQL in order to support a variety of data types. -`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. +`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#data-types) for a list of supported types. In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](mllib-data-types.html#local-vector) types. A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 816bdf131700..3085539b40e6 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -139,7 +139,7 @@ and logistic regression. Linear SVMs supports only binary classification, while logistic regression supports both binary and multiclass classification problems. For both methods, `spark.mllib` supports L1 and L2 regularized variants. -The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, +The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html#labeled-point) in MLlib, where labels are class indices starting from zero: $0, 1, 2, \ldots$. ### Linear Support Vector Machines (SVMs) @@ -491,5 +491,3 @@ Algorithms are all implemented in Scala: * [RidgeRegressionWithSGD](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) * [LassoWithSGD](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD) -Python calls the Scala implementation via -[PythonMLLibAPI](api/scala/index.html#org.apache.spark.mllib.api.python.PythonMLLibAPI). diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 333a8c364a88..eb117c40eea3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -40,7 +40,7 @@ private[ml] trait LSHParams extends HasInputCol with HasOutputCol { * @group param */ final val outputDim: IntParam = new IntParam(this, "outputDim", "output dimension, where" + - "increasing dimensionality lowers the false negative rate, and decreasing dimensionality" + + " increasing dimensionality lowers the false negative rate, and decreasing dimensionality" + " improves the running performance", ParamValidators.gt(0)) /** @group getParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 7bef899a633d..ede0a060eef9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -34,7 +34,7 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to train a gradient boosting model - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param input Training dataset: RDD of [[LabeledPoint]]. * @param seed Random seed. * @return tuple of ensemble models and weights: * (array of decision tree models, array of model weights) @@ -59,7 +59,7 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to validate a gradient boosting model - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param input Training dataset: RDD of [[LabeledPoint]]. * @param validationInput Validation dataset. * This dataset should be different from the training dataset, * but it should follow the same distribution. @@ -162,7 +162,7 @@ private[spark] object GradientBoostedTrees extends Logging { * Method to calculate error of the base learner for the gradient boosting calculation. * Note: This method is not used by the gradient boosting algorithm but is useful for debugging * purposes. - * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param data Training dataset: RDD of [[LabeledPoint]]. * @param trees Boosted Decision Tree models * @param treeWeights Learning rates at each boosting iteration. * @param loss evaluation metric. @@ -184,7 +184,7 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to compute error or loss for every iteration of gradient boosting. * - * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param data RDD of [[LabeledPoint]] * @param trees Boosted Decision Tree models * @param treeWeights Learning rates at each boosting iteration. * @param loss evaluation metric. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index b504f411d256..8ae5ca3c84b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -82,7 +82,7 @@ private[spark] object RandomForest extends Logging { /** * Train a random forest. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[LabeledPoint]] * @return an unweighted set of trees */ def run( @@ -343,7 +343,7 @@ private[spark] object RandomForest extends Logging { /** * Given a group of nodes, this finds the best split for each node. * - * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]] + * @param input Training data: RDD of [[TreePoint]] * @param metadata Learning and dataset metadata * @param topNodesForGroup For each tree in group, tree index -> root node. * Used for matching instances with nodes. @@ -854,10 +854,10 @@ private[spark] object RandomForest extends Logging { * and for multiclass classification with a high-arity feature, * there is one bin per category. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[LabeledPoint]] * @param metadata Learning and dataset metadata * @param seed random seed - * @return Splits, an Array of [[org.apache.spark.mllib.tree.model.Split]] + * @return Splits, an Array of [[Split]] * of size (numFeatures, numSplits) */ protected[tree] def findSplits( From b0aa1aa1af6c513a6a881eaea96abdd2b480ef98 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 17 Nov 2016 17:04:19 +0000 Subject: [PATCH 246/381] [SPARK-18490][SQL] duplication nodename extrainfo for ShuffleExchange ## What changes were proposed in this pull request? In ShuffleExchange, the nodename's extraInfo are the same when exchangeCoordinator.isEstimated is true or false. Merge the two situation in the PR. Author: root Closes #15920 from windpiger/DupNodeNameShuffleExchange. --- .../apache/spark/sql/execution/exchange/ShuffleExchange.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index 7a4a25137070..125a4930c652 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -45,9 +45,7 @@ case class ShuffleExchange( override def nodeName: String = { val extraInfo = coordinator match { - case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated => - s"(coordinator id: ${System.identityHashCode(coordinator)})" - case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated => + case Some(exchangeCoordinator) => s"(coordinator id: ${System.identityHashCode(coordinator)})" case None => "" } From ce13c2672318242748f7520ed4ce6bcfad4fb428 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 17 Nov 2016 17:31:12 -0800 Subject: [PATCH 247/381] [SPARK-18360][SQL] default table path of tables in default database should depend on the location of default database ## What changes were proposed in this pull request? The current semantic of the warehouse config: 1. it's a static config, which means you can't change it once your spark application is launched. 2. Once a database is created, its location won't change even the warehouse path config is changed. 3. default database is a special case, although its location is fixed, but the locations of tables created in it are not. If a Spark app starts with warehouse path B(while the location of default database is A), then users create a table `tbl` in default database, its location will be `B/tbl` instead of `A/tbl`. If uses change the warehouse path config to C, and create another table `tbl2`, its location will still be `B/tbl2` instead of `C/tbl2`. rule 3 doesn't make sense and I think we made it by mistake, not intentionally. Data source tables don't follow rule 3 and treat default database like normal ones. This PR fixes hive serde tables to make it consistent with data source tables. ## How was this patch tested? HiveSparkSubmitSuite Author: Wenchen Fan Closes #15812 from cloud-fan/default-db. --- .../spark/sql/hive/HiveExternalCatalog.scala | 237 ++++++++++-------- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 76 +++++- 2 files changed, 190 insertions(+), 123 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 843305883abc..cacffcf33c26 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -197,136 +197,151 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (tableDefinition.tableType == VIEW) { client.createTable(tableDefinition, ignoreIfExists) - } else if (tableDefinition.provider.get == DDLUtils.HIVE_PROVIDER) { - // Here we follow data source tables and put table metadata like provider, schema, etc. in - // table properties, so that we can work around the Hive metastore issue about not case - // preserving and make Hive serde table support mixed-case column names. - val tableWithDataSourceProps = tableDefinition.copy( - properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) - client.createTable(tableWithDataSourceProps, ignoreIfExists) } else { - // To work around some hive metastore issues, e.g. not case-preserving, bad decimal type - // support, no column nullability, etc., we should do some extra works before saving table - // metadata into Hive metastore: - // 1. Put table metadata like provider, schema, etc. in table properties. - // 2. Check if this table is hive compatible. - // 2.1 If it's not hive compatible, set location URI, schema, partition columns and bucket - // spec to empty and save table metadata to Hive. - // 2.2 If it's hive compatible, set serde information in table metadata and try to save - // it to Hive. If it fails, treat it as not hive compatible and go back to 2.1 - val tableProperties = tableMetaToTableProps(tableDefinition) - // Ideally we should not create a managed table with location, but Hive serde table can // specify location for managed table. And in [[CreateDataSourceTableAsSelectCommand]] we have // to create the table directory and write out data before we create this table, to avoid // exposing a partial written table. val needDefaultTableLocation = tableDefinition.tableType == MANAGED && tableDefinition.storage.locationUri.isEmpty + val tableLocation = if (needDefaultTableLocation) { Some(defaultTablePath(tableDefinition.identifier)) } else { tableDefinition.storage.locationUri } - // Ideally we should also put `locationUri` in table properties like provider, schema, etc. - // However, in older version of Spark we already store table location in storage properties - // with key "path". Here we keep this behaviour for backward compatibility. - val storagePropsWithLocation = tableDefinition.storage.properties ++ - tableLocation.map("path" -> _) - - // converts the table metadata to Spark SQL specific format, i.e. set data schema, names and - // bucket specification to empty. Note that partition columns are retained, so that we can - // call partition-related Hive API later. - def newSparkSQLSpecificMetastoreTable(): CatalogTable = { - tableDefinition.copy( - // Hive only allows directory paths as location URIs while Spark SQL data source tables - // also allow file paths. For non-hive-compatible format, we should not set location URI - // to avoid hive metastore to throw exception. - storage = tableDefinition.storage.copy( - locationUri = None, - properties = storagePropsWithLocation), - schema = tableDefinition.partitionSchema, - bucketSpec = None, - properties = tableDefinition.properties ++ tableProperties) + + if (tableDefinition.provider.get == DDLUtils.HIVE_PROVIDER) { + val tableWithDataSourceProps = tableDefinition.copy( + // We can't leave `locationUri` empty and count on Hive metastore to set a default table + // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default + // table location for tables in default database, while we expect to use the location of + // default database. + storage = tableDefinition.storage.copy(locationUri = tableLocation), + // Here we follow data source tables and put table metadata like provider, schema, etc. in + // table properties, so that we can work around the Hive metastore issue about not case + // preserving and make Hive serde table support mixed-case column names. + properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) + client.createTable(tableWithDataSourceProps, ignoreIfExists) + } else { + createDataSourceTable( + tableDefinition.withNewStorage(locationUri = tableLocation), + ignoreIfExists) } + } + } - // converts the table metadata to Hive compatible format, i.e. set the serde information. - def newHiveCompatibleMetastoreTable(serde: HiveSerDe): CatalogTable = { - val location = if (tableDefinition.tableType == EXTERNAL) { - // When we hit this branch, we are saving an external data source table with hive - // compatible format, which means the data source is file-based and must have a `path`. - require(tableDefinition.storage.locationUri.isDefined, - "External file-based data source table must have a `path` entry in storage properties.") - Some(new Path(tableDefinition.location).toUri.toString) - } else { - None - } + private def createDataSourceTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = { + // To work around some hive metastore issues, e.g. not case-preserving, bad decimal type + // support, no column nullability, etc., we should do some extra works before saving table + // metadata into Hive metastore: + // 1. Put table metadata like provider, schema, etc. in table properties. + // 2. Check if this table is hive compatible. + // 2.1 If it's not hive compatible, set location URI, schema, partition columns and bucket + // spec to empty and save table metadata to Hive. + // 2.2 If it's hive compatible, set serde information in table metadata and try to save + // it to Hive. If it fails, treat it as not hive compatible and go back to 2.1 + val tableProperties = tableMetaToTableProps(table) + + // Ideally we should also put `locationUri` in table properties like provider, schema, etc. + // However, in older version of Spark we already store table location in storage properties + // with key "path". Here we keep this behaviour for backward compatibility. + val storagePropsWithLocation = table.storage.properties ++ + table.storage.locationUri.map("path" -> _) + + // converts the table metadata to Spark SQL specific format, i.e. set data schema, names and + // bucket specification to empty. Note that partition columns are retained, so that we can + // call partition-related Hive API later. + def newSparkSQLSpecificMetastoreTable(): CatalogTable = { + table.copy( + // Hive only allows directory paths as location URIs while Spark SQL data source tables + // also allow file paths. For non-hive-compatible format, we should not set location URI + // to avoid hive metastore to throw exception. + storage = table.storage.copy( + locationUri = None, + properties = storagePropsWithLocation), + schema = table.partitionSchema, + bucketSpec = None, + properties = table.properties ++ tableProperties) + } - tableDefinition.copy( - storage = tableDefinition.storage.copy( - locationUri = location, - inputFormat = serde.inputFormat, - outputFormat = serde.outputFormat, - serde = serde.serde, - properties = storagePropsWithLocation - ), - properties = tableDefinition.properties ++ tableProperties) + // converts the table metadata to Hive compatible format, i.e. set the serde information. + def newHiveCompatibleMetastoreTable(serde: HiveSerDe): CatalogTable = { + val location = if (table.tableType == EXTERNAL) { + // When we hit this branch, we are saving an external data source table with hive + // compatible format, which means the data source is file-based and must have a `path`. + require(table.storage.locationUri.isDefined, + "External file-based data source table must have a `path` entry in storage properties.") + Some(new Path(table.location).toUri.toString) + } else { + None } - val qualifiedTableName = tableDefinition.identifier.quotedString - val maybeSerde = HiveSerDe.sourceToSerDe(tableDefinition.provider.get) - val skipHiveMetadata = tableDefinition.storage.properties - .getOrElse("skipHiveMetadata", "false").toBoolean - - val (hiveCompatibleTable, logMessage) = maybeSerde match { - case _ if skipHiveMetadata => - val message = - s"Persisting data source table $qualifiedTableName into Hive metastore in" + - "Spark SQL specific format, which is NOT compatible with Hive." - (None, message) - - // our bucketing is un-compatible with hive(different hash function) - case _ if tableDefinition.bucketSpec.nonEmpty => - val message = - s"Persisting bucketed data source table $qualifiedTableName into " + - "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " - (None, message) - - case Some(serde) => - val message = - s"Persisting file based data source table $qualifiedTableName into " + - s"Hive metastore in Hive compatible format." - (Some(newHiveCompatibleMetastoreTable(serde)), message) - - case _ => - val provider = tableDefinition.provider.get - val message = - s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + - s"Persisting data source table $qualifiedTableName into Hive metastore in " + - s"Spark SQL specific format, which is NOT compatible with Hive." - (None, message) - } + table.copy( + storage = table.storage.copy( + locationUri = location, + inputFormat = serde.inputFormat, + outputFormat = serde.outputFormat, + serde = serde.serde, + properties = storagePropsWithLocation + ), + properties = table.properties ++ tableProperties) + } - (hiveCompatibleTable, logMessage) match { - case (Some(table), message) => - // We first try to save the metadata of the table in a Hive compatible way. - // If Hive throws an error, we fall back to save its metadata in the Spark SQL - // specific way. - try { - logInfo(message) - saveTableIntoHive(table, ignoreIfExists) - } catch { - case NonFatal(e) => - val warningMessage = - s"Could not persist ${tableDefinition.identifier.quotedString} in a Hive " + - "compatible way. Persisting it into Hive metastore in Spark SQL specific format." - logWarning(warningMessage, e) - saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists) - } + val qualifiedTableName = table.identifier.quotedString + val maybeSerde = HiveSerDe.sourceToSerDe(table.provider.get) + val skipHiveMetadata = table.storage.properties + .getOrElse("skipHiveMetadata", "false").toBoolean + + val (hiveCompatibleTable, logMessage) = maybeSerde match { + case _ if skipHiveMetadata => + val message = + s"Persisting data source table $qualifiedTableName into Hive metastore in" + + "Spark SQL specific format, which is NOT compatible with Hive." + (None, message) + + // our bucketing is un-compatible with hive(different hash function) + case _ if table.bucketSpec.nonEmpty => + val message = + s"Persisting bucketed data source table $qualifiedTableName into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + (None, message) + + case Some(serde) => + val message = + s"Persisting file based data source table $qualifiedTableName into " + + s"Hive metastore in Hive compatible format." + (Some(newHiveCompatibleMetastoreTable(serde)), message) + + case _ => + val provider = table.provider.get + val message = + s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + + s"Persisting data source table $qualifiedTableName into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive." + (None, message) + } - case (None, message) => - logWarning(message) - saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists) - } + (hiveCompatibleTable, logMessage) match { + case (Some(table), message) => + // We first try to save the metadata of the table in a Hive compatible way. + // If Hive throws an error, we fall back to save its metadata in the Spark SQL + // specific way. + try { + logInfo(message) + saveTableIntoHive(table, ignoreIfExists) + } catch { + case NonFatal(e) => + val warningMessage = + s"Could not persist ${table.identifier.quotedString} in a Hive " + + "compatible way. Persisting it into Hive metastore in Spark SQL specific format." + logWarning(warningMessage, e) + saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists) + } + + case (None, message) => + logWarning(message) + saveTableIntoHive(newSparkSQLSpecificMetastoreTable(), ignoreIfExists) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index fbd705172cae..a670560c5969 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -24,6 +24,7 @@ import java.util.Date import scala.collection.mutable.ArrayBuffer import scala.tools.nsc.Properties +import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException @@ -33,11 +34,12 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResource, JarResource} +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.util.{ResetSystemProperties, Utils} /** @@ -295,6 +297,20 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-18360: default table path of tables in default database should depend on the " + + "location of default database") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_18360.getClass.getName.stripSuffix("$"), + "--name", "SPARK-18360", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -397,11 +413,7 @@ object SetWarehouseLocationTest extends Logging { def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") - val sparkConf = new SparkConf(loadDefaults = true) - val builder = SparkSession.builder() - .config(sparkConf) - .config("spark.ui.enabled", "false") - .enableHiveSupport() + val sparkConf = new SparkConf(loadDefaults = true).set("spark.ui.enabled", "false") val providedExpectedWarehouseLocation = sparkConf.getOption("spark.sql.test.expectedWarehouseDir") @@ -410,7 +422,7 @@ object SetWarehouseLocationTest extends Logging { // If spark.sql.test.expectedWarehouseDir is set, the warehouse dir is set // through spark-summit. So, neither spark.sql.warehouse.dir nor // hive.metastore.warehouse.dir is set at here. - (builder.getOrCreate(), warehouseDir) + (new TestHiveContext(new SparkContext(sparkConf)).sparkSession, warehouseDir) case None => val warehouseLocation = Utils.createTempDir() warehouseLocation.delete() @@ -420,10 +432,10 @@ object SetWarehouseLocationTest extends Logging { // spark.sql.warehouse.dir and hive.metastore.warehouse.dir. // We are expecting that the value of spark.sql.warehouse.dir will override the // value of hive.metastore.warehouse.dir. - val session = builder - .config("spark.sql.warehouse.dir", warehouseLocation.toString) - .config("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString) - .getOrCreate() + val session = new TestHiveContext(new SparkContext(sparkConf + .set("spark.sql.warehouse.dir", warehouseLocation.toString) + .set("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString))) + .sparkSession (session, warehouseLocation.toString) } @@ -801,3 +813,43 @@ object SPARK_14244 extends QueryTest { } } } + +object SPARK_18360 { + def main(args: Array[String]): Unit = { + val spark = SparkSession.builder() + .config("spark.ui.enabled", "false") + .enableHiveSupport().getOrCreate() + + val defaultDbLocation = spark.catalog.getDatabase("default").locationUri + assert(new Path(defaultDbLocation) == new Path(spark.sharedState.warehousePath)) + + val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + + try { + val tableMeta = CatalogTable( + identifier = TableIdentifier("test_tbl", Some("default")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("i", "int"), + provider = Some(DDLUtils.HIVE_PROVIDER)) + + val newWarehousePath = Utils.createTempDir().getAbsolutePath + hiveClient.runSqlHive(s"SET hive.metastore.warehouse.dir=$newWarehousePath") + hiveClient.createTable(tableMeta, ignoreIfExists = false) + val rawTable = hiveClient.getTable("default", "test_tbl") + // Hive will use the value of `hive.metastore.warehouse.dir` to generate default table + // location for tables in default database. + assert(rawTable.storage.locationUri.get.contains(newWarehousePath)) + hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = false, purge = false) + + spark.sharedState.externalCatalog.createTable(tableMeta, ignoreIfExists = false) + val readBack = spark.sharedState.externalCatalog.getTable("default", "test_tbl") + // Spark SQL will use the location of default database to generate default table + // location for tables in default database. + assert(readBack.storage.locationUri.get.contains(defaultDbLocation)) + } finally { + hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = true, purge = false) + hiveClient.runSqlHive(s"SET hive.metastore.warehouse.dir=$defaultDbLocation") + } + } +} From d9dd979d170f44383a9a87f892f2486ddb3cca7d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 17 Nov 2016 18:45:15 -0800 Subject: [PATCH 248/381] [SPARK-18462] Fix ClassCastException in SparkListenerDriverAccumUpdates event ## What changes were proposed in this pull request? This patch fixes a `ClassCastException: java.lang.Integer cannot be cast to java.lang.Long` error which could occur in the HistoryServer while trying to process a deserialized `SparkListenerDriverAccumUpdates` event. The problem stems from how `jackson-module-scala` handles primitive type parameters (see https://github.com/FasterXML/jackson-module-scala/wiki/FAQ#deserializing-optionint-and-other-primitive-challenges for more details). This was causing a problem where our code expected a field to be deserialized as a `(Long, Long)` tuple but we got an `(Int, Int)` tuple instead. This patch hacks around this issue by registering a custom `Converter` with Jackson in order to deserialize the tuples as `(Object, Object)` and perform the appropriate casting. ## How was this patch tested? New regression tests in `SQLListenerSuite`. Author: Josh Rosen Closes #15922 from JoshRosen/SPARK-18462. --- .../spark/sql/execution/ui/SQLListener.scala | 39 +++++++++++++++- .../sql/execution/ui/SQLListenerSuite.scala | 44 ++++++++++++++++++- 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 60f13432d78d..5daf21595d8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,6 +19,11 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable +import com.fasterxml.jackson.databind.JavaType +import com.fasterxml.jackson.databind.`type`.TypeFactory +import com.fasterxml.jackson.databind.annotation.JsonDeserialize +import com.fasterxml.jackson.databind.util.Converter + import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging @@ -43,9 +48,41 @@ case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) extends SparkListenerEvent @DeveloperApi -case class SparkListenerDriverAccumUpdates(executionId: Long, accumUpdates: Seq[(Long, Long)]) +case class SparkListenerDriverAccumUpdates( + executionId: Long, + @JsonDeserialize(contentConverter = classOf[LongLongTupleConverter]) + accumUpdates: Seq[(Long, Long)]) extends SparkListenerEvent +/** + * Jackson [[Converter]] for converting an (Int, Int) tuple into a (Long, Long) tuple. + * + * This is necessary due to limitations in how Jackson's scala module deserializes primitives; + * see the "Deserializing Option[Int] and other primitive challenges" section in + * https://github.com/FasterXML/jackson-module-scala/wiki/FAQ for a discussion of this issue and + * SPARK-18462 for the specific problem that motivated this conversion. + */ +private class LongLongTupleConverter extends Converter[(Object, Object), (Long, Long)] { + + override def convert(in: (Object, Object)): (Long, Long) = { + def toLong(a: Object): Long = a match { + case i: java.lang.Integer => i.intValue() + case l: java.lang.Long => l.longValue() + } + (toLong(in._1), toLong(in._2)) + } + + override def getInputType(typeFactory: TypeFactory): JavaType = { + val objectType = typeFactory.uncheckedSimpleType(classOf[Object]) + typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(objectType, objectType)) + } + + override def getOutputType(typeFactory: TypeFactory): JavaType = { + val longType = typeFactory.uncheckedSimpleType(classOf[Long]) + typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(longType, longType)) + } +} + class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 948a155457b6..8aea112897fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.ui import java.util.Properties +import org.json4s.jackson.JsonMethods._ import org.mockito.Mockito.mock import org.apache.spark._ @@ -35,10 +36,10 @@ import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanIn import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{AccumulatorMetadata, LongAccumulator} +import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator} -class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { +class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTestUtils { import testImplicits._ import org.apache.spark.AccumulatorSuite.makeInfo @@ -416,6 +417,45 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(driverUpdates(physicalPlan.longMetric("dummy").id) == expectedAccumValue) } + test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") { + val event = SparkListenerDriverAccumUpdates(1L, Seq((2L, 3L))) + val json = JsonProtocol.sparkEventToJson(event) + assertValidDataInJson(json, + parse(""" + |{ + | "Event": "org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates", + | "executionId": 1, + | "accumUpdates": [[2,3]] + |} + """.stripMargin)) + JsonProtocol.sparkEventFromJson(json) match { + case SparkListenerDriverAccumUpdates(executionId, accums) => + assert(executionId == 1L) + accums.foreach { case (a, b) => + assert(a == 2L) + assert(b == 3L) + } + } + + // Test a case where the numbers in the JSON can only fit in longs: + val longJson = parse( + """ + |{ + | "Event": "org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates", + | "executionId": 4294967294, + | "accumUpdates": [[4294967294,3]] + |} + """.stripMargin) + JsonProtocol.sparkEventFromJson(longJson) match { + case SparkListenerDriverAccumUpdates(executionId, accums) => + assert(executionId == 4294967294L) + accums.foreach { case (a, b) => + assert(a == 4294967294L) + assert(b == 3L) + } + } + } + } From 6c6aba229c68bc2ba341d12e5d4bbbae8dc880bc Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 18 Nov 2016 09:17:49 -0800 Subject: [PATCH 249/381] refine connect and read code --- .../spark/sql/DatasetToArrowSuite.scala | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala index e954cdc751a6..0e988558a2b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.io.{DataInputStream, EOFException, RandomAccessFile} +import java.io._ import java.net.{InetAddress, Socket} import java.nio.channels.FileChannel @@ -25,8 +25,8 @@ import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.file.ArrowReader import org.apache.arrow.vector.schema.ArrowRecordBatch - import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils case class ArrowIntTest(a: Int, b: Int) @@ -67,6 +67,8 @@ class DatasetToArrowSuite extends QueryTest with SharedSQLContext { class RecordBatchReceiver { + val allocator = new RootAllocator(Long.MaxValue) + def array(buf: ArrowBuf): Array[Byte] = { val bytes = Array.ofDim[Byte](buf.readableBytes()) buf.readBytes(bytes) @@ -74,33 +76,32 @@ class RecordBatchReceiver { } def connectAndRead(port: Int): (Array[Byte], Int) = { - val s = new Socket(InetAddress.getByName("localhost"), port) - val is = s.getInputStream + val clientSocket = new Socket(InetAddress.getByName("localhost"), port) + val clientDataIns = new DataInputStream(clientSocket.getInputStream) - val dis = new DataInputStream(is) - val len = dis.readInt() + val messageLength = clientDataIns.readInt() - val buffer = Array.ofDim[Byte](len) - val bytesRead = dis.read(buffer) - if (bytesRead != len) { - throw new EOFException("Wrong EOF") + val buffer = Array.ofDim[Byte](messageLength) + val bytesRead = clientDataIns.read(buffer) + if (bytesRead != messageLength) { + throw new EOFException("Wrong EOF to read Arrow Bytes") } - (buffer, len) + (buffer, messageLength) } def makeFile(buffer: Array[Byte]): FileChannel = { - var aFile = new RandomAccessFile("/tmp/nio-data.txt", "rw") - aFile.write(buffer) - aFile.close() - - aFile = new RandomAccessFile("/tmp/nio-data.txt", "r") - val fChannel = aFile.getChannel - fChannel + val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName).getPath + val arrowFile = new File(tempDir, "arrow-bytes") + val arrowOus = new FileOutputStream(arrowFile.getPath) + arrowOus.write(buffer) + arrowOus.close() + + val arrowIns = new FileInputStream(arrowFile.getPath) + arrowIns.getChannel } - def readRecordBatch(fc: FileChannel, len: Int): ArrowRecordBatch = { - val allocator = new RootAllocator(len) - val reader = new ArrowReader(fc, allocator) + def readRecordBatch(channel: FileChannel, len: Int): ArrowRecordBatch = { + val reader = new ArrowReader(channel, allocator) val footer = reader.readFooter() val schema = footer.getSchema val blocks = footer.getRecordBatches From 51baca2219fda8692b88fc8552548544aec73a1e Mon Sep 17 00:00:00 2001 From: Tyson Condie Date: Fri, 18 Nov 2016 11:11:24 -0800 Subject: [PATCH 250/381] [SPARK-18187][SQL] CompactibleFileStreamLog should not use "compactInterval" direcly with user setting. ## What changes were proposed in this pull request? CompactibleFileStreamLog relys on "compactInterval" to detect a compaction batch. If the "compactInterval" is reset by user, CompactibleFileStreamLog will return wrong answer, resulting data loss. This PR procides a way to check the validity of 'compactInterval', and calculate an appropriate value. ## How was this patch tested? When restart a stream, we change the 'spark.sql.streaming.fileSource.log.compactInterval' different with the former one. The primary solution to this issue was given by uncleGen Added extensions include an additional metadata field in OffsetSeq and CompactibleFileStreamLog APIs. zsxwing Author: Tyson Condie Author: genmao.ygm Closes #15852 from tcondie/spark-18187. --- .../streaming/CompactibleFileStreamLog.scala | 61 ++++++++++++++++++- .../streaming/FileStreamSinkLog.scala | 8 ++- .../streaming/FileStreamSourceLog.scala | 9 +-- .../execution/streaming/HDFSMetadataLog.scala | 2 +- .../sql/execution/streaming/OffsetSeq.scala | 12 +++- .../execution/streaming/OffsetSeqLog.scala | 31 +++++++--- .../CompactibleFileStreamLogSuite.scala | 33 ++++++++++ .../sql/streaming/FileStreamSourceSuite.scala | 41 ++++++++----- .../spark/sql/streaming/StreamTest.scala | 20 +++++- 9 files changed, 178 insertions(+), 39 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 8af3db196888..8529ceac30f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -63,7 +63,46 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( protected def isDeletingExpiredLog: Boolean - protected def compactInterval: Int + protected def defaultCompactInterval: Int + + protected final lazy val compactInterval: Int = { + // SPARK-18187: "compactInterval" can be set by user via defaultCompactInterval. + // If there are existing log entries, then we should ensure a compatible compactInterval + // is used, irrespective of the defaultCompactInterval. There are three cases: + // + // 1. If there is no '.compact' file, we can use the default setting directly. + // 2. If there are two or more '.compact' files, we use the interval of patch id suffix with + // '.compact' as compactInterval. This case could arise if isDeletingExpiredLog == false. + // 3. If there is only one '.compact' file, then we must find a compact interval + // that is compatible with (i.e., a divisor of) the previous compact file, and that + // faithfully tries to represent the revised default compact interval i.e., is at least + // is large if possible. + // e.g., if defaultCompactInterval is 5 (and previous compact interval could have + // been any 2,3,4,6,12), then a log could be: 11.compact, 12, 13, in which case + // will ensure that the new compactInterval = 6 > 5 and (11 + 1) % 6 == 0 + val compactibleBatchIds = fileManager.list(metadataPath, batchFilesFilter) + .filter(f => f.getPath.toString.endsWith(CompactibleFileStreamLog.COMPACT_FILE_SUFFIX)) + .map(f => pathToBatchId(f.getPath)) + .sorted + .reverse + + // Case 1 + var interval = defaultCompactInterval + if (compactibleBatchIds.length >= 2) { + // Case 2 + val latestCompactBatchId = compactibleBatchIds(0) + val previousCompactBatchId = compactibleBatchIds(1) + interval = (latestCompactBatchId - previousCompactBatchId).toInt + } else if (compactibleBatchIds.length == 1) { + // Case 3 + interval = CompactibleFileStreamLog.deriveCompactInterval( + defaultCompactInterval, compactibleBatchIds(0).toInt) + } + assert(interval > 0, s"intervalValue = $interval not positive value.") + logInfo(s"Set the compact interval to $interval " + + s"[defaultCompactInterval: $defaultCompactInterval]") + interval + } /** * Filter out the obsolete logs. @@ -245,4 +284,24 @@ object CompactibleFileStreamLog { def nextCompactionBatchId(batchId: Long, compactInterval: Long): Long = { (batchId + compactInterval + 1) / compactInterval * compactInterval - 1 } + + /** + * Derives a compact interval from the latest compact batch id and + * a default compact interval. + */ + def deriveCompactInterval(defaultInterval: Int, latestCompactBatchId: Int) : Int = { + if (latestCompactBatchId + 1 <= defaultInterval) { + latestCompactBatchId + 1 + } else if (defaultInterval < (latestCompactBatchId + 1) / 2) { + // Find the first divisor >= default compact interval + def properDivisors(min: Int, n: Int) = + (min to n/2).view.filter(i => n % i == 0) :+ n + + properDivisors(defaultInterval, latestCompactBatchId + 1).head + } else { + // default compact interval > than any divisor other than latest compact id + latestCompactBatchId + 1 + } + } } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala index b4f14151f1ef..eb6eed87eca7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -88,9 +88,11 @@ class FileStreamSinkLog( protected override val isDeletingExpiredLog = sparkSession.sessionState.conf.fileSinkLogDeletion - protected override val compactInterval = sparkSession.sessionState.conf.fileSinkLogCompactInterval - require(compactInterval > 0, - s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $compactInterval) " + + protected override val defaultCompactInterval = + sparkSession.sessionState.conf.fileSinkLogCompactInterval + + require(defaultCompactInterval > 0, + s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $defaultCompactInterval) " + "to a positive value.") override def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala index fe81b1560706..327b3ac26776 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -38,11 +38,12 @@ class FileStreamSourceLog( import CompactibleFileStreamLog._ // Configurations about metadata compaction - protected override val compactInterval = + protected override val defaultCompactInterval: Int = sparkSession.sessionState.conf.fileSourceLogCompactInterval - require(compactInterval > 0, - s"Please set ${SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key} (was $compactInterval) to a " + - s"positive value.") + + require(defaultCompactInterval > 0, + s"Please set ${SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key} " + + s"(was $defaultCompactInterval) to a positive value.") protected override val fileCleanupDelayMs = sparkSession.sessionState.conf.fileSourceLogCleanupDelay diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index db7057d7da70..080729b2ca8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -70,7 +70,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: /** * A `PathFilter` to filter only batch files */ - private val batchFilesFilter = new PathFilter { + protected val batchFilesFilter = new PathFilter { override def accept(path: Path): Boolean = isBatchFile(path) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index a4e1fe679709..7469caeee3be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -23,7 +23,7 @@ package org.apache.spark.sql.execution.streaming * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance * vector clock that must progress linearly forward. */ -case class OffsetSeq(offsets: Seq[Option[Offset]]) { +case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[String] = None) { /** * Unpacks an offset into [[StreamProgress]] by associating each offset with the order list of @@ -47,7 +47,13 @@ object OffsetSeq { * Returns a [[OffsetSeq]] with a variable sequence of offsets. * `nulls` in the sequence are converted to `None`s. */ - def fill(offsets: Offset*): OffsetSeq = { - OffsetSeq(offsets.map(Option(_))) + def fill(offsets: Offset*): OffsetSeq = OffsetSeq.fill(None, offsets: _*) + + /** + * Returns a [[OffsetSeq]] with metadata and a variable sequence of offsets. + * `nulls` in the sequence are converted to `None`s. + */ + def fill(metadata: Option[String], offsets: Offset*): OffsetSeq = { + OffsetSeq(offsets.map(Option(_)), metadata) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index d1c9d95be9fd..cc25b4474ba2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -33,12 +33,13 @@ import org.apache.spark.sql.SparkSession * by a newline character. If a source offset is missing, then * that line will contain a string value defined in the * SERIALIZED_VOID_OFFSET variable in [[OffsetSeqLog]] companion object. - * For instance, when dealine wiht [[LongOffset]] types: - * v1 // version 1 - * {0} // LongOffset 0 - * {3} // LongOffset 3 - * - // No offset for this source i.e., an invalid JSON string - * {2} // LongOffset 2 + * For instance, when dealing with [[LongOffset]] types: + * v1 // version 1 + * metadata + * {0} // LongOffset 0 + * {3} // LongOffset 3 + * - // No offset for this source i.e., an invalid JSON string + * {2} // LongOffset 2 * ... */ class OffsetSeqLog(sparkSession: SparkSession, path: String) @@ -58,13 +59,25 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) if (version != OffsetSeqLog.VERSION) { throw new IllegalStateException(s"Unknown log version: ${version}") } - OffsetSeq.fill(lines.map(parseOffset).toArray: _*) + + // read metadata + val metadata = lines.next().trim match { + case "" => None + case md => Some(md) + } + OffsetSeq.fill(metadata, lines.map(parseOffset).toArray: _*) } - override protected def serialize(metadata: OffsetSeq, out: OutputStream): Unit = { + override protected def serialize(offsetSeq: OffsetSeq, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller out.write(OffsetSeqLog.VERSION.getBytes(UTF_8)) - metadata.offsets.map(_.map(_.json)).foreach { offset => + + // write metadata + out.write('\n') + out.write(offsetSeq.metadata.getOrElse("").getBytes(UTF_8)) + + // write offsets, one per line + offsetSeq.offsets.map(_.map(_.json)).foreach { offset => out.write('\n') offset match { case Some(json: String) => out.write(json.getBytes(UTF_8)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala new file mode 100644 index 000000000000..2cd2157b293c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.SparkFunSuite + +class CompactibleFileStreamLogSuite extends SparkFunSuite { + + import CompactibleFileStreamLog._ + + test("deriveCompactInterval") { + // latestCompactBatchId(4) + 1 <= default(5) + // then use latestestCompactBatchId + 1 === 5 + assert(5 === deriveCompactInterval(5, 4)) + // First divisor of 10 greater than 4 === 5 + assert(5 === deriveCompactInterval(4, 9)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index b365af76c379..a099153d2e58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.streaming import java.io.File +import scala.collection.mutable + import org.scalatest.PrivateMethodTester import org.scalatest.time.SpanSugar._ @@ -896,32 +898,38 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } - test("compacat metadata log") { + test("compact interval metadata log") { val _sources = PrivateMethod[Seq[Source]]('sources) val _metadataLog = PrivateMethod[FileStreamSourceLog]('metadataLog) - def verify(execution: StreamExecution) - (batchId: Long, expectedBatches: Int): Boolean = { + def verify( + execution: StreamExecution, + batchId: Long, + expectedBatches: Int, + expectedCompactInterval: Int): Boolean = { import CompactibleFileStreamLog._ val fileSource = (execution invokePrivate _sources()).head.asInstanceOf[FileStreamSource] val metadataLog = fileSource invokePrivate _metadataLog() - if (isCompactionBatch(batchId, 2)) { + if (isCompactionBatch(batchId, expectedCompactInterval)) { val path = metadataLog.batchIdToPath(batchId) // Assert path name should be ended with compact suffix. - assert(path.getName.endsWith(COMPACT_FILE_SUFFIX)) + assert(path.getName.endsWith(COMPACT_FILE_SUFFIX), + "path does not end with compact file suffix") // Compacted batch should include all entries from start. val entries = metadataLog.get(batchId) - assert(entries.isDefined) - assert(entries.get.length === metadataLog.allFiles().length) - assert(metadataLog.get(None, Some(batchId)).flatMap(_._2).length === entries.get.length) + assert(entries.isDefined, "Entries not defined") + assert(entries.get.length === metadataLog.allFiles().length, "clean up check") + assert(metadataLog.get(None, Some(batchId)).flatMap(_._2).length === + entries.get.length, "Length check") } assert(metadataLog.allFiles().sortBy(_.batchId) === - metadataLog.get(None, Some(batchId)).flatMap(_._2).sortBy(_.batchId)) + metadataLog.get(None, Some(batchId)).flatMap(_._2).sortBy(_.batchId), + "Batch id mismatch") metadataLog.get(None, Some(batchId)).flatMap(_._2).length === expectedBatches } @@ -932,26 +940,27 @@ class FileStreamSourceSuite extends FileStreamSourceTest { ) { val fileStream = createFileStream("text", src.getCanonicalPath) val filtered = fileStream.filter($"value" contains "keep") + val updateConf = Map(SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key -> "5") testStream(filtered)( AddTextFileData("drop1\nkeep2\nkeep3", src, tmp), CheckAnswer("keep2", "keep3"), - AssertOnQuery(verify(_)(0L, 1)), + AssertOnQuery(verify(_, 0L, 1, 2)), AddTextFileData("drop4\nkeep5\nkeep6", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6"), - AssertOnQuery(verify(_)(1L, 2)), + AssertOnQuery(verify(_, 1L, 2, 2)), AddTextFileData("drop7\nkeep8\nkeep9", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9"), - AssertOnQuery(verify(_)(2L, 3)), + AssertOnQuery(verify(_, 2L, 3, 2)), StopStream, - StartStream(), - AssertOnQuery(verify(_)(2L, 3)), + StartStream(additionalConfs = updateConf), + AssertOnQuery(verify(_, 2L, 3, 2)), AddTextFileData("drop10\nkeep11", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9", "keep11"), - AssertOnQuery(verify(_)(3L, 4)), + AssertOnQuery(verify(_, 3L, 4, 2)), AddTextFileData("drop12\nkeep13", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9", "keep11", "keep13"), - AssertOnQuery(verify(_)(4L, 5)) + AssertOnQuery(verify(_, 4L, 5, 2)) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 742833065144..a6b2d4b9ab4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -161,7 +161,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( trigger: Trigger = ProcessingTime(0), - triggerClock: Clock = new SystemClock) + triggerClock: Clock = new SystemClock, + additionalConfs: Map[String, String] = Map.empty) extends StreamAction /** Advance the trigger clock's time manually. */ @@ -240,6 +241,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for val sink = new MemorySink(stream.schema, outputMode) + val resetConfValues = mutable.Map[String, Option[String]]() @volatile var streamDeathCause: Throwable = null @@ -330,7 +332,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { startedTest.foreach { action => logInfo(s"Processing test stream action: $action") action match { - case StartStream(trigger, triggerClock) => + case StartStream(trigger, triggerClock, additionalConfs) => verify(currentStream == null, "stream already running") verify(triggerClock.isInstanceOf[SystemClock] || triggerClock.isInstanceOf[StreamManualClock], @@ -338,6 +340,14 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { if (triggerClock.isInstanceOf[StreamManualClock]) { manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() } + + additionalConfs.foreach(pair => { + val value = + if (spark.conf.contains(pair._1)) Some(spark.conf.get(pair._1)) else None + resetConfValues(pair._1) = value + spark.conf.set(pair._1, pair._2) + }) + lastStream = currentStream currentStream = spark @@ -519,6 +529,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { currentStream.stop() } spark.streams.removeListener(statusCollector) + + // Rollback prev configuration values + resetConfValues.foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } } } From 795e9fc9213cb9941ae131aadcafddb94bde5f74 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Fri, 18 Nov 2016 11:19:49 -0800 Subject: [PATCH 251/381] [SPARK-18457][SQL] ORC and other columnar formats using HiveShim read all columns when doing a simple count ## What changes were proposed in this pull request? When reading zero columns (e.g., count(*)) from ORC or any other format that uses HiveShim, actually set the read column list to empty for Hive to use. ## How was this patch tested? Query correctness is handled by existing unit tests. I'm happy to add more if anyone can point out some case that is not covered. Reduction in data read can be verified in the UI when built with a recent version of Hadoop say: ``` build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.0 -Phive -DskipTests clean package ``` However the default Hadoop 2.2 that is used for unit tests does not report actual bytes read and instead just full file sizes (see FileScanRDD.scala line 80). Therefore I don't think there is a good way to add a unit test for this. I tested with the following setup using above build options ``` case class OrcData(intField: Long, stringField: String) spark.range(1,1000000).map(i => OrcData(i, s"part-$i")).toDF().write.format("orc").save("orc_test") sql( s"""CREATE EXTERNAL TABLE orc_test( | intField LONG, | stringField STRING |) |STORED AS ORC |LOCATION '${System.getProperty("user.dir") + "/orc_test"}' """.stripMargin) ``` ## Results query | Spark 2.0.2 | this PR ---|---|--- `sql("select count(*) from orc_test").collect`|4.4 MB|199.4 KB `sql("select intField from orc_test").collect`|743.4 KB|743.4 KB `sql("select * from orc_test").collect`|4.4 MB|4.4 MB Author: Andrew Ray Closes #15898 from aray/sql-orc-no-col. --- .../org/apache/spark/sql/hive/HiveShim.scala | 6 ++--- .../spark/sql/hive/orc/OrcQuerySuite.scala | 25 ++++++++++++++++++- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index 0d2a765a388a..9e9894803ce2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -69,13 +69,13 @@ private[hive] object HiveShim { } /* - * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty + * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null */ def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - if (ids != null && ids.nonEmpty) { + if (ids != null) { ColumnProjectionUtils.appendReadColumns(conf, ids.asJava) } - if (names != null && names.nonEmpty) { + if (names != null) { appendReadColumnNames(conf, names) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index ecb597298452..a628977af2f4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.hive.orc import java.nio.charset.StandardCharsets import java.sql.Timestamp +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.{LogicalRelation, RecordReaderIterator} import org.apache.spark.sql.hive.{HiveUtils, MetastoreRelation} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -577,4 +579,25 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { assert(spark.table(tableName).schema == schema.copy(fields = expectedFields)) } } + + test("Empty schema does not read data from ORC file") { + val data = Seq((1, 1), (2, 2)) + withOrcFile(data) { path => + val requestedSchema = StructType(Nil) + val conf = new Configuration() + val physicalSchema = OrcFileOperator.readSchema(Seq(path), Some(conf)).get + OrcRelation.setRequiredColumns(conf, physicalSchema, requestedSchema) + val maybeOrcReader = OrcFileOperator.getFileReader(path, Some(conf)) + assert(maybeOrcReader.isDefined) + val orcRecordReader = new SparkOrcNewRecordReader( + maybeOrcReader.get, conf, 0, maybeOrcReader.get.getContentLength) + + val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) + try { + assert(recordsIterator.next().toString == "{null, null}") + } finally { + recordsIterator.close() + } + } + } } From 5fe0cdef1dc5163cf60d8d4bcc8bdfe8f1917cf7 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 18 Nov 2016 12:40:18 -0800 Subject: [PATCH 252/381] refine test suite of arrow vectors --- .../spark/sql/DatasetToArrowSuite.scala | 80 ++++++++++--------- 1 file changed, 41 insertions(+), 39 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala index 0e988558a2b8..4e6eb19c3b6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql import java.io._ import java.net.{InetAddress, Socket} +import java.nio.{ByteBuffer, ByteOrder} import java.nio.channels.FileChannel import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.file.ArrowReader -import org.apache.arrow.vector.schema.ArrowRecordBatch + import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -40,28 +41,37 @@ class DatasetToArrowSuite extends QueryTest with SharedSQLContext { val port = ds.collectAsArrowToPython() - val clientThread: Thread = new Thread(new Runnable() { - def run() { - try { - val receiver: RecordBatchReceiver = new RecordBatchReceiver - val record: ArrowRecordBatch = receiver.read(port) - } - catch { - case e: Exception => - throw e - } - } - }) - - clientThread.start() - - try { - clientThread.join() - } catch { - case e: InterruptedException => - throw e - case _ => - } + val receiver: RecordBatchReceiver = new RecordBatchReceiver + val (buffer, numBytesRead) = receiver.connectAndRead(port) + val channel = receiver.makeFile(buffer) + val reader = new ArrowReader(channel, receiver.allocator) + + val footer = reader.readFooter() + val schema = footer.getSchema + assert(schema.getFields.size() === ds.schema.fields.length) + assert(schema.getFields.get(0).getName === ds.schema.fields(0).name) + assert(schema.getFields.get(0).isNullable === ds.schema.fields(0).nullable) + assert(schema.getFields.get(1).getName === ds.schema.fields(1).name) + assert(schema.getFields.get(1).isNullable === ds.schema.fields(1).nullable) + + val blockMetadata = footer.getRecordBatches + assert(blockMetadata.size() === 1) + + val recordBatch = reader.readRecordBatch(blockMetadata.get(0)) + val nodes = recordBatch.getNodes + assert(nodes.size() === 2) + + val firstNode = nodes.get(0) + assert(firstNode.getLength === 3) + assert(firstNode.getNullCount === 0) + + val buffers = recordBatch.getBuffers + assert(buffers.size() === 4) + + val column1 = receiver.getIntArray(buffers.get(1)) + assert(column1=== Array(1, 2, 3)) + val column2 = receiver.getIntArray(buffers.get(3)) + assert(column2 === Array(2, 3, 4)) } } @@ -69,7 +79,14 @@ class RecordBatchReceiver { val allocator = new RootAllocator(Long.MaxValue) - def array(buf: ArrowBuf): Array[Byte] = { + def getIntArray(buf: ArrowBuf): Array[Int] = { + val intBuf = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer() + val intArray = Array.ofDim[Int](intBuf.remaining()) + intBuf.get(intArray) + intArray + } + + private def array(buf: ArrowBuf): Array[Byte] = { val bytes = Array.ofDim[Byte](buf.readableBytes()) buf.readBytes(bytes) bytes @@ -99,19 +116,4 @@ class RecordBatchReceiver { val arrowIns = new FileInputStream(arrowFile.getPath) arrowIns.getChannel } - - def readRecordBatch(channel: FileChannel, len: Int): ArrowRecordBatch = { - val reader = new ArrowReader(channel, allocator) - val footer = reader.readFooter() - val schema = footer.getSchema - val blocks = footer.getRecordBatches - val recordBatch = reader.readRecordBatch(blocks.get(0)) - recordBatch - } - - def read(port: Int): ArrowRecordBatch = { - val (buffer, len) = connectAndRead(port) - val fc = makeFile(buffer) - readRecordBatch(fc, len) - } } From 40d59ff5eaac6df237fe3d50186695c3806b268c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 18 Nov 2016 21:45:18 +0000 Subject: [PATCH 253/381] [SPARK-18422][CORE] Fix wholeTextFiles test to pass on Windows in JavaAPISuite ## What changes were proposed in this pull request? This PR fixes the test `wholeTextFiles` in `JavaAPISuite.java`. This is failed due to the different path format on Windows. For example, the path in `container` was ``` C:\projects\spark\target\tmp\1478967560189-0/part-00000 ``` whereas `new URI(res._1()).getPath()` was as below: ``` /C:/projects/spark/target/tmp/1478967560189-0/part-00000 ``` ## How was this patch tested? Tests in `JavaAPISuite.java`. Tested via AppVeyor. **Before** Build: https://ci.appveyor.com/project/spark-test/spark/build/63-JavaAPISuite-1 Diff: https://github.com/apache/spark/compare/master...spark-test:JavaAPISuite-1 ``` [info] Test org.apache.spark.JavaAPISuite.wholeTextFiles started [error] Test org.apache.spark.JavaAPISuite.wholeTextFiles failed: java.lang.AssertionError: expected: but was:, took 0.578 sec [error] at org.apache.spark.JavaAPISuite.wholeTextFiles(JavaAPISuite.java:1089) ... ``` **After** Build started: [CORE] `org.apache.spark.JavaAPISuite` [![PR-15866](https://ci.appveyor.com/api/projects/status/github/spark-test/spark?branch=198DDA52-F201-4D2B-BE2F-244E0C1725B2&svg=true)](https://ci.appveyor.com/project/spark-test/spark/branch/198DDA52-F201-4D2B-BE2F-244E0C1725B2) Diff: https://github.com/apache/spark/compare/master...spark-test:198DDA52-F201-4D2B-BE2F-244E0C1725B2 ``` [info] Test org.apache.spark.JavaAPISuite.wholeTextFiles started ... ``` Author: hyukjinkwon Closes #15866 from HyukjinKwon/SPARK-18422. --- .../java/org/apache/spark/JavaAPISuite.java | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 533025ba83e7..7bebe0612f9a 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -20,7 +20,6 @@ import java.io.*; import java.nio.channels.FileChannel; import java.nio.ByteBuffer; -import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -46,6 +45,7 @@ import com.google.common.collect.Lists; import com.google.common.base.Throwables; import com.google.common.io.Files; +import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.compress.DefaultCodec; @@ -1075,18 +1075,23 @@ public void wholeTextFiles() throws Exception { byte[] content2 = "spark is also easy to use.\n".getBytes(StandardCharsets.UTF_8); String tempDirName = tempDir.getAbsolutePath(); - Files.write(content1, new File(tempDirName + "/part-00000")); - Files.write(content2, new File(tempDirName + "/part-00001")); + String path1 = new Path(tempDirName, "part-00000").toUri().getPath(); + String path2 = new Path(tempDirName, "part-00001").toUri().getPath(); + + Files.write(content1, new File(path1)); + Files.write(content2, new File(path2)); Map container = new HashMap<>(); - container.put(tempDirName+"/part-00000", new Text(content1).toString()); - container.put(tempDirName+"/part-00001", new Text(content2).toString()); + container.put(path1, new Text(content1).toString()); + container.put(path2, new Text(content2).toString()); JavaPairRDD readRDD = sc.wholeTextFiles(tempDirName, 3); List> result = readRDD.collect(); for (Tuple2 res : result) { - assertEquals(res._2(), container.get(new URI(res._1()).getPath())); + // Note that the paths from `wholeTextFiles` are in URI format on Windows, + // for example, file:/C:/a/b/c. + assertEquals(res._2(), container.get(new Path(res._1()).toUri().getPath())); } } From 974de14903277d96c738000e4705b4e98c7ce54d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 18 Nov 2016 14:47:04 -0800 Subject: [PATCH 254/381] update supported types --- .../scala/org/apache/spark/sql/Dataset.scala | 22 ++++++- .../spark/sql/DatasetToArrowSuite.scala | 58 +++++++++++++------ 2 files changed, 62 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 86febae8aa07..cb3d11857fca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -26,6 +26,7 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import io.netty.buffer.ArrowBuf +import org.apache.arrow.flatbuf.Precision import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.file.ArrowWriter import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} @@ -2291,6 +2292,14 @@ class Dataset[T] private[sql]( dt match { case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) + case StringType => + ArrowType.Utf8.INSTANCE + case DoubleType => + new ArrowType.FloatingPoint(Precision.DOUBLE) + case FloatType => + new ArrowType.FloatingPoint(Precision.SINGLE) + case BooleanType => + ArrowType.Bool.INSTANCE case _ => throw new IllegalArgumentException(s"Unsupported data type") } @@ -2341,7 +2350,18 @@ class Dataset[T] private[sql]( val buffers = this.schema.fields.zipWithIndex.flatMap { case (field, idx) => val validity = internalRowToValidityMap(rows, idx, field, allocator) val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) - rows.foreach { row => buf.writeInt(row.getInt(idx)) } + field.dataType match { + case IntegerType => + rows.foreach { row => buf.writeInt(row.getInt(idx)) } + case StringType => + rows.foreach { row => buf.writeByte(row.getByte(idx)) } + case DoubleType => + rows.foreach { row => buf.writeDouble(row.getDouble(idx)) } + case FloatType => + rows.foreach { row => buf.writeFloat(row.getFloat(idx)) } + case BooleanType => + rows.foreach { row => buf.writeBoolean(row.getBoolean(idx)) } + } Array(validity, buf) }.toList.asJava diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala index 4e6eb19c3b6d..c4271800f300 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala @@ -22,24 +22,39 @@ import java.net.{InetAddress, Socket} import java.nio.{ByteBuffer, ByteOrder} import java.nio.channels.FileChannel +import scala.util.Random + import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.file.ArrowReader +import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils + case class ArrowIntTest(a: Int, b: Int) +case class ArrowIntDoubleTest(a: Int, b: Double) class DatasetToArrowSuite extends QueryTest with SharedSQLContext { import testImplicits._ - test("Collect as arrow to python") { + final val numElements = 4 + @transient var dataset: Dataset[_] = _ + @transient var column1: Seq[Int] = _ + @transient var column2: Seq[Double] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + column1 = Seq.fill(numElements)(Random.nextInt) + column2 = Seq.fill(numElements)(Random.nextDouble) + dataset = column1.zip(column2).map{ case (c1, c2) => ArrowIntDoubleTest(c1, c2) }.toDS() + } - val ds = Seq(ArrowIntTest(1, 2), ArrowIntTest(2, 3), ArrowIntTest(3, 4)).toDS() + test("Collect as arrow to python") { - val port = ds.collectAsArrowToPython() + val port = dataset.collectAsArrowToPython() val receiver: RecordBatchReceiver = new RecordBatchReceiver val (buffer, numBytesRead) = receiver.connectAndRead(port) @@ -48,11 +63,13 @@ class DatasetToArrowSuite extends QueryTest with SharedSQLContext { val footer = reader.readFooter() val schema = footer.getSchema - assert(schema.getFields.size() === ds.schema.fields.length) - assert(schema.getFields.get(0).getName === ds.schema.fields(0).name) - assert(schema.getFields.get(0).isNullable === ds.schema.fields(0).nullable) - assert(schema.getFields.get(1).getName === ds.schema.fields(1).name) - assert(schema.getFields.get(1).isNullable === ds.schema.fields(1).nullable) + assert(schema.getFields.size() === dataset.schema.fields.length) + assert(schema.getFields.get(0).getName === dataset.schema.fields(0).name) + assert(schema.getFields.get(0).isNullable === dataset.schema.fields(0).nullable) + assert(schema.getFields.get(0).getType.isInstanceOf[ArrowType.Int]) + assert(schema.getFields.get(1).getName === dataset.schema.fields(1).name) + assert(schema.getFields.get(1).isNullable === dataset.schema.fields(1).nullable) + assert(schema.getFields.get(1).getType.isInstanceOf[ArrowType.FloatingPoint]) val blockMetadata = footer.getRecordBatches assert(blockMetadata.size() === 1) @@ -62,16 +79,16 @@ class DatasetToArrowSuite extends QueryTest with SharedSQLContext { assert(nodes.size() === 2) val firstNode = nodes.get(0) - assert(firstNode.getLength === 3) + assert(firstNode.getLength === numElements) assert(firstNode.getNullCount === 0) val buffers = recordBatch.getBuffers assert(buffers.size() === 4) - val column1 = receiver.getIntArray(buffers.get(1)) - assert(column1=== Array(1, 2, 3)) - val column2 = receiver.getIntArray(buffers.get(3)) - assert(column2 === Array(2, 3, 4)) + val column1Read = receiver.getIntArray(buffers.get(1)) + assert(column1Read === column1) + val column2Read = receiver.getDoubleArray(buffers.get(3)) + assert(column2Read === column2) } } @@ -80,10 +97,17 @@ class RecordBatchReceiver { val allocator = new RootAllocator(Long.MaxValue) def getIntArray(buf: ArrowBuf): Array[Int] = { - val intBuf = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer() - val intArray = Array.ofDim[Int](intBuf.remaining()) - intBuf.get(intArray) - intArray + val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer() + val resultArray = Array.ofDim[Int](buffer.remaining()) + buffer.get(resultArray) + resultArray + } + + def getDoubleArray(buf: ArrowBuf): Array[Double] = { + val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer() + val resultArray = Array.ofDim[Double](buffer.remaining()) + buffer.get(resultArray) + resultArray } private def array(buf: ArrowBuf): Array[Byte] = { From e5f5c29e021d504284fe5ad1a77dcd5a992ac10a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 18 Nov 2016 16:13:02 -0800 Subject: [PATCH 255/381] [SPARK-18477][SS] Enable interrupts for HDFS in HDFSMetadataLog ## What changes were proposed in this pull request? HDFS `write` may just hang until timeout if some network error happens. It's better to enable interrupts to allow stopping the query fast on HDFS. This PR just changes the logic to only disable interrupts for local file system, as HADOOP-10622 only happens for local file system. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15911 from zsxwing/interrupt-on-dfs. --- .../execution/streaming/HDFSMetadataLog.scala | 56 ++++++++++++++----- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 080729b2ca8d..d95ec7f67feb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -105,25 +105,34 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: /** * Store the metadata for the specified batchId and return `true` if successful. If the batchId's * metadata has already been stored, this method will return `false`. - * - * Note that this method must be called on a [[org.apache.spark.util.UninterruptibleThread]] - * so that interrupts can be disabled while writing the batch file. This is because there is a - * potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). If the thread - * running "Shell.runCommand" is interrupted, then the thread can get deadlocked. In our - * case, `writeBatch` creates a file using HDFS API and calls "Shell.runCommand" to set the - * file permissions, and can get deadlocked if the stream execution thread is stopped by - * interrupt. Hence, we make sure that this method is called on [[UninterruptibleThread]] which - * allows us to disable interrupts here. Also see SPARK-14131. */ override def add(batchId: Long, metadata: T): Boolean = { get(batchId).map(_ => false).getOrElse { // Only write metadata when the batch has not yet been written - Thread.currentThread match { - case ut: UninterruptibleThread => - ut.runUninterruptibly { writeBatch(batchId, metadata, serialize) } - case _ => - throw new IllegalStateException( - "HDFSMetadataLog.add() must be executed on a o.a.spark.util.UninterruptibleThread") + if (fileManager.isLocalFileSystem) { + Thread.currentThread match { + case ut: UninterruptibleThread => + // When using a local file system, "writeBatch" must be called on a + // [[org.apache.spark.util.UninterruptibleThread]] so that interrupts can be disabled + // while writing the batch file. This is because there is a potential dead-lock in + // Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). If the thread running + // "Shell.runCommand" is interrupted, then the thread can get deadlocked. In our case, + // `writeBatch` creates a file using HDFS API and will call "Shell.runCommand" to set + // the file permission if using the local file system, and can get deadlocked if the + // stream execution thread is stopped by interrupt. Hence, we make sure that + // "writeBatch" is called on [[UninterruptibleThread]] which allows us to disable + // interrupts here. Also see SPARK-14131. + ut.runUninterruptibly { writeBatch(batchId, metadata, serialize) } + case _ => + throw new IllegalStateException( + "HDFSMetadataLog.add() on a local file system must be executed on " + + "a o.a.spark.util.UninterruptibleThread") + } + } else { + // For a distributed file system, such as HDFS or S3, if the network is broken, write + // operations may just hang until timeout. We should enable interrupts to allow stopping + // the query fast. + writeBatch(batchId, metadata, serialize) } true } @@ -298,6 +307,9 @@ object HDFSMetadataLog { /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ def delete(path: Path): Unit + + /** Whether the file systme is a local FS. */ + def isLocalFileSystem: Boolean } /** @@ -342,6 +354,13 @@ object HDFSMetadataLog { // ignore if file has already been deleted } } + + override def isLocalFileSystem: Boolean = fc.getDefaultFileSystem match { + case _: local.LocalFs | _: local.RawLocalFs => + // LocalFs = RawLocalFs + ChecksumFs + true + case _ => false + } } /** @@ -398,5 +417,12 @@ object HDFSMetadataLog { // ignore if file has already been deleted } } + + override def isLocalFileSystem: Boolean = fs match { + case _: LocalFileSystem | _: RawLocalFileSystem => + // LocalFileSystem = RawLocalFileSystem + ChecksumFileSystem + true + case _ => false + } } } From 6f7ff75091154fed7649ea6d79e887aad9fbde6a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 18 Nov 2016 16:34:11 -0800 Subject: [PATCH 256/381] [SPARK-18505][SQL] Simplify AnalyzeColumnCommand ## What changes were proposed in this pull request? I'm spending more time at the design & code level for cost-based optimizer now, and have found a number of issues related to maintainability and compatibility that I will like to address. This is a small pull request to clean up AnalyzeColumnCommand: 1. Removed warning on duplicated columns. Warnings in log messages are useless since most users that run SQL don't see them. 2. Removed the nested updateStats function, by just inlining the function. 3. Renamed a few functions to better reflect what they do. 4. Removed the factory apply method for ColumnStatStruct. It is a bad pattern to use a apply method that returns an instantiation of a class that is not of the same type (ColumnStatStruct.apply used to return CreateNamedStruct). 5. Renamed ColumnStatStruct to just AnalyzeColumnCommand. 6. Added more documentation explaining some of the non-obvious return types and code blocks. In follow-up pull requests, I'd like to address the following: 1. Get rid of the Map[String, ColumnStat] map, since internally we should be using Attribute to reference columns, rather than strings. 2. Decouple the fields exposed by ColumnStat and internals of Spark SQL's execution path. Currently the two are coupled because ColumnStat takes in an InternalRow. 3. Correctness: Remove code path that stores statistics in the catalog using the base64 encoding of the UnsafeRow format, which is not stable across Spark versions. 4. Clearly document the data representation stored in the catalog for statistics. ## How was this patch tested? Affected test cases have been updated. Author: Reynold Xin Closes #15933 from rxin/SPARK-18505. --- .../command/AnalyzeColumnCommand.scala | 115 ++++++++++-------- .../spark/sql/StatisticsColumnSuite.scala | 2 +- .../org/apache/spark/sql/StatisticsTest.scala | 7 +- .../spark/sql/hive/HiveExternalCatalog.scala | 4 +- .../sql/hive/client/HiveClientImpl.scala | 2 +- 5 files changed, 74 insertions(+), 56 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 6141fab4aff0..7fc57d09e924 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.command -import scala.collection.mutable - +import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases @@ -44,13 +43,16 @@ case class AnalyzeColumnCommand( val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) - relation match { + // Compute total size + val (catalogTable: CatalogTable, sizeInBytes: Long) = relation match { case catalogRel: CatalogRelation => - updateStats(catalogRel.catalogTable, + // This is a Hive serde format table + (catalogRel.catalogTable, AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable)) case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => - updateStats(logicalRel.catalogTable.get, + // This is a data source format table + (logicalRel.catalogTable.get, AnalyzeTableCommand.calculateTotalSize(sessionState, logicalRel.catalogTable.get)) case otherRelation => @@ -58,45 +60,45 @@ case class AnalyzeColumnCommand( s"${otherRelation.nodeName}.") } - def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { - val (rowCount, columnStats) = computeColStats(sparkSession, relation) - // We also update table-level stats in order to keep them consistent with column-level stats. - val statistics = Statistics( - sizeInBytes = newTotalSize, - rowCount = Some(rowCount), - // Newly computed column stats should override the existing ones. - colStats = catalogTable.stats.map(_.colStats).getOrElse(Map()) ++ columnStats) - sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) - // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdentWithDB) - } + // Compute stats for each column + val (rowCount, newColStats) = + AnalyzeColumnCommand.computeColStats(sparkSession, relation, columnNames) + + // We also update table-level stats in order to keep them consistent with column-level stats. + val statistics = Statistics( + sizeInBytes = sizeInBytes, + rowCount = Some(rowCount), + // Newly computed column stats should override the existing ones. + colStats = catalogTable.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) + + sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) + + // Refresh the cached data source table in the catalog. + sessionState.catalog.refreshTable(tableIdentWithDB) Seq.empty[Row] } +} +object AnalyzeColumnCommand extends Logging { + + /** + * Compute stats for the given columns. + * @return (row count, map from column name to ColumnStats) + * + * This is visible for testing. + */ def computeColStats( sparkSession: SparkSession, - relation: LogicalPlan): (Long, Map[String, ColumnStat]) = { + relation: LogicalPlan, + columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { - // check correctness of column names - val attributesToAnalyze = mutable.MutableList[Attribute]() - val duplicatedColumns = mutable.MutableList[String]() + // Resolve the column names and dedup using AttributeSet val resolver = sparkSession.sessionState.conf.resolver - columnNames.foreach { col => + val attributesToAnalyze = AttributeSet(columnNames.map { col => val exprOption = relation.output.find(attr => resolver(attr.name, col)) - val expr = exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) - // do deduplication - if (!attributesToAnalyze.contains(expr)) { - attributesToAnalyze += expr - } else { - duplicatedColumns += col - } - } - if (duplicatedColumns.nonEmpty) { - logWarning("Duplicate column names were deduplicated in `ANALYZE TABLE` statement. " + - s"Input columns: ${columnNames.mkString("(", ", ", ")")}. " + - s"Duplicate columns: ${duplicatedColumns.mkString("(", ", ", ")")}.") - } + exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) + }).toSeq // Collect statistics per column. // The first element in the result will be the overall row count, the following elements @@ -104,22 +106,21 @@ case class AnalyzeColumnCommand( // The layout of each struct follows the layout of the ColumnStats. val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStatStruct(_, ndvMaxErr)) + attributesToAnalyze.map(AnalyzeColumnCommand.createColumnStatStruct(_, ndvMaxErr)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) .queryExecution.toRdd.collect().head // unwrap the result + // TODO: Get rid of numFields by using the public Dataset API. val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - val numFields = ColumnStatStruct.numStatFields(expr.dataType) + val numFields = AnalyzeColumnCommand.numStatFields(expr.dataType) (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields))) }.toMap (rowCount, columnStats) } -} -object ColumnStatStruct { private val zero = Literal(0, LongType) private val one = Literal(1, LongType) @@ -137,7 +138,11 @@ object ColumnStatStruct { private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - private def getStruct(exprs: Seq[Expression]): CreateNamedStruct = { + /** + * Creates a struct that groups the sequence of expressions together. This is used to create + * one top level struct per column. + */ + private def createStruct(exprs: Seq[Expression]): CreateNamedStruct = { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() @@ -161,6 +166,7 @@ object ColumnStatStruct { Seq(numNulls(e), numTrues(e), numFalses(e)) } + // TODO(rxin): Get rid of this function. def numStatFields(dataType: DataType): Int = { dataType match { case BinaryType | BooleanType => 3 @@ -168,14 +174,25 @@ object ColumnStatStruct { } } - def apply(attr: Attribute, relativeSD: Double): CreateNamedStruct = attr.dataType match { - // Use aggregate functions to compute statistics we need. - case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD)) - case StringType => getStruct(stringColumnStat(attr, relativeSD)) - case BinaryType => getStruct(binaryColumnStat(attr)) - case BooleanType => getStruct(booleanColumnStat(attr)) - case otherType => - throw new AnalysisException("Analyzing columns is not supported for column " + - s"${attr.name} of data type: ${attr.dataType}.") + /** + * Creates a struct expression that contains the statistics to collect for a column. + * + * @param attr column to collect statistics + * @param relativeSD relative error for approximate number of distinct values. + */ + def createColumnStatStruct(attr: Attribute, relativeSD: Double): CreateNamedStruct = { + attr.dataType match { + case _: NumericType | TimestampType | DateType => + createStruct(numericColumnStat(attr, relativeSD)) + case StringType => + createStruct(stringColumnStat(attr, relativeSD)) + case BinaryType => + createStruct(binaryColumnStat(attr)) + case BooleanType => + createStruct(booleanColumnStat(attr)) + case otherType => + throw new AnalysisException("Analyzing columns is not supported for column " + + s"${attr.name} of data type: ${attr.dataType}.") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index f1a201abd8da..e866ac2cb3b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -79,7 +79,7 @@ class StatisticsColumnSuite extends StatisticsTest { val tableIdent = TableIdentifier(table, Some("default")) val relation = spark.sessionState.catalog.lookupRelation(tableIdent) val (_, columnStats) = - AnalyzeColumnCommand(tableIdent, columnsToAnalyze).computeColStats(spark, relation) + AnalyzeColumnCommand.computeColStats(spark, relation, columnsToAnalyze) assert(columnStats.contains(colName1)) assert(columnStats.contains(colName2)) // check deduplication diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala index 5134ac0e7e5b..915ee0d31bca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} -import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, ColumnStatStruct} +import org.apache.spark.sql.execution.command.AnalyzeColumnCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ + trait StatisticsTest extends QueryTest with SharedSQLContext { def checkColStats( @@ -36,7 +37,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { val tableIdent = TableIdentifier(table, Some("default")) val relation = spark.sessionState.catalog.lookupRelation(tableIdent) val (_, columnStats) = - AnalyzeColumnCommand(tableIdent, columns.map(_.name)).computeColStats(spark, relation) + AnalyzeColumnCommand.computeColStats(spark, relation, columns.map(_.name)) expectedColStatsSeq.foreach { case (field, expectedColStat) => assert(columnStats.contains(field.name)) val colStat = columnStats(field.name) @@ -48,7 +49,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { // check if we get the same colStat after encoding and decoding val encodedCS = colStat.toString - val numFields = ColumnStatStruct.numStatFields(field.dataType) + val numFields = AnalyzeColumnCommand.numStatFields(field.dataType) val decodedCS = ColumnStat(numFields, encodedCS) StatisticsTest.checkColStat( dataType = field.dataType, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index cacffcf33c26..5dbb4024bbee 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} +import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, DDLUtils} import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ @@ -634,7 +634,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } val colStats: Map[String, ColumnStat] = tableWithSchema.schema.collect { case f if colStatsProps.contains(f.name) => - val numFields = ColumnStatStruct.numStatFields(f.dataType) + val numFields = AnalyzeColumnCommand.numStatFields(f.dataType) (f.name, ColumnStat(numFields, colStatsProps(f.name))) }.toMap tableWithSchema.copy( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 2bf9a26b0b7f..daae8523c636 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -97,7 +97,7 @@ private[hive] class HiveClientImpl( } // Create an internal session state for this HiveClientImpl. - val state = { + val state: SessionState = { val original = Thread.currentThread().getContextClassLoader // Switch to the initClassLoader. Thread.currentThread().setContextClassLoader(initClassLoader) From 2a40de408b5eb47edba92f9fe92a42ed1e78bf98 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 18 Nov 2016 16:34:38 -0800 Subject: [PATCH 257/381] [SPARK-18497][SS] Make ForeachSink support watermark ## What changes were proposed in this pull request? The issue in ForeachSink is the new created DataSet still uses the old QueryExecution. When `foreachPartition` is called, `QueryExecution.toString` will be called and then fail because it doesn't know how to plan EventTimeWatermark. This PR just replaces the QueryExecution with IncrementalExecution to fix the issue. ## How was this patch tested? `test("foreach with watermark")`. Author: Shixiong Zhu Closes #15934 from zsxwing/SPARK-18497. --- .../sql/execution/streaming/ForeachSink.scala | 16 ++++----- .../streaming/ForeachSinkSuite.scala | 35 +++++++++++++++++++ 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index f5c550dd6ac3..c93fcfb77cc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -47,22 +47,22 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria // method supporting incremental planning. But in the long run, we should generally make newly // created Datasets use `IncrementalExecution` where necessary (which is SPARK-16264 tries to // resolve). - + val incrementalExecution = data.queryExecution.asInstanceOf[IncrementalExecution] val datasetWithIncrementalExecution = - new Dataset(data.sparkSession, data.logicalPlan, implicitly[Encoder[T]]) { + new Dataset(data.sparkSession, incrementalExecution, implicitly[Encoder[T]]) { override lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType val deserialized = CatalystSerde.deserialize[T](logicalPlan) // was originally: sparkSession.sessionState.executePlan(deserialized) ... - val incrementalExecution = new IncrementalExecution( + val newIncrementalExecution = new IncrementalExecution( this.sparkSession, deserialized, - data.queryExecution.asInstanceOf[IncrementalExecution].outputMode, - data.queryExecution.asInstanceOf[IncrementalExecution].checkpointLocation, - data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId, - data.queryExecution.asInstanceOf[IncrementalExecution].currentEventTimeWatermark) - incrementalExecution.toRdd.mapPartitions { rows => + incrementalExecution.outputMode, + incrementalExecution.checkpointLocation, + incrementalExecution.currentBatchId, + incrementalExecution.currentEventTimeWatermark) + newIncrementalExecution.toRdd.mapPartitions { rows => rows.map(_.get(0, objectType)) }.asInstanceOf[RDD[T]] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 9e059216110f..ee6261036fdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext @@ -169,6 +170,40 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf assert(errorEvent.error.get.getMessage === "error") } } + + test("foreach with watermark") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"count".as[Long]) + .map(_.toInt) + .repartition(1) + + val query = windowedAggregation + .writeStream + .outputMode(OutputMode.Complete) + .foreach(new TestForeachWriter()) + .start() + try { + inputData.addData(10, 11, 12) + query.processAllAvailable() + + val allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 1) + val expectedEvents = Seq( + ForeachSinkSuite.Open(partition = 0, version = 0), + ForeachSinkSuite.Process(value = 3), + ForeachSinkSuite.Close(None) + ) + assert(allEvents === Seq(expectedEvents)) + } finally { + query.stop() + } + } } /** A global object to collect events in the executor */ From db9fb9baacbf8640dd37a507b7450db727c7e6ea Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 19 Nov 2016 09:00:11 +0000 Subject: [PATCH 258/381] [SPARK-18448][CORE] SparkSession should implement java.lang.AutoCloseable like JavaSparkContext ## What changes were proposed in this pull request? Just adds `close()` + `Closeable` as a synonym for `stop()`. This makes it usable in Java in try-with-resources, as suggested by ash211 (`Closeable` extends `AutoCloseable` BTW) ## How was this patch tested? Existing tests Author: Sean Owen Closes #15932 from srowen/SPARK-18448. --- .../main/scala/org/apache/spark/sql/SparkSession.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 3045eb69f427..58b2ab395717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.beans.Introspector +import java.io.Closeable import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ @@ -72,7 +73,7 @@ import org.apache.spark.util.Utils class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState]) - extends Serializable with Logging { self => + extends Serializable with Closeable with Logging { self => private[sql] def this(sc: SparkContext) { this(sc, None) @@ -647,6 +648,13 @@ class SparkSession private( sparkContext.stop() } + /** + * Synonym for `stop()`. + * + * @since 2.2.0 + */ + override def close(): Unit = stop() + /** * Parses the data type in our internal string representation. The data type string should * have the same format as the one generated by `toString` in scala. From d5b1d5fc80153571c308130833d0c0774de62c92 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 19 Nov 2016 11:24:15 +0000 Subject: [PATCH 259/381] [SPARK-18445][BUILD][DOCS] Fix the markdown for `Note:`/`NOTE:`/`Note that`/`'''Note:'''` across Scala/Java API documentation ## What changes were proposed in this pull request? It seems in Scala/Java, - `Note:` - `NOTE:` - `Note that` - `'''Note:'''` - `note` This PR proposes to fix those to `note` to be consistent. **Before** - Scala ![2016-11-17 6 16 39](https://cloud.githubusercontent.com/assets/6477701/20383180/1a7aed8c-acf2-11e6-9611-5eaf6d52c2e0.png) - Java ![2016-11-17 6 14 41](https://cloud.githubusercontent.com/assets/6477701/20383096/c8ffc680-acf1-11e6-914a-33460bf1401d.png) **After** - Scala ![2016-11-17 6 16 44](https://cloud.githubusercontent.com/assets/6477701/20383167/09940490-acf2-11e6-937a-0d5e1dc2cadf.png) - Java ![2016-11-17 6 13 39](https://cloud.githubusercontent.com/assets/6477701/20383132/e7c2a57e-acf1-11e6-9c47-b849674d4d88.png) ## How was this patch tested? The notes were found via ```bash grep -r "NOTE: " . | \ # Note:|NOTE:|Note that|'''Note:''' grep -v "// NOTE: " | \ # starting with // does not appear in API documentation. grep -E '.scala|.java' | \ # java/scala files grep -v Suite | \ # exclude tests grep -v Test | \ # exclude tests grep -e 'org.apache.spark.api.java' \ # packages appear in API documenation -e 'org.apache.spark.api.java.function' \ # note that this is a regular expression. So actual matches were mostly `org/apache/spark/api/java/functions ...` -e 'org.apache.spark.api.r' \ ... ``` ```bash grep -r "Note that " . | \ # Note:|NOTE:|Note that|'''Note:''' grep -v "// Note that " | \ # starting with // does not appear in API documentation. grep -E '.scala|.java' | \ # java/scala files grep -v Suite | \ # exclude tests grep -v Test | \ # exclude tests grep -e 'org.apache.spark.api.java' \ # packages appear in API documenation -e 'org.apache.spark.api.java.function' \ -e 'org.apache.spark.api.r' \ ... ``` ```bash grep -r "Note: " . | \ # Note:|NOTE:|Note that|'''Note:''' grep -v "// Note: " | \ # starting with // does not appear in API documentation. grep -E '.scala|.java' | \ # java/scala files grep -v Suite | \ # exclude tests grep -v Test | \ # exclude tests grep -e 'org.apache.spark.api.java' \ # packages appear in API documenation -e 'org.apache.spark.api.java.function' \ -e 'org.apache.spark.api.r' \ ... ``` ```bash grep -r "'''Note:'''" . | \ # Note:|NOTE:|Note that|'''Note:''' grep -v "// '''Note:''' " | \ # starting with // does not appear in API documentation. grep -E '.scala|.java' | \ # java/scala files grep -v Suite | \ # exclude tests grep -v Test | \ # exclude tests grep -e 'org.apache.spark.api.java' \ # packages appear in API documenation -e 'org.apache.spark.api.java.function' \ -e 'org.apache.spark.api.r' \ ... ``` And then fixed one by one comparing with API documentation/access modifiers. After that, manually tested via `jekyll build`. Author: hyukjinkwon Closes #15889 from HyukjinKwon/SPARK-18437. --- .../org/apache/spark/ContextCleaner.scala | 2 +- .../scala/org/apache/spark/Partitioner.scala | 2 +- .../scala/org/apache/spark/SparkConf.scala | 6 +- .../scala/org/apache/spark/SparkContext.scala | 47 ++++++++------- .../apache/spark/api/java/JavaDoubleRDD.scala | 4 +- .../apache/spark/api/java/JavaPairRDD.scala | 26 ++++---- .../org/apache/spark/api/java/JavaRDD.scala | 12 ++-- .../apache/spark/api/java/JavaRDDLike.scala | 3 +- .../spark/api/java/JavaSparkContext.scala | 21 +++---- .../api/java/JavaSparkStatusTracker.scala | 2 +- .../io/SparkHadoopMapReduceWriter.scala | 2 +- .../apache/spark/io/CompressionCodec.scala | 23 ++++--- .../apache/spark/partial/BoundedDouble.scala | 2 +- .../org/apache/spark/rdd/CoGroupedRDD.scala | 8 +-- .../apache/spark/rdd/DoubleRDDFunctions.scala | 2 +- .../org/apache/spark/rdd/HadoopRDD.scala | 6 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 6 +- .../apache/spark/rdd/PairRDDFunctions.scala | 23 +++---- .../spark/rdd/PartitionPruningRDD.scala | 2 +- .../spark/rdd/PartitionwiseSampledRDD.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 46 +++++++------- .../apache/spark/rdd/RDDCheckpointData.scala | 2 +- .../spark/rdd/ReliableCheckpointRDD.scala | 2 +- .../spark/rdd/SequenceFileRDDFunctions.scala | 5 +- .../apache/spark/rdd/ZippedWithIndexRDD.scala | 2 +- .../spark/scheduler/AccumulableInfo.scala | 10 ++-- .../spark/serializer/JavaSerializer.scala | 2 +- .../spark/serializer/KryoSerializer.scala | 2 +- .../apache/spark/serializer/Serializer.scala | 2 +- .../apache/spark/storage/StorageUtils.scala | 19 +++--- .../org/apache/spark/util/AccumulatorV2.scala | 5 +- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- docs/mllib-isotonic-regression.md | 2 +- docs/streaming-programming-guide.md | 2 +- .../spark/sql/kafka010/KafkaSource.scala | 2 +- .../spark/streaming/kafka/KafkaUtils.scala | 8 +-- .../streaming/kinesis/KinesisUtils.scala | 60 +++++++++---------- .../kinesis/KinesisBackedBlockRDDSuite.scala | 2 +- .../apache/spark/graphx/impl/GraphImpl.scala | 2 +- .../apache/spark/graphx/lib/PageRank.scala | 2 +- .../org/apache/spark/ml/linalg/Vectors.scala | 2 +- .../scala/org/apache/spark/ml/Model.scala | 2 +- .../DecisionTreeClassifier.scala | 6 +- .../ml/classification/GBTClassifier.scala | 6 +- .../classification/LogisticRegression.scala | 36 +++++------ .../spark/ml/clustering/GaussianMixture.scala | 6 +- .../spark/ml/feature/MinMaxScaler.scala | 3 +- .../spark/ml/feature/OneHotEncoder.scala | 3 +- .../org/apache/spark/ml/feature/PCA.scala | 5 +- .../spark/ml/feature/StopWordsRemover.scala | 5 +- .../spark/ml/feature/StringIndexer.scala | 6 +- .../org/apache/spark/ml/param/params.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 6 +- .../GeneralizedLinearRegression.scala | 4 +- .../ml/regression/LinearRegression.scala | 28 +++++---- .../ml/source/libsvm/LibSVMDataSource.scala | 2 +- .../ml/tree/impl/GradientBoostedTrees.scala | 4 +- .../org/apache/spark/ml/util/ReadWrite.scala | 2 +- .../classification/LogisticRegression.scala | 28 +++++---- .../spark/mllib/classification/SVM.scala | 20 ++++--- .../mllib/clustering/GaussianMixture.scala | 8 +-- .../spark/mllib/clustering/KMeans.scala | 8 ++- .../apache/spark/mllib/clustering/LDA.scala | 4 +- .../spark/mllib/clustering/LDAModel.scala | 2 +- .../spark/mllib/clustering/LDAOptimizer.scala | 6 +- .../mllib/evaluation/AreaUnderCurve.scala | 2 +- .../apache/spark/mllib/linalg/Vectors.scala | 6 +- .../linalg/distributed/BlockMatrix.scala | 2 +- .../linalg/distributed/IndexedRowMatrix.scala | 5 +- .../mllib/linalg/distributed/RowMatrix.scala | 21 ++++--- .../spark/mllib/optimization/Gradient.scala | 3 +- .../apache/spark/mllib/rdd/RDDFunctions.scala | 2 +- .../MatrixFactorizationModel.scala | 6 +- .../apache/spark/mllib/stat/Statistics.scala | 34 +++++------ .../spark/mllib/tree/DecisionTree.scala | 32 +++++----- .../apache/spark/mllib/tree/loss/Loss.scala | 12 ++-- .../mllib/tree/model/treeEnsembleModels.scala | 4 +- pom.xml | 7 +++ project/SparkBuild.scala | 3 +- python/pyspark/mllib/stat/KernelDensity.py | 2 +- python/pyspark/mllib/util.py | 2 +- python/pyspark/rdd.py | 4 +- python/pyspark/streaming/kafka.py | 4 +- .../scala/org/apache/spark/sql/Encoders.scala | 8 +-- .../sql/types/CalendarIntervalType.scala | 4 +- .../scala/org/apache/spark/sql/Column.scala | 2 +- .../spark/sql/DataFrameStatFunctions.scala | 3 +- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 56 ++++++++--------- .../org/apache/spark/sql/SQLContext.scala | 7 ++- .../org/apache/spark/sql/SparkSession.scala | 9 +-- .../apache/spark/sql/UDFRegistration.scala | 3 +- .../execution/streaming/state/package.scala | 4 +- .../sql/expressions/UserDefinedFunction.scala | 8 ++- .../org/apache/spark/sql/functions.scala | 22 +++---- .../apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- .../apache/spark/sql/sources/interfaces.scala | 10 ++-- .../sql/util/QueryExecutionListener.scala | 8 ++- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../spark/streaming/StreamingContext.scala | 18 +++--- .../streaming/api/java/JavaPairDStream.scala | 2 +- .../api/java/JavaStreamingContext.scala | 40 +++++++------ .../spark/streaming/dstream/DStream.scala | 4 +- .../dstream/MapWithStateDStream.scala | 2 +- .../WriteAheadLogBackedBlockRDDSuite.scala | 2 +- 105 files changed, 517 insertions(+), 436 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 5678d790e9e7..af913454fce6 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -139,7 +139,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { periodicGCService.shutdown() } - /** Register a RDD for cleanup when it is garbage collected. */ + /** Register an RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]): Unit = { registerForCleanup(rdd, CleanRDD(rdd.id)) } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 93dfbc0e6ed6..f83f5278e8b8 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -101,7 +101,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly * equal ranges. The ranges are determined by sampling the content of the RDD passed in. * - * Note that the actual number of partitions created by the RangePartitioner might not be the same + * @note The actual number of partitions created by the RangePartitioner might not be the same * as the `partitions` parameter, in the case where the number of sampled records is less than * the value of `partitions`. */ diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c9c342df82c9..04d657c09afd 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -42,10 +42,10 @@ import org.apache.spark.util.Utils * All setter methods in this class support chaining. For example, you can write * `new SparkConf().setMaster("local").setAppName("My app")`. * - * Note that once a SparkConf object is passed to Spark, it is cloned and can no longer be modified - * by the user. Spark does not support modifying the configuration at runtime. - * * @param loadDefaults whether to also load values from Java system properties + * + * @note Once a SparkConf object is passed to Spark, it is cloned and can no longer be modified + * by the user. Spark does not support modifying the configuration at runtime. */ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Serializable { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 25a3d609a6b0..1261e3e73576 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -281,7 +281,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration: Configuration = _hadoopConfiguration @@ -700,7 +700,7 @@ class SparkContext(config: SparkConf) extends Logging { * Execute a block of code in a scope such that all new RDDs created in this body will * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. * - * Note: Return statements are NOT allowed in the given body. + * @note Return statements are NOT allowed in the given body. */ private[spark] def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) @@ -927,7 +927,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Load data from a flat binary file, assuming the length of each record is constant. * - * '''Note:''' We ensure that the byte array for each record in the resulting RDD + * @note We ensure that the byte array for each record in the resulting RDD * has the provided record length. * * @param path Directory to the input data files, the path can be comma separated paths as the @@ -970,7 +970,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -995,7 +995,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1034,7 +1034,7 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minPartitions) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1058,7 +1058,7 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1084,7 +1084,7 @@ class SparkContext(config: SparkConf) extends Logging { * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1124,7 +1124,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param kClass Class of the keys * @param vClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1150,7 +1150,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1169,7 +1169,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1199,7 +1199,7 @@ class SparkContext(config: SparkConf) extends Logging { * for the appropriate type. In addition, we pass the converter a ClassTag of its type to * allow it to figure out the Writable class to use in the subclass case. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1330,16 +1330,18 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Register the given accumulator. Note that accumulators must be registered before use, or it - * will throw exception. + * Register the given accumulator. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _]): Unit = { acc.register(this) } /** - * Register the given accumulator with given name. Note that accumulators must be registered - * before use, or it will throw exception. + * Register the given accumulator with given name. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _], name: String): Unit = { acc.register(this, name = Some(name)) @@ -1550,7 +1552,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executors it kills * through this method with new ones, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1572,7 +1574,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executor. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executor it kills * through this method with a new one, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1590,7 +1592,7 @@ class SparkContext(config: SparkConf) extends Logging { * this request. This assumes the cluster manager will automatically and eventually * fulfill all missing application resource requests. * - * Note: The replace is by no means guaranteed; another application on the same cluster + * @note The replace is by no means guaranteed; another application on the same cluster * can steal the window of opportunity and acquire this application's resources in the * mean time. * @@ -1639,7 +1641,8 @@ class SparkContext(config: SparkConf) extends Logging { /** * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap @@ -2298,7 +2301,7 @@ object SparkContext extends Logging { * singleton object. Because we can only have one active SparkContext per JVM, * this is useful when applications may wish to share a SparkContext. * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(config: SparkConf): SparkContext = { @@ -2323,7 +2326,7 @@ object SparkContext extends Logging { * * This method allows not passing a SparkConf (useful if just retrieving). * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(): SparkContext = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 0026fc9dad51..a32a4b28c173 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -153,7 +153,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd)) @@ -256,7 +256,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * e.g 1<=x<10 , 10<=x<20, 20<=x<50 * And on the input of 1 and 50 we would have a histogram of 1,0,0 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 1c95bc4bfcaa..bff5a29bb60f 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -206,7 +206,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.intersection(other.rdd)) @@ -223,9 +223,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -234,6 +234,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * In addition, users can control the partitioning of the output RDD, the serializer that is use * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple * items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -255,9 +258,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -265,6 +268,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD. This method automatically * uses map-side aggregation in shuffling the RDD. + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -398,7 +404,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] * will provide much better performance. */ @@ -409,7 +415,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with into `numPartitions` partitions. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] * will provide much better performance. */ @@ -539,7 +545,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] * will provide much better performance. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index d67cff64e6e4..ccd94f876e0b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -99,27 +99,29 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD with a random seed. - * Note: this is NOT guaranteed to provide exactly the fraction of the count - * of the given [[RDD]]. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) /** * Return a sampled subset of this RDD, with a user-supplied seed. - * Note: this is NOT guaranteed to provide exactly the fraction of the count - * of the given [[RDD]]. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] * with replacement: expected number of times each element is chosen; fraction must be >= 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) @@ -157,7 +159,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index a37c52cbaf21..eda16d957cc5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -47,7 +47,8 @@ private[spark] abstract class AbstractJavaRDDLike[T, This <: JavaRDDLike[T, This /** * Defines operations common to several Java RDD implementations. - * Note that this trait is not intended to be implemented by user code. + * + * @note This trait is not intended to be implemented by user code. */ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 4e50c2686dd5..38d347aeab8c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -298,7 +298,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -316,7 +316,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -366,7 +366,7 @@ class JavaSparkContext(val sc: SparkContext) * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -396,7 +396,7 @@ class JavaSparkContext(val sc: SparkContext) * @param keyClass Class of the keys * @param valueClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -416,7 +416,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -437,7 +437,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -458,7 +458,7 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -487,7 +487,7 @@ class JavaSparkContext(val sc: SparkContext) * @param kClass Class of the keys * @param vClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -694,7 +694,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration(): Configuration = { @@ -811,7 +811,8 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns a Java map of JavaRDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: JMap[java.lang.Integer, JavaRDD[_]] = { sc.getPersistentRDDs.mapValues(s => JavaRDD.fromRDD(s)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala index 99ca3c77cced..6aa290ecd7bb 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala @@ -31,7 +31,7 @@ import org.apache.spark.{SparkContext, SparkJobInfo, SparkStageInfo} * will provide information for the last `spark.ui.retainedStages` stages and * `spark.ui.retainedJobs` jobs. * - * NOTE: this class's constructor should be considered private and may be subject to change. + * @note This class's constructor should be considered private and may be subject to change. */ class JavaSparkStatusTracker private[spark] (sc: SparkContext) { diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala index 796439276a22..aaeb3d003829 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala @@ -119,7 +119,7 @@ object SparkHadoopMapReduceWriter extends Logging { } } - /** Write a RDD partition out in a single Spark task. */ + /** Write an RDD partition out in a single Spark task. */ private def executeTask[K, V: ClassTag]( context: TaskContext, jobTrackerId: String, diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index ae014becef75..6ba79e506a64 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -32,9 +32,8 @@ import org.apache.spark.util.Utils * CompressionCodec allows the customization of choosing different compression implementations * to be used in block storage. * - * Note: The wire protocol for a codec is not guaranteed compatible across versions of Spark. - * This is intended for use as an internal compression utility within a single - * Spark application. + * @note The wire protocol for a codec is not guaranteed compatible across versions of Spark. + * This is intended for use as an internal compression utility within a single Spark application. */ @DeveloperApi trait CompressionCodec { @@ -103,9 +102,9 @@ private[spark] object CompressionCodec { * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.lz4.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -123,9 +122,9 @@ class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { * :: DeveloperApi :: * LZF implementation of [[org.apache.spark.io.CompressionCodec]]. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -143,9 +142,9 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.snappy.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index ab6aba6fc7d6..8f579c5a3033 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -28,7 +28,7 @@ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, v this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode /** - * Note that consistent with Double, any NaN value will make equality false + * @note Consistent with Double, any NaN value will make equality false */ override def equals(that: Any): Boolean = that match { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 2381f54ee3f0..a091f06b4ed7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -66,14 +66,14 @@ private[spark] class CoGroupPartition( /** * :: DeveloperApi :: - * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a + * An RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a * tuple with the list of values for that key. * - * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of - * instantiating this directly. - * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output + * + * @note This is an internal API. We recommend users use RDD.cogroup(...) instead of + * instantiating this directly. */ @DeveloperApi class CoGroupedRDD[K: ClassTag]( diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a05a770b40c5..f3ab324d5911 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -158,7 +158,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * e.g 1<=x<10 , 10<=x<20, 20<=x<=50 * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 36a2f5c87e37..86351b8c575e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -84,9 +84,6 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.hadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. @@ -97,6 +94,9 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. * @param minPartitions Minimum number of HadoopRDD partitions (Hadoop Splits) to generate. + * + * @note Instantiating this class directly is not recommended, please use + * [[org.apache.spark.SparkContext.hadoopRDD()]] */ @DeveloperApi class HadoopRDD[K, V]( diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 488e777fea37..a5965f597038 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -57,13 +57,13 @@ private[spark] class NewHadoopPartition( * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param inputFormatClass Storage format of the data to be read. * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. + * + * @note Instantiating this class directly is not recommended, please use + * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]] */ @DeveloperApi class NewHadoopRDD[K, V]( diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index f9b9631d9e7c..33e695ec5322 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -57,8 +57,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * :: Experimental :: * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C - * Note that V and C can be different -- for example, one might group an RDD of type - * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -66,6 +66,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type + * (Int, Int) into an RDD of type (Int, Seq[Int]). */ @Experimental def combineByKeyWithClassTag[C]( @@ -361,7 +364,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Count the number of elements for each key, collecting the results to a local Map. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. @@ -488,11 +491,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * The ordering of elements within each group is not guaranteed, and may even differ * each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope { @@ -512,11 +515,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * resulting RDD with into `numPartitions` partitions. The ordering of elements within * each group is not guaranteed, and may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = self.withScope { @@ -633,7 +636,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * within each group is not guaranteed, and may even differ each time the resulting RDD is * evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -1014,7 +1017,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. @@ -1068,7 +1071,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 0c6ddda52cee..ce75a16031a3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -48,7 +48,7 @@ private[spark] class PruneDependency[T](rdd: RDD[T], partitionFilterFunc: Int => /** * :: DeveloperApi :: - * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on + * An RDD used to prune RDD partitions/partitions so we can avoid launching tasks on * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index 3b1acacf409b..6a89ea878646 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -32,7 +32,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) } /** - * A RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, + * An RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, * a user-specified [[org.apache.spark.util.random.RandomSampler]] instance is used to obtain * a random sample of the records in the partition. The random seeds assigned to the samplers * are guaranteed to have different values. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index cded899db1f5..bff2b8f1d06c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -428,7 +428,7 @@ abstract class RDD[T: ClassTag]( * current upstream partitions will be executed in parallel (per whatever * the current partitioning is). * - * Note: With shuffle = true, you can actually coalesce to a larger number + * @note With shuffle = true, you can actually coalesce to a larger number * of partitions. This is useful if you have a small number of partitions, * say 100, potentially with a few partitions being abnormally large. Calling * coalesce(1000, shuffle = true) will result in 1000 partitions with the @@ -466,14 +466,14 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. * - * Note: this is NOT guaranteed to provide exactly the fraction of the count - * of the given [[RDD]]. - * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] * with replacement: expected number of times each element is chosen; fraction must be >= 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample( withReplacement: Boolean, @@ -537,13 +537,13 @@ abstract class RDD[T: ClassTag]( /** * Return a fixed-size sampled subset of this RDD in an array * - * @note this method should only be used if the resulting array is expected to be small, as - * all the data is loaded into the driver's memory. - * * @param withReplacement whether sampling is done with replacement * @param num size of the returned sample * @param seed seed for the random number generator * @return sample of specified size in an array + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def takeSample( withReplacement: Boolean, @@ -618,7 +618,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: RDD[T]): RDD[T] = withScope { this.map(v => (v, null)).cogroup(other.map(v => (v, null))) @@ -630,7 +630,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param partitioner Partitioner to use for the resulting RDD */ @@ -646,7 +646,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. Performs a hash partition across the cluster * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param numPartitions How many partitions to use in the resulting RDD */ @@ -674,7 +674,7 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -687,7 +687,7 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -702,7 +702,7 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -921,7 +921,7 @@ abstract class RDD[T: ClassTag]( /** * Return an array that contains all of the elements in this RDD. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. */ def collect(): Array[T] = withScope { @@ -934,7 +934,7 @@ abstract class RDD[T: ClassTag]( * * The iterator will consume as much memory as the largest partition in this RDD. * - * Note: this results in multiple Spark jobs, and if the input RDD is the result + * @note This results in multiple Spark jobs, and if the input RDD is the result * of a wide transformation (e.g. join with different partitioners), to avoid * recomputing the input RDD should be cached first. */ @@ -1182,7 +1182,7 @@ abstract class RDD[T: ClassTag]( /** * Return the count of each unique value in this RDD as a local map of (value, count) pairs. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. @@ -1272,7 +1272,7 @@ abstract class RDD[T: ClassTag]( * This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type. * This method needs to trigger a spark job when this RDD contains more than one partitions. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The index assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1286,7 +1286,7 @@ abstract class RDD[T: ClassTag]( * 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method * won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]]. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The unique ID assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1305,10 +1305,10 @@ abstract class RDD[T: ClassTag]( * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * - * @note due to complications in the internal implementation, this method will raise + * @note Due to complications in the internal implementation, this method will raise * an exception if called on an RDD of `Nothing` or `Null`. */ def take(num: Int): Array[T] = withScope { @@ -1370,7 +1370,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of top elements to return @@ -1393,7 +1393,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of elements to return @@ -1438,7 +1438,7 @@ abstract class RDD[T: ClassTag]( } /** - * @note due to complications in the internal implementation, this method will raise an + * @note Due to complications in the internal implementation, this method will raise an * exception if called on an RDD of `Nothing` or `Null`. This may be come up in practice * because, for example, the type of `parallelize(Seq())` is `RDD[Nothing]`. * (`parallelize(Seq())` should be avoided anyway in favor of `parallelize(Seq[T]())`.) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 429514b4f6be..1070bb96b252 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -32,7 +32,7 @@ private[spark] object CheckpointState extends Enumeration { /** * This class contains all the information related to RDD checkpointing. Each instance of this - * class is associated with a RDD. It manages process of checkpointing of the associated RDD, + * class is associated with an RDD. It manages process of checkpointing of the associated RDD, * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index 9f800e3a0953..e0a29b48314f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -151,7 +151,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { } /** - * Write a RDD partition's data to a checkpoint file. + * Write an RDD partition's data to a checkpoint file. */ def writePartitionToCheckpointFile[T: ClassTag]( path: String, diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 1311b481c7c7..86a332790fb0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -27,9 +27,10 @@ import org.apache.spark.internal.Logging /** * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, - * through an implicit conversion. Note that this can't be part of PairRDDFunctions because - * we need more implicit parameters to convert our keys and values to Writable. + * through an implicit conversion. * + * @note This can't be part of PairRDDFunctions because we need more implicit parameters to + * convert our keys and values to Writable. */ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( self: RDD[(K, V)], diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index b0e5ba0865c6..8425b211d6ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -29,7 +29,7 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) } /** - * Represents a RDD zipped with its element indices. The ordering is first based on the partition + * Represents an RDD zipped with its element indices. The ordering is first based on the partition * index and then the ordering of items within each partition. So the first item in the first * partition gets index 0, and the last item in the last partition receives the largest index. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index cedacad44afe..0a5fe5a1d3ee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -24,11 +24,6 @@ import org.apache.spark.annotation.DeveloperApi * :: DeveloperApi :: * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. * - * Note: once this is JSON serialized the types of `update` and `value` will be lost and be - * cast to strings. This is because the user can define an accumulator of any type and it will - * be difficult to preserve the type in consumers of the event log. This does not apply to - * internal accumulators that represent task level metrics. - * * @param id accumulator ID * @param name accumulator name * @param update partial value from a task, may be None if used on driver to describe a stage @@ -36,6 +31,11 @@ import org.apache.spark.annotation.DeveloperApi * @param internal whether this accumulator was internal * @param countFailedValues whether to count this accumulator's partial value if the task failed * @param metadata internal metadata associated with this accumulator, if any + * + * @note Once this is JSON serialized the types of `update` and `value` will be lost and be + * cast to strings. This is because the user can define an accumulator of any type and it will + * be difficult to preserve the type in consumers of the event log. This does not apply to + * internal accumulators that represent task level metrics. */ @DeveloperApi case class AccumulableInfo private[spark] ( diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 8b72da2ee01b..f60dcfddfdc2 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -131,7 +131,7 @@ private[spark] class JavaSerializerInstance( * :: DeveloperApi :: * A Spark serializer that uses Java's built-in serialization. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 0d26281fe107..19e020c968a9 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.collection.CompactBuffer /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index cb95246d5b0c..afe6cd86059f 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.NextIterator * * 2. Java serialization interface. * - * Note that serializers are not required to be wire-compatible across different versions of Spark. + * @note Serializers are not required to be wire-compatible across different versions of Spark. * They are intended to be used to serialize/de-serialize data within a single Spark application. */ @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index fb9941bbd9e0..e12f2e6095d5 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -71,7 +71,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * contains, get, and size. */ @@ -80,7 +80,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the RDD blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * getting the memory, disk, and off-heap memory sizes occupied by this RDD. */ @@ -128,7 +128,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return whether the given block is stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.contains`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.contains`, which is O(blocks) time. */ def containsBlock(blockId: BlockId): Boolean = { blockId match { @@ -141,7 +142,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the given block stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.get`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.get`, which is O(blocks) time. */ def getBlock(blockId: BlockId): Option[BlockStatus] = { blockId match { @@ -154,19 +156,22 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the number of blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.blocks.size`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.size`, which is O(blocks) time. */ def numBlocks: Int = _nonRddBlocks.size + numRddBlocks /** * Return the number of RDD blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. + * + * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. */ def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum /** * Return the number of blocks that belong to the given RDD in O(1) time. - * Note that this is much faster than `this.rddBlocksById(rddId).size`, which is + * + * @note This is much faster than `this.rddBlocksById(rddId).size`, which is * O(blocks in this RDD) time. */ def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index d3ddd3913132..1326f0977c24 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -59,8 +59,9 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } /** - * Returns true if this accumulator has been registered. Note that all accumulators must be - * registered before use, or it will throw exception. + * Returns true if this accumulator has been registered. + * + * @note All accumulators must be registered before use, or it will throw exception. */ final def isRegistered: Boolean = metadata != null && AccumulatorContext.get(metadata.id).isDefined diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index bec95d13d193..5e8a854e46a0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2076,7 +2076,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } /** - * Checks the DAGScheduler's internal logic for traversing a RDD DAG by making sure that + * Checks the DAGScheduler's internal logic for traversing an RDD DAG by making sure that * getShuffleDependencies correctly returns the direct shuffle dependencies of a particular * RDD. The test creates the following RDD graph (where n denotes a narrow dependency and s * denotes a shuffle dependency): diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index d90905a86ade..ca84551506b2 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -27,7 +27,7 @@ best fitting the original data points. [pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). -The training input is a RDD of tuples of three double values that represent +The training input is an RDD of tuples of three double values that represent label, feature and weight in this order. Additionally IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 0b0315b36650..18fc1cd93482 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2191,7 +2191,7 @@ consistent batch processing times. Make sure you set the CMS GC on both the driv - When data is received from a stream source, receiver creates blocks of data. A new block of data is generated every blockInterval milliseconds. N blocks of data are created during the batchInterval where N = batchInterval/blockInterval. These blocks are distributed by the BlockManager of the current executor to the block managers of other executors. After that, the Network Input Tracker running on the driver is informed about the block locations for further processing. -- A RDD is created on the driver for the blocks created during the batchInterval. The blocks generated during the batchInterval are partitions of the RDD. Each partition is a task in spark. blockInterval== batchinterval would mean that a single partition is created and probably it is processed locally. +- An RDD is created on the driver for the blocks created during the batchInterval. The blocks generated during the batchInterval are partitions of the RDD. Each partition is a task in spark. blockInterval== batchinterval would mean that a single partition is created and probably it is processed locally. - The map tasks on the blocks are processed in the executors (one that received the block, and another where the block was replicated) that has the blocks irrespective of block interval, unless non-local scheduling kicks in. Having bigger blockinterval means bigger blocks. A high value of `spark.locality.wait` increases the chance of processing a block on the local node. A balance needs to be found out between these two parameters to ensure that the bigger blocks are processed locally. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 5bcc5124b091..341081a338c0 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -279,7 +279,7 @@ private[kafka010] case class KafkaSource( } }.toArray - // Create a RDD that reads from Kafka and get the (key, value) pair as byte arrays. + // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. val rdd = new KafkaSourceRDD( sc, executorKafkaParams, offsetRanges, pollTimeoutMs).map { cr => Row(cr.key, cr.value, cr.topic, cr.partition, cr.offset, cr.timestamp, cr.timestampType.id) diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index b17e19807794..56f0cb0b166a 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -223,7 +223,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. + * Create an RDD from Kafka using offset ranges for each topic and partition. * * @param sc SparkContext object * @param kafkaParams Kafka @@ -255,7 +255,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * Create an RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. * @@ -303,7 +303,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. + * Create an RDD from Kafka using offset ranges for each topic and partition. * * @param jsc JavaSparkContext object * @param kafkaParams Kafka @@ -340,7 +340,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * Create an RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. * diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index a0007d33d625..b2daffa34ccb 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -33,10 +33,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -57,6 +53,10 @@ object KinesisUtils { * StorageLevel.MEMORY_AND_DISK_2 is recommended. * @param messageHandler A custom message handler that can generate a generic output from a * Kinesis `Record`, which contains both message data, and metadata. + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream[T: ClassTag]( ssc: StreamingContext, @@ -81,10 +81,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -107,6 +103,9 @@ object KinesisUtils { * Kinesis `Record`, which contains both message data, and metadata. * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off def createStream[T: ClassTag]( @@ -134,10 +133,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -156,6 +151,10 @@ object KinesisUtils { * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream( ssc: StreamingContext, @@ -178,10 +177,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -202,6 +197,9 @@ object KinesisUtils { * StorageLevel.MEMORY_AND_DISK_2 is recommended. * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ def createStream( ssc: StreamingContext, @@ -225,10 +223,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -250,6 +244,10 @@ object KinesisUtils { * @param messageHandler A custom message handler that can generate a generic output from a * Kinesis `Record`, which contains both message data, and metadata. * @param recordClass Class of the records in DStream + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream[T]( jssc: JavaStreamingContext, @@ -272,10 +270,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -299,6 +293,9 @@ object KinesisUtils { * @param recordClass Class of the records in DStream * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off def createStream[T]( @@ -326,10 +323,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -348,6 +341,10 @@ object KinesisUtils { * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream( jssc: JavaStreamingContext, @@ -367,10 +364,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -391,6 +384,9 @@ object KinesisUtils { * StorageLevel.MEMORY_AND_DISK_2 is recommended. * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ def createStream( jssc: JavaStreamingContext, diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 905c33834df1..a4d81a680979 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -221,7 +221,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) assert(collectedData.toSet === testData.toSet) // Verify that the block fetching is skipped when isBlockValid is set to false. - // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // This is done by using an RDD whose data is only in memory but is set to skip block fetching // Using that RDD will throw exception, as it skips block fetching even if the blocks are in // in BlockManager. if (testIsBlockValid) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index e18831382d4d..381011009999 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -42,7 +42,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( @transient override val edges: EdgeRDDImpl[ED, VD] = replicatedVertexView.edges - /** Return a RDD that brings edges together with their source and destination vertices. */ + /** Return an RDD that brings edges together with their source and destination vertices. */ @transient override lazy val triplets: RDD[EdgeTriplet[VD, ED]] = { replicatedVertexView.upgrade(vertices, true, true) replicatedVertexView.edges.partitionsRDD.mapPartitions(_.flatMap { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index c0c3c73463aa..f926984aa633 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -58,7 +58,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of * neighbors which link to `i` and `outDeg[j]` is the out degree of vertex `j`. * - * Note that this is not the "normalized" PageRank and as a consequence pages that have no + * @note This is not the "normalized" PageRank and as a consequence pages that have no * inlinks will have a PageRank of alpha. */ object PageRank extends Logging { diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 2e4a58dc6291..22e4ec693b1f 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -30,7 +30,7 @@ import org.apache.spark.annotation.Since /** * Represents a numeric vector, whose index type is Int and value type is Double. * - * Note: Users should not implement this interface. + * @note Users should not implement this interface. */ @Since("2.0.0") sealed trait Vector extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 252acc156583..c581fed17727 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -30,7 +30,7 @@ import org.apache.spark.ml.param.ParamMap abstract class Model[M <: Model[M]] extends Transformer { /** * The parent estimator that produced this model. - * Note: For ensembles' component Models, this value can be null. + * @note For ensembles' component Models, this value can be null. */ @transient var parent: Estimator[M] = _ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index bb192ab5f25a..7424031ed460 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -207,9 +207,9 @@ class DecisionTreeClassificationModel private[ml] ( * where gain is scaled by the number of instances passing through node * - Normalize importances for tree to sum to 1. * - * Note: Feature importance for single decision trees can have high variance due to - * correlated predictor variables. Consider using a [[RandomForestClassifier]] - * to determine feature importance instead. + * @note Feature importance for single decision trees can have high variance due to + * correlated predictor variables. Consider using a [[RandomForestClassifier]] + * to determine feature importance instead. */ @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index f8f164e8c14b..52f93f5a6b34 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -43,7 +43,6 @@ import org.apache.spark.sql.types.DoubleType * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) * learning algorithm for classification. * It supports binary labels, as well as both continuous and categorical features. - * Note: Multiclass labels are not currently supported. * * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. * @@ -54,6 +53,8 @@ import org.apache.spark.sql.types.DoubleType * based on the loss function, whereas the original gradient boosting method does not. * - We expect to implement TreeBoost in the future: * [https://issues.apache.org/jira/browse/SPARK-4240] + * + * @note Multiclass labels are not currently supported. */ @Since("1.4.0") class GBTClassifier @Since("1.4.0") ( @@ -169,10 +170,11 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) * model for classification. * It supports binary labels, as well as both continuous and categorical features. - * Note: Multiclass labels are not currently supported. * * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. + * + * @note Multiclass labels are not currently supported. */ @Since("1.6.0") class GBTClassificationModel private[ml]( diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 18b9b3043db8..71a7fe53c15f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1191,8 +1191,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. * See http://en.wikipedia.org/wiki/Receiver_operating_characteristic * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") @@ -1200,8 +1200,8 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * Computes the area under the receiver operating characteristic (ROC) curve. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() @@ -1210,8 +1210,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns the precision-recall curve, which is a Dataframe containing * two fields recall, precision with (0.0, 1.0) prepended to it. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") @@ -1219,8 +1219,8 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val fMeasureByThreshold: DataFrame = { @@ -1232,8 +1232,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the precision. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val precisionByThreshold: DataFrame = { @@ -1245,8 +1245,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the recall. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val recallByThreshold: DataFrame = { @@ -1401,18 +1401,18 @@ class BinaryLogisticRegressionSummary private[classification] ( * $$ *

* - * @note In order to avoid unnecessary computation during calculation of the gradient updates - * we lay out the coefficients in column major order during training. This allows us to - * perform feature standardization once, while still retaining sequential memory access - * for speed. We convert back to row major order when we create the model, - * since this form is optimal for the matrix operations used for prediction. - * * @param bcCoefficients The broadcast coefficients corresponding to the features. * @param bcFeaturesStd The broadcast standard deviation values of the features. * @param numClasses the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * @param fitIntercept Whether to fit an intercept term. * @param multinomial Whether to use multinomial (softmax) or binary loss + * + * @note In order to avoid unnecessary computation during calculation of the gradient updates + * we lay out the coefficients in column major order during training. This allows us to + * perform feature standardization once, while still retaining sequential memory access + * for speed. We convert back to row major order when we create the model, + * since this form is optimal for the matrix operations used for prediction. */ private class LogisticAggregator( bcCoefficients: Broadcast[Vector], diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index a0bd66e731a1..c6035cc4c964 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -268,9 +268,9 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { * While this process is generally guaranteed to converge, it is not guaranteed * to find a global optimum. * - * Note: For high-dimensional data (with many features), this algorithm may perform poorly. - * This is due to high-dimensional data (a) making it difficult to cluster at all (based - * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. + * @note For high-dimensional data (with many features), this algorithm may perform poorly. + * This is due to high-dimensional data (a) making it difficult to cluster at all (based + * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. */ @Since("2.0.0") @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 28cbe1cb01e9..ccfb0ce8f85c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -85,7 +85,8 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H *

* * For the case $E_{max} == E_{min}$, $Rescaled(e_i) = 0.5 * (max + min)$. - * Note that since zero values will probably be transformed to non-zero values, output of the + * + * @note Since zero values will probably be transformed to non-zero values, output of the * transformer will be DenseVector even for sparse input. */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index e8e28ba29c84..ea401216aec7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -36,7 +36,8 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]] * because it makes the vector entries sum up to one, and hence linearly dependent. * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. - * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories. + * + * @note This is different from scikit-learn's OneHotEncoder, which keeps all categories. * The output vectors are sparse. * * @see [[StringIndexer]] for converting categorical values into category indices diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 1e49352b8517..6e08bf059124 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -142,8 +142,9 @@ class PCAModel private[ml] ( /** * Transform a vector by computed Principal Components. - * NOTE: Vectors to be transformed must be the same length - * as the source vectors given to [[PCA.fit()]]. + * + * @note Vectors to be transformed must be the same length as the source vectors given + * to [[PCA.fit()]]. */ @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 666070037cdd..0ced21365ff6 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -28,7 +28,10 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructType} /** * A feature transformer that filters out stop words from input. - * Note: null values from input array are preserved unless adding null to stopWords explicitly. + * + * @note null values from input array are preserved unless adding null to stopWords + * explicitly. + * * @see [[http://en.wikipedia.org/wiki/Stop_words]] */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 80fe46796f80..8b155f00017c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -113,11 +113,11 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { /** * Model fitted by [[StringIndexer]]. * - * NOTE: During transformation, if the input column does not exist, + * @param labels Ordered list of labels, corresponding to indices to be assigned. + * + * @note During transformation, if the input column does not exist, * [[StringIndexerModel.transform]] would return the input dataset unmodified. * This is a temporary fix for the case when target labels do not exist during prediction. - * - * @param labels Ordered list of labels, corresponding to indices to be assigned. */ @Since("1.4.0") class StringIndexerModel ( diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 9245931b27ca..96206e0b7ad8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -533,7 +533,7 @@ trait Params extends Identifiable with Serializable { * Returns all params sorted by their names. The default implementation uses Java reflection to * list all public methods that have no arguments and return [[Param]]. * - * Note: Developer should not use this method in constructor because we cannot guarantee that + * @note Developer should not use this method in constructor because we cannot guarantee that * this variable gets initialized before other params. */ lazy val params: Array[Param[_]] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index ebc6c12ddcf9..1419da874709 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -207,9 +207,9 @@ class DecisionTreeRegressionModel private[ml] ( * where gain is scaled by the number of instances passing through node * - Normalize importances for tree to sum to 1. * - * Note: Feature importance for single decision trees can have high variance due to - * correlated predictor variables. Consider using a [[RandomForestRegressor]] - * to determine feature importance instead. + * @note Feature importance for single decision trees can have high variance due to + * correlated predictor variables. Consider using a [[RandomForestRegressor]] + * to determine feature importance instead. */ @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 1d2961e0277f..736fd3b9e0f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -879,8 +879,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( * Private copy of model to ensure Params are not modified outside this class. * Coefficients is not a deep copy, but that is acceptable. * - * NOTE: [[predictionCol]] must be set correctly before the value of [[model]] is set, - * and [[model]] must be set before [[predictions]] is set! + * @note [[predictionCol]] must be set correctly before the value of [[model]] is set, + * and [[model]] must be set before [[predictions]] is set! */ protected val model: GeneralizedLinearRegressionModel = origModel.copy(ParamMap.empty).setPredictionCol(predictionCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 71c542adf6f6..da7ce6b46f2a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -103,11 +103,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String /** * Whether to standardize the training features before fitting the model. * The coefficients of models will be always returned on the original scale, - * so it will be transparent for users. Note that with/without standardization, - * the models should be always converged to the same solution when no regularization - * is applied. In R's GLMNET package, the default behavior is true as well. + * so it will be transparent for users. * Default is true. * + * @note With/without standardization, the models should be always converged + * to the same solution when no regularization is applied. In R's GLMNET package, + * the default behavior is true as well. + * * @group setParam */ @Since("1.5.0") @@ -624,8 +626,8 @@ class LinearRegressionSummary private[regression] ( * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val explainedVariance: Double = metrics.explainedVariance @@ -634,8 +636,8 @@ class LinearRegressionSummary private[regression] ( * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanAbsoluteError: Double = metrics.meanAbsoluteError @@ -644,8 +646,8 @@ class LinearRegressionSummary private[regression] ( * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanSquaredError: Double = metrics.meanSquaredError @@ -654,8 +656,8 @@ class LinearRegressionSummary private[regression] ( * Returns the root mean squared error, which is defined as the square root of * the mean squared error. * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val rootMeanSquaredError: Double = metrics.rootMeanSquaredError @@ -664,8 +666,8 @@ class LinearRegressionSummary private[regression] ( * Returns R^2^, the coefficient of determination. * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val r2: Double = metrics.r2 diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala index 73d813064dec..e1376927030e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.{DataFrame, DataFrameReader} * inconsistent feature dimensions. * - "vectorType": feature vector type, "sparse" (default) or "dense". * - * Note that this class is public for documentation purpose. Please don't use this class directly. + * @note This class is public for documentation purpose. Please don't use this class directly. * Rather, use the data source API as illustrated above. * * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index ede0a060eef9..0a0bc4c00638 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -98,7 +98,7 @@ private[spark] object GradientBoostedTrees extends Logging { * @param initTreeWeight: learning rate assigned to the first tree. * @param initTree: first DecisionTreeModel. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to every sample. */ def computeInitialPredictionAndError( @@ -121,7 +121,7 @@ private[spark] object GradientBoostedTrees extends Logging { * @param treeWeight: Learning rate. * @param tree: Tree using which the prediction and error should be updated. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to each sample. */ def updatePredictionError( diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index bc4f9e6716ee..e5fa5d53e3fc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -221,7 +221,7 @@ trait MLReadable[T] { /** * Reads an ML instance from the input path, a shortcut of `read.load(path)`. * - * Note: Implementing classes should override this to be Java-friendly. + * @note Implementing classes should override this to be Java-friendly. */ @Since("1.6.0") def load(path: String): T = read.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index d851b983349c..4b650000736e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -202,9 +202,11 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] { * Train a classification model for Binary Logistic Regression * using Stochastic Gradient Descent. By default L2 regularization is used, * which can be changed via `LogisticRegressionWithSGD.optimizer`. - * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} - * for k classes multi-label classification problem. + * * Using [[LogisticRegressionWithLBFGS]] is recommended over this. + * + * @note Labels used in Logistic Regression should be {0, 1, ..., k - 1} + * for k classes multi-label classification problem. */ @Since("0.8.0") class LogisticRegressionWithSGD private[mllib] ( @@ -239,7 +241,8 @@ class LogisticRegressionWithSGD private[mllib] ( /** * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. - * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("0.8.0") @deprecated("Use ml.classification.LogisticRegression or LogisticRegressionWithLBFGS", "2.0.0") @@ -252,7 +255,6 @@ object LogisticRegressionWithSGD { * number of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in * gradient descent are initialized using the initial weights provided. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. @@ -260,6 +262,8 @@ object LogisticRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -276,13 +280,13 @@ object LogisticRegressionWithSGD { * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed * number of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -298,13 +302,13 @@ object LogisticRegressionWithSGD { * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed * number of iterations of gradient descent using the specified step size. We use the entire data * set to update the gradient in each iteration. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -318,11 +322,12 @@ object LogisticRegressionWithSGD { * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed * number of iterations of gradient descent using a step size of 1.0. We use the entire data set * to update the gradient in each iteration. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -335,8 +340,6 @@ object LogisticRegressionWithSGD { /** * Train a classification model for Multinomial/Binary Logistic Regression using * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default. - * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} - * for k classes multi-label classification problem. * * Earlier implementations of LogisticRegressionWithLBFGS applies a regularization * penalty to all elements including the intercept. If this is called with one of @@ -344,6 +347,9 @@ object LogisticRegressionWithSGD { * into a call to ml.LogisticRegression, otherwise this will use the existing mllib * GeneralizedLinearAlgorithm trainer, resulting in a regularization penalty to the * intercept. + * + * @note Labels used in Logistic Regression should be {0, 1, ..., k - 1} + * for k classes multi-label classification problem. */ @Since("1.1.0") class LogisticRegressionWithLBFGS diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 7c3ccbb40b81..aec1526b55c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -125,7 +125,8 @@ object SVMModel extends Loader[SVMModel] { /** * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. By default L2 * regularization is used, which can be changed via [[SVMWithSGD.optimizer]]. - * NOTE: Labels used in SVM should be {0, 1}. + * + * @note Labels used in SVM should be {0, 1}. */ @Since("0.8.0") class SVMWithSGD private ( @@ -158,7 +159,9 @@ class SVMWithSGD private ( } /** - * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}. + * Top-level methods for calling SVM. + * + * @note Labels used in SVM should be {0, 1}. */ @Since("0.8.0") object SVMWithSGD { @@ -169,8 +172,6 @@ object SVMWithSGD { * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in * gradient descent are initialized using the initial weights provided. * - * NOTE: Labels used in SVM should be {0, 1}. - * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @param stepSize Step size to be used for each iteration of gradient descent. @@ -178,6 +179,8 @@ object SVMWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * + * @note Labels used in SVM should be {0, 1}. */ @Since("0.8.0") def train( @@ -195,7 +198,8 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. - * NOTE: Labels used in SVM should be {0, 1} + * + * @note Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. @@ -217,13 +221,14 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using the specified step size. We use the entire data set to * update the gradient in each iteration. - * NOTE: Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param stepSize Step size to be used for each iteration of Gradient Descent. * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. + * + * @note Labels used in SVM should be {0, 1} */ @Since("0.8.0") def train( @@ -238,11 +243,12 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using a step size of 1.0. We use the entire data set to * update the gradient in each iteration. - * NOTE: Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. + * + * @note Labels used in SVM should be {0, 1} */ @Since("0.8.0") def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 43193adf3e18..56cdeea5f7a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -41,14 +41,14 @@ import org.apache.spark.util.Utils * While this process is generally guaranteed to converge, it is not guaranteed * to find a global optimum. * - * Note: For high-dimensional data (with many features), this algorithm may perform poorly. - * This is due to high-dimensional data (a) making it difficult to cluster at all (based - * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. - * * @param k Number of independent Gaussians in the mixture model. * @param convergenceTol Maximum change in log-likelihood at which convergence * is considered to have occurred. * @param maxIterations Maximum number of iterations allowed. + * + * @note For high-dimensional data (with many features), this algorithm may perform poorly. + * This is due to high-dimensional data (a) making it difficult to cluster at all (based + * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. */ @Since("1.3.0") class GaussianMixture private ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index ed9c064879d0..fa72b72e2d92 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -56,14 +56,18 @@ class KMeans private ( def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) /** - * Number of clusters to create (k). Note that it is possible for fewer than k clusters to + * Number of clusters to create (k). + * + * @note It is possible for fewer than k clusters to * be returned, for example, if there are fewer than k distinct points to cluster. */ @Since("1.4.0") def getK: Int = k /** - * Set the number of clusters to create (k). Note that it is possible for fewer than k clusters to + * Set the number of clusters to create (k). + * + * @note It is possible for fewer than k clusters to * be returned, for example, if there are fewer than k distinct points to cluster. Default: 2. */ @Since("0.8.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index d999b9be8e8a..7c52abdeaac2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -175,7 +175,7 @@ class LDA private ( * * This is the parameter to a symmetric Dirichlet distribution. * - * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * @note The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ @Since("1.3.0") @@ -187,7 +187,7 @@ class LDA private ( * * This is the parameter to a symmetric Dirichlet distribution. * - * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * @note The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. * * If set to -1, then topicConcentration is set automatically. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 90d8a558f10d..b5b0e64a2a6c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -66,7 +66,7 @@ abstract class LDAModel private[clustering] extends Saveable { * * This is the parameter to a symmetric Dirichlet distribution. * - * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * @note The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index ae324f86fe6d..7365ea1f200d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -93,9 +93,11 @@ final class EMLDAOptimizer extends LDAOptimizer { /** * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up). * Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with - * care. Note that checkpoints will be cleaned up via reference counting, regardless. + * care. * * Default: true + * + * @note Checkpoints will be cleaned up via reference counting, regardless. */ @Since("2.0.0") def setKeepLastCheckpoint(keepLastCheckpoint: Boolean): this.type = { @@ -348,7 +350,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * Mini-batch fraction in (0, 1], which sets the fraction of document sampled and used in * each iteration. * - * Note that this should be adjusted in synch with [[LDA.setMaxIterations()]] + * @note This should be adjusted in synch with [[LDA.setMaxIterations()]] * so the entire corpus is used. Specifically, set both so that * maxIterations * miniBatchFraction >= 1. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index f0779491e637..003d1411a9cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -39,7 +39,7 @@ private[evaluation] object AreaUnderCurve { /** * Returns the area under the given curve. * - * @param curve a RDD of ordered 2D points stored in pairs representing a curve + * @param curve an RDD of ordered 2D points stored in pairs representing a curve */ def of(curve: RDD[(Double, Double)]): Double = { curve.sliding(2).aggregate(0.0)( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index fbd217af74ec..c94d7890cf55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.types._ /** * Represents a numeric vector, whose index type is Int and value type is Double. * - * Note: Users should not implement this interface. + * @note Users should not implement this interface. */ @SQLUserDefinedType(udt = classOf[VectorUDT]) @Since("1.0.0") @@ -132,7 +132,9 @@ sealed trait Vector extends Serializable { /** * Number of active entries. An "active entry" is an element which is explicitly stored, - * regardless of its value. Note that inactive entries have value 0. + * regardless of its value. + * + * @note Inactive entries have value 0. */ @Since("1.4.0") def numActives: Int diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 377be6bfb988..03866753b50e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -451,7 +451,7 @@ class BlockMatrix @Since("1.3.0") ( * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause * some performance issues until support for multiplying two sparse matrices is added. * - * Note: The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when + * @note The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when * there were blocks with duplicate indices. Now, the blocks with duplicate indices will be added * with each other. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index b03b3ecde94f..809906a15833 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -188,8 +188,9 @@ class IndexedRowMatrix @Since("1.0.0") ( } /** - * Computes the Gramian matrix `A^T A`. Note that this cannot be - * computed on matrices with more than 65535 columns. + * Computes the Gramian matrix `A^T A`. + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.0.0") def computeGramianMatrix(): Matrix = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index ec32e37afb79..4b120332ab8d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -106,8 +106,9 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the Gramian matrix `A^T A`. Note that this cannot be computed on matrices with - * more than 65535 columns. + * Computes the Gramian matrix `A^T A`. + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.0.0") def computeGramianMatrix(): Matrix = { @@ -168,9 +169,6 @@ class RowMatrix @Since("1.0.0") ( * ARPACK is set to 300 or k * 3, whichever is larger. The numerical tolerance for ARPACK's * eigen-decomposition is set to 1e-10. * - * @note The conditions that decide which method to use internally and the default parameters are - * subject to change. - * * @param k number of leading singular values to keep (0 < k <= n). * It might return less than k if * there are numerically zero singular values or there are not enough Ritz values @@ -180,6 +178,9 @@ class RowMatrix @Since("1.0.0") ( * @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0) * are treated as zero, where sigma(0) is the largest singular value. * @return SingularValueDecomposition(U, s, V). U = null if computeU = false. + * + * @note The conditions that decide which method to use internally and the default parameters are + * subject to change. */ @Since("1.0.0") def computeSVD( @@ -319,9 +320,11 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the covariance matrix, treating each row as an observation. Note that this cannot - * be computed on matrices with more than 65535 columns. + * Computes the covariance matrix, treating each row as an observation. + * * @return a local dense matrix of size n x n + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.0.0") def computeCovariance(): Matrix = { @@ -369,12 +372,12 @@ class RowMatrix @Since("1.0.0") ( * The row data do not need to be "centered" first; it is not necessary for * the mean of each column to be 0. * - * Note that this cannot be computed on matrices with more than 65535 columns. - * * @param k number of top principal components. * @return a matrix of size n-by-k, whose columns are principal components, and * a vector of values which indicate how much variance each principal component * explains + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.6.0") def computePrincipalComponentsAndExplainedVariance(k: Int): (Matrix, Vector) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 81e64de4e5b5..c49e72646bf1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -305,7 +305,8 @@ class LeastSquaresGradient extends Gradient { * :: DeveloperApi :: * Compute gradient and loss for a Hinge loss function, as used in SVM binary classification. * See also the documentation for the precise formulation. - * NOTE: This assumes that the labels are {0,1} + * + * @note This assumes that the labels are {0,1} */ @DeveloperApi class HingeGradient extends Gradient { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 0f7857b8d862..005119616f06 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { /** - * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * Returns an RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding * window over them. The ordering is first based on the partition index and then the ordering of * items within each partition. This is similar to sliding in Scala collections, except that it * becomes an empty RDD if the window size is greater than the total number of items. It needs to diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index c642573ccba6..24e4dcccc843 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -43,14 +43,14 @@ import org.apache.spark.storage.StorageLevel /** * Model representing the result of matrix factorization. * - * Note: If you create the model directly using constructor, please be aware that fast prediction - * requires cached user/product features and their associated partitioners. - * * @param rank Rank for the features in this model. * @param userFeatures RDD of tuples where each tuple represents the userId and * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. + * + * @note If you create the model directly using constructor, please be aware that fast prediction + * requires cached user/product features and their associated partitioners. */ @Since("0.8.0") class MatrixFactorizationModel @Since("0.8.0") ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index f3159f7e724c..925fdf4d7e7b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -60,15 +60,15 @@ object Statistics { * Compute the correlation matrix for the input RDD of Vectors using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * - * Note that for Spearman, a rank correlation, we need to create an RDD[Double] for each column - * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], - * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to - * avoid recomputing the common lineage. - * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. + * + * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column + * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to + * avoid recomputing the common lineage. */ @Since("1.1.0") def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) @@ -77,12 +77,12 @@ object Statistics { * Compute the Pearson correlation for the input RDDs. * Returns NaN if either vector has 0 variance. * - * Note: the two input RDDs need to have the same number of partitions and the same number of - * elements in each partition. - * * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s + * + * @note The two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. */ @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) @@ -98,15 +98,15 @@ object Statistics { * Compute the correlation for the input RDDs using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * - * Note: the two input RDDs need to have the same number of partitions and the same number of - * elements in each partition. - * * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. + * + * @note The two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. */ @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) @@ -122,15 +122,15 @@ object Statistics { * Conduct Pearson's chi-squared goodness of fit test of the observed data against the * expected distribution. * - * Note: the two input Vectors need to have the same size. - * `observed` cannot contain negative values. - * `expected` cannot contain nonpositive values. - * * @param observed Vector containing the observed categorical counts/relative frequencies. * @param expected Vector containing the expected categorical counts/relative frequencies. * `expected` is rescaled if the `expected` sum differs from the `observed` sum. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * + * @note The two input Vectors need to have the same size. + * `observed` cannot contain negative values. + * `expected` cannot contain nonpositive values. */ @Since("1.1.0") def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { @@ -141,11 +141,11 @@ object Statistics { * Conduct Pearson's chi-squared goodness of fit test of the observed data against the uniform * distribution, with each category having an expected frequency of `1 / observed.size`. * - * Note: `observed` cannot contain negative values. - * * @param observed Vector containing the observed categorical counts/relative frequencies. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * + * @note `observed` cannot contain negative values. */ @Since("1.1.0") def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 36feab7859b4..d846c43cf291 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -75,10 +75,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -86,6 +82,10 @@ object DecisionTree extends Serializable with Logging { * of decision tree (classification or regression), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.0.0") def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { @@ -96,10 +96,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -108,6 +104,10 @@ object DecisionTree extends Serializable with Logging { * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means * 1 internal node + 2 leaf nodes). * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.0.0") def train( @@ -123,10 +123,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -136,6 +132,10 @@ object DecisionTree extends Serializable with Logging { * 1 internal node + 2 leaf nodes). * @param numClasses Number of classes for classification. Default value of 2. * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.2.0") def train( @@ -152,10 +152,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -170,6 +166,10 @@ object DecisionTree extends Serializable with Logging { * indicates that feature n is categorical with k categories * indexed from 0: {0, 1, ..., k-1}. * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.0.0") def train( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index de14ddf024d7..09274a2e1b2a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -42,11 +42,13 @@ trait Loss extends Serializable { /** * Method to calculate error of the base learner for the gradient boosting calculation. - * Note: This method is not used by the gradient boosting algorithm but is useful for debugging - * purposes. + * * @param model Model of the weak learner. * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return Measure of model error on data + * + * @note This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. */ @Since("1.2.0") def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { @@ -55,11 +57,13 @@ trait Loss extends Serializable { /** * Method to calculate loss when the predictions are already known. - * Note: This method is used in the method evaluateEachIteration to avoid recomputing the - * predicted values from previously fit trees. + * * @param prediction Predicted label. * @param label True label. * @return Measure of model error on datapoint. + * + * @note This method is used in the method evaluateEachIteration to avoid recomputing the + * predicted values from previously fit trees. */ private[spark] def computeError(prediction: Double, label: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 657ed0a8ecda..299950785e42 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -187,7 +187,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @param initTreeWeight: learning rate assigned to the first tree. * @param initTree: first DecisionTreeModel. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to every sample. */ @Since("1.4.0") @@ -213,7 +213,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @param treeWeight: Learning rate. * @param tree: Tree using which the prediction and error should be updated. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to each sample. */ @Since("1.4.0") diff --git a/pom.xml b/pom.xml index 650b4cd965b6..024b2850d0a3 100644 --- a/pom.xml +++ b/pom.xml @@ -2476,6 +2476,13 @@ maven-javadoc-plugin -Xdoclint:all -Xdoclint:-missing + + + note + a + Note: + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2d3a95b163a7..92b45657210e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -741,7 +741,8 @@ object Unidoc { javacOptions in (JavaUnidoc, unidoc) := Seq( "-windowtitle", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " JavaDoc", "-public", - "-noqualifier", "java.lang" + "-noqualifier", "java.lang", + "-tag", """note:a:Note\:""" ), // Use GitHub repository for Scaladoc source links diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py index 3b1c5519bd87..7250eab6705a 100644 --- a/python/pyspark/mllib/stat/KernelDensity.py +++ b/python/pyspark/mllib/stat/KernelDensity.py @@ -28,7 +28,7 @@ class KernelDensity(object): """ - Estimate probability density at required points given a RDD of samples + Estimate probability density at required points given an RDD of samples from the population. >>> kd = KernelDensity() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index ed6fd4bca4c5..97755807ef26 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -499,7 +499,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, def generateLinearRDD(sc, nexamples, nfeatures, eps, nParts=2, intercept=0.0): """ - Generate a RDD of LabeledPoints. + Generate an RDD of LabeledPoints. """ return callMLlibFunc( "generateLinearRDDWrapper", sc, int(nexamples), int(nfeatures), diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a163ceafe9d3..641787ee20e0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1218,7 +1218,7 @@ def mergeMaps(m1, m2): def top(self, num, key=None): """ - Get the top N elements from a RDD. + Get the top N elements from an RDD. Note that this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory. @@ -1242,7 +1242,7 @@ def merge(a, b): def takeOrdered(self, num, key=None): """ - Get the N elements from a RDD ordered in ascending order or as + Get the N elements from an RDD ordered in ascending order or as specified by the optional key function. Note that this method should only be used if the resulting array is expected diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index bf27d8047a75..134424add3b6 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -144,7 +144,7 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, """ .. note:: Experimental - Create a RDD from Kafka using offset ranges for each topic and partition. + Create an RDD from Kafka using offset ranges for each topic and partition. :param sc: SparkContext object :param kafkaParams: Additional params for Kafka @@ -155,7 +155,7 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, :param valueDecoder: A function used to decode value (default is utf8_decoder) :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess meta using messageHandler (default is None). - :return: A RDD object + :return: An RDD object """ if leaders is None: leaders = dict() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index dc90659a676e..0b95a8821b05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -165,10 +165,10 @@ object Encoders { * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java * serialization. This encoder maps T into a single byte array (binary) field. * - * Note that this is extremely inefficient and should only be used as the last resort. - * * T must be publicly accessible. * + * @note This is extremely inefficient and should only be used as the last resort. + * * @since 1.6.0 */ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) @@ -177,10 +177,10 @@ object Encoders { * Creates an encoder that serializes objects of type T using generic Java serialization. * This encoder maps T into a single byte array (binary) field. * - * Note that this is extremely inefficient and should only be used as the last resort. - * * T must be publicly accessible. * + * @note This is extremely inefficient and should only be used as the last resort. + * * @since 1.6.0 */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala index e121044288e5..21f3497ba06f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala @@ -23,10 +23,10 @@ import org.apache.spark.annotation.InterfaceStability * The data type representing calendar time intervals. The calendar time interval is stored * internally in two components: number of months the number of microseconds. * - * Note that calendar intervals are not comparable. - * * Please use the singleton [[DataTypes.CalendarIntervalType]]. * + * @note Calendar intervals are not comparable. + * * @since 1.5.0 */ @InterfaceStability.Stable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 7a131b30eafd..fa3b2b9de5d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -118,7 +118,7 @@ class TypedColumn[-T, U]( * $"a" === $"b" * }}} * - * Note that the internal Catalyst expression can be accessed via "expr", but this method is for + * @note The internal Catalyst expression can be accessed via "expr", but this method is for * debugging purposes only and can change in any future Spark releases. * * @groupname java_expr_ops Java-specific expression operators diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index b5bbcee37150..6335fc4579a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -51,7 +51,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient * Online Computation of Quantile Summaries]] by Greenwald and Khanna. * - * Note that NaN values will be removed from the numerical column before calculation * @param col the name of the numerical column * @param probabilities a list of quantile probabilities * Each number must belong to [0, 1]. @@ -61,6 +60,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * Note that values greater than 1 are accepted but give the same result as 1. * @return the approximate quantiles at the given probabilities * + * @note NaN values will be removed from the numerical column before calculation + * * @since 2.0.0 */ def approxQuantile( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e0c89811ddbf..15281f24fa62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -218,7 +218,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * Inserts the content of the [[DataFrame]] to the specified table. It requires that * the schema of the [[DataFrame]] is the same as the schema of the table. * - * Note: Unlike `saveAsTable`, `insertInto` ignores the column names and just uses position-based + * @note Unlike `saveAsTable`, `insertInto` ignores the column names and just uses position-based * resolution. For example: * * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3761773698df..3c75a6a45ec8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -377,7 +377,7 @@ class Dataset[T] private[sql]( /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. - * This can be quite convenient in conversion from a RDD of tuples into a [[DataFrame]] with + * This can be quite convenient in conversion from an RDD of tuples into a [[DataFrame]] with * meaningful names. For example: * {{{ * val rdd: RDD[(Int, String)] = ... @@ -703,13 +703,13 @@ class Dataset[T] private[sql]( * df1.join(df2, "user_id") * }}} * - * Note that if you perform a self-join using this function without aliasing the input - * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. - * * @param right Right side of the join operation. * @param usingColumn Name of the column to join on. This column must exist on both sides. * + * @note If you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * * @group untypedrel * @since 2.0.0 */ @@ -728,13 +728,13 @@ class Dataset[T] private[sql]( * df1.join(df2, Seq("user_id", "user_name")) * }}} * - * Note that if you perform a self-join using this function without aliasing the input - * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. - * * @param right Right side of the join operation. * @param usingColumns Names of the columns to join on. This columns must exist on both sides. * + * @note If you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * * @group untypedrel * @since 2.0.0 */ @@ -748,14 +748,14 @@ class Dataset[T] private[sql]( * Different from other join functions, the join columns will only appear once in the output, * i.e. similar to SQL's `JOIN USING` syntax. * - * Note that if you perform a self-join using this function without aliasing the input - * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. - * * @param right Right side of the join operation. * @param usingColumns Names of the columns to join on. This columns must exist on both sides. * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. * + * @note If you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * * @group untypedrel * @since 2.0.0 */ @@ -856,10 +856,10 @@ class Dataset[T] private[sql]( /** * Explicit cartesian join with another [[DataFrame]]. * - * Note that cartesian joins are very expensive without an extra filter that can be pushed down. - * * @param right Right side of the join operation. * + * @note Cartesian joins are very expensive without an extra filter that can be pushed down. + * * @group untypedrel * @since 2.1.0 */ @@ -1044,7 +1044,8 @@ class Dataset[T] private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. - * Note that the column name can also reference to a nested column like `a.b`. + * + * @note The column name can also reference to a nested column like `a.b`. * * @group untypedrel * @since 2.0.0 @@ -1053,7 +1054,8 @@ class Dataset[T] private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. - * Note that the column name can also reference to a nested column like `a.b`. + * + * @note The column name can also reference to a nested column like `a.b`. * * @group untypedrel * @since 2.0.0 @@ -1621,7 +1623,7 @@ class Dataset[T] private[sql]( * Returns a new Dataset containing rows only in both this Dataset and another Dataset. * This is equivalent to `INTERSECT` in SQL. * - * Note that, equality checking is performed directly on the encoded representation of the data + * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. * * @group typedrel @@ -1635,7 +1637,7 @@ class Dataset[T] private[sql]( * Returns a new Dataset containing rows in this Dataset but not in another Dataset. * This is equivalent to `EXCEPT` in SQL. * - * Note that, equality checking is performed directly on the encoded representation of the data + * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. * * @group typedrel @@ -1648,13 +1650,13 @@ class Dataset[T] private[sql]( /** * Returns a new [[Dataset]] by sampling a fraction of rows, using a user-supplied seed. * - * Note: this is NOT guaranteed to provide exactly the fraction of the count - * of the given [[Dataset]]. - * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. * @param seed Seed for sampling. * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[Dataset]]. + * * @group typedrel * @since 1.6.0 */ @@ -1670,12 +1672,12 @@ class Dataset[T] private[sql]( /** * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed. * - * Note: this is NOT guaranteed to provide exactly the fraction of the total count - * of the given [[Dataset]]. - * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. * + * @note This is NOT guaranteed to provide exactly the fraction of the total count + * of the given [[Dataset]]. + * * @group typedrel * @since 1.6.0 */ @@ -2375,7 +2377,7 @@ class Dataset[T] private[sql]( * * The iterator will consume as much memory as the largest partition in this Dataset. * - * Note: this results in multiple Spark jobs, and if the input Dataset is the result + * @note this results in multiple Spark jobs, and if the input Dataset is the result * of a wide transformation (e.g. join with different partitioners), to avoid * recomputing the input Dataset should be cached first. * @@ -2453,7 +2455,7 @@ class Dataset[T] private[sql]( * Returns a new Dataset that contains only the unique rows from this Dataset. * This is an alias for `dropDuplicates`. * - * Note that, equality checking is performed directly on the encoded representation of the data + * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. * * @group typedrel diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 3c5cf037c578..2fae93651b34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -181,9 +181,6 @@ class SQLContext private[sql](val sparkSession: SparkSession) /** * A collection of methods for registering user-defined functions (UDF). - * Note that the user-defined functions must be deterministic. Due to optimization, - * duplicate invocations may be eliminated or the function may even be invoked more times than - * it is present in the query. * * The following example registers a Scala closure as UDF: * {{{ @@ -208,6 +205,10 @@ class SQLContext private[sql](val sparkSession: SparkSession) * DataTypes.StringType); * }}} * + * @note The user-defined functions must be deterministic. Due to optimization, + * duplicate invocations may be eliminated or the function may even be invoked more times than + * it is present in the query. + * * @group basic * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 58b2ab395717..e09e3caa3c98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -155,9 +155,6 @@ class SparkSession private( /** * A collection of methods for registering user-defined functions (UDF). - * Note that the user-defined functions must be deterministic. Due to optimization, - * duplicate invocations may be eliminated or the function may even be invoked more times than - * it is present in the query. * * The following example registers a Scala closure as UDF: * {{{ @@ -182,6 +179,10 @@ class SparkSession private( * DataTypes.StringType); * }}} * + * @note The user-defined functions must be deterministic. Due to optimization, + * duplicate invocations may be eliminated or the function may even be invoked more times than + * it is present in the query. + * * @since 2.0.0 */ def udf: UDFRegistration = sessionState.udf @@ -201,7 +202,7 @@ class SparkSession private( * Start a new session with isolated SQL configurations, temporary tables, registered * functions are isolated, but sharing the underlying [[SparkContext]] and cached data. * - * Note: Other than the [[SparkContext]], all shared state is initialized lazily. + * @note Other than the [[SparkContext]], all shared state is initialized lazily. * This method will force the initialization of the shared state to ensure that parent * and child sessions are set up with the same shared state. If the underlying catalog * implementation is Hive, this will initialize the metastore, which may take some time. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 0444ad10d34f..6043c5ee14b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -39,7 +39,8 @@ import org.apache.spark.util.Utils /** * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this. - * Note that the user-defined functions must be deterministic. + * + * @note The user-defined functions must be deterministic. * * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 4914a9d722a8..1b56c08f729c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -28,7 +28,7 @@ package object state { implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { - /** Map each partition of a RDD along with data in a [[StateStore]]. */ + /** Map each partition of an RDD along with data in a [[StateStore]]. */ def mapPartitionsWithStateStore[U: ClassTag]( sqlContext: SQLContext, checkpointLocation: String, @@ -49,7 +49,7 @@ package object state { storeUpdateFunction) } - /** Map each partition of a RDD along with data in a [[StateStore]]. */ + /** Map each partition of an RDD along with data in a [[StateStore]]. */ private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( checkpointLocation: String, operatorId: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 28598af78165..36dd5f78ac13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -25,9 +25,7 @@ import org.apache.spark.sql.types.DataType /** * A user-defined function. To create one, use the `udf` functions in [[functions]]. - * Note that the user-defined functions must be deterministic. Due to optimization, - * duplicate invocations may be eliminated or the function may even be invoked more times than - * it is present in the query. + * * As an example: * {{{ * // Defined a UDF that returns true or false based on some numeric score. @@ -37,6 +35,10 @@ import org.apache.spark.sql.types.DataType * df.select( predict(df("score")) ) * }}} * + * @note The user-defined functions must be deterministic. Due to optimization, + * duplicate invocations may be eliminated or the function may even be invoked more times than + * it is present in the query. + * * @since 1.3.0 */ @InterfaceStability.Stable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e221c032b82f..d5940c638acd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -476,7 +476,7 @@ object functions { * * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) * - * Note: the list of columns should match with grouping columns exactly, or empty (means all the + * @note The list of columns should match with grouping columns exactly, or empty (means all the * grouping columns). * * @group agg_funcs @@ -489,7 +489,7 @@ object functions { * * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) * - * Note: the list of columns should match with grouping columns exactly. + * @note The list of columns should match with grouping columns exactly. * * @group agg_funcs * @since 2.0.0 @@ -1120,7 +1120,7 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * from U[0.0, 1.0]. * - * Note that this is indeterministic when data partitions are not fixed. + * @note This is indeterministic when data partitions are not fixed. * * @group normal_funcs * @since 1.4.0 @@ -1140,7 +1140,7 @@ object functions { * Generate a column with independent and identically distributed (i.i.d.) samples from * the standard normal distribution. * - * Note that this is indeterministic when data partitions are not fixed. + * @note This is indeterministic when data partitions are not fixed. * * @group normal_funcs * @since 1.4.0 @@ -1159,7 +1159,7 @@ object functions { /** * Partition ID. * - * Note that this is indeterministic because it depends on data partitioning and task scheduling. + * @note This is indeterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs * @since 1.6.0 @@ -2207,7 +2207,7 @@ object functions { * Locate the position of the first occurrence of substr column in the given string. * Returns null if either of the arguments are null. * - * NOTE: The position is not zero based, but 1 based index. Returns 0 if substr + * @note The position is not zero based, but 1 based index. Returns 0 if substr * could not be found in str. * * @group string_funcs @@ -2242,7 +2242,8 @@ object functions { /** * Locate the position of the first occurrence of substr. - * NOTE: The position is not zero based, but 1 based index. Returns 0 if substr + * + * @note The position is not zero based, but 1 based index. Returns 0 if substr * could not be found in str. * * @group string_funcs @@ -2255,7 +2256,7 @@ object functions { /** * Locate the position of the first occurrence of substr in a string column, after position pos. * - * NOTE: The position is not zero based, but 1 based index. returns 0 if substr + * @note The position is not zero based, but 1 based index. returns 0 if substr * could not be found in str. * * @group string_funcs @@ -2369,7 +2370,8 @@ object functions { /** * Splits str around pattern (pattern is a regular expression). - * NOTE: pattern is a string representation of the regular expression. + * + * @note Pattern is a string representation of the regular expression. * * @group string_funcs * @since 1.5.0 @@ -2468,7 +2470,7 @@ object functions { * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All * pattern letters of [[java.text.SimpleDateFormat]] can be used. * - * NOTE: Use when ever possible specialized functions like [[year]]. These benefit from a + * @note Use when ever possible specialized functions like [[year]]. These benefit from a * specialized implementation. * * @group datetime_funcs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index dec316be7aea..7c64e28d2472 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -140,7 +140,7 @@ abstract class JdbcDialect extends Serializable { * tried in reverse order. A user-added dialect will thus be applied first, * overwriting the defaults. * - * Note that all new dialects are applied to new jdbc DataFrames only. Make + * @note All new dialects are applied to new jdbc DataFrames only. Make * sure to register your dialects first. */ @DeveloperApi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 15a48072525b..ff6dd8cb0cf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -69,7 +69,8 @@ trait DataSourceRegister { trait RelationProvider { /** * Returns a new base relation with the given parameters. - * Note: the parameters' keywords are case insensitive and this insensitivity is enforced + * + * @note The parameters' keywords are case insensitive and this insensitivity is enforced * by the Map that is passed to the function. */ def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation @@ -99,7 +100,8 @@ trait RelationProvider { trait SchemaRelationProvider { /** * Returns a new base relation with the given parameters and user defined schema. - * Note: the parameters' keywords are case insensitive and this insensitivity is enforced + * + * @note The parameters' keywords are case insensitive and this insensitivity is enforced * by the Map that is passed to the function. */ def createRelation( @@ -205,7 +207,7 @@ abstract class BaseRelation { * large to broadcast. This method will be called multiple times during query planning * and thus should not perform expensive operations for each invocation. * - * Note that it is always better to overestimate size than underestimate, because underestimation + * @note It is always better to overestimate size than underestimate, because underestimation * could lead to execution plans that are suboptimal (i.e. broadcasting a very large table). * * @since 1.3.0 @@ -219,7 +221,7 @@ abstract class BaseRelation { * * If `needConversion` is `false`, buildScan() should return an [[RDD]] of [[InternalRow]] * - * Note: The internal representation is not stable across releases and thus data sources outside + * @note The internal representation is not stable across releases and thus data sources outside * of Spark SQL should leave this as true. * * @since 1.4.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 5e93fc469a41..4504582187b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.QueryExecution * :: Experimental :: * The interface of query execution listener that can be used to analyze execution metrics. * - * Note that implementations should guarantee thread-safety as they can be invoked by + * @note Implementations should guarantee thread-safety as they can be invoked by * multiple different threads. */ @Experimental @@ -39,24 +39,26 @@ trait QueryExecutionListener { /** * A callback function that will be called when a query executed successfully. - * Note that this can be invoked by multiple different threads. * * @param funcName name of the action that triggered this query. * @param qe the QueryExecution object that carries detail information like logical plan, * physical plan, etc. * @param durationNs the execution time for this query in nanoseconds. + * + * @note This can be invoked by multiple different threads. */ @DeveloperApi def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit /** * A callback function that will be called when a query execution failed. - * Note that this can be invoked by multiple different threads. * * @param funcName the name of the action that triggered this query. * @param qe the QueryExecution object that carries detail information like logical plan, * physical plan, etc. * @param exception the exception that failed this query. + * + * @note This can be invoked by multiple different threads. */ @DeveloperApi def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 0daa29b666f6..b272c8e7d79c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -157,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val allColumns = fields.map(_.name).mkString(",") val schema = StructType(fields) - // Create a RDD for the schema + // Create an RDD for the schema val rdd = sparkContext.parallelize((1 to 10000), 10).map { i => Row( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 4808d0fcbc6c..444261da8de6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -421,11 +421,11 @@ class StreamingContext private[streaming] ( * by "moving" them from another location within the same file system. File names * starting with . are ignored. * - * '''Note:''' We ensure that the byte array for each record in the - * resulting RDDs of the DStream has the provided record length. - * * @param directory HDFS directory to monitor for new file * @param recordLength length of each record in bytes + * + * @note We ensure that the byte array for each record in the + * resulting RDDs of the DStream has the provided record length. */ def binaryRecordsStream( directory: String, @@ -447,12 +447,12 @@ class StreamingContext private[streaming] ( * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of - * those RDDs, so `queueStream` doesn't support checkpointing. - * * @param queue Queue of RDDs. Modifications to this data structure must be synchronized. * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD + * + * @note Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. */ def queueStream[T: ClassTag]( queue: Queue[RDD[T]], @@ -465,14 +465,14 @@ class StreamingContext private[streaming] ( * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of - * those RDDs, so `queueStream` doesn't support checkpointing. - * * @param queue Queue of RDDs. Modifications to this data structure must be synchronized. * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. * Set as null if no RDD should be returned when empty * @tparam T Type of objects in the RDD + * + * @note Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. */ def queueStream[T: ClassTag]( queue: Queue[RDD[T]], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index da9ff858853c..aa4003c62e1e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -74,7 +74,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def repartition(numPartitions: Int): JavaPairDStream[K, V] = dstream.repartition(numPartitions) - /** Method that generates a RDD for the given Duration */ + /** Method that generates an RDD for the given Duration */ def compute(validTime: Time): JavaPairRDD[K, V] = { dstream.compute(validTime) match { case Some(rdd) => new JavaPairRDD(rdd) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 4c4376a089f5..b43b9405def9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -218,11 +218,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * for new files and reads them as flat binary files with fixed record lengths, * yielding byte arrays * - * '''Note:''' We ensure that the byte array for each record in the - * resulting RDDs of the DStream has the provided record length. - * * @param directory HDFS directory to monitor for new files * @param recordLength The length at which to split the records + * + * @note We ensure that the byte array for each record in the + * resulting RDDs of the DStream has the provided record length. */ def binaryRecordsStream(directory: String, recordLength: Int): JavaDStream[Array[Byte]] = { ssc.binaryRecordsStream(directory, recordLength) @@ -352,13 +352,13 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: + * @param queue Queue of RDDs + * @tparam T Type of objects in the RDD + * + * @note * 1. Changes to the queue after the stream is created will not be recognized. * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of * those RDDs, so `queueStream` doesn't support checkpointing. - * - * @param queue Queue of RDDs - * @tparam T Type of objects in the RDD */ def queueStream[T](queue: java.util.Queue[JavaRDD[T]]): JavaDStream[T] = { implicit val cm: ClassTag[T] = @@ -372,14 +372,14 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: - * 1. Changes to the queue after the stream is created will not be recognized. - * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of - * those RDDs, so `queueStream` doesn't support checkpointing. - * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD + * + * @note + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. */ def queueStream[T]( queue: java.util.Queue[JavaRDD[T]], @@ -396,7 +396,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: + * @note * 1. Changes to the queue after the stream is created will not be recognized. * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of * those RDDs, so `queueStream` doesn't support checkpointing. @@ -454,9 +454,10 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** * Create a new DStream in which each RDD is generated by applying a function on RDDs of * the DStreams. The order of the JavaRDDs in the transform function parameter will be the - * same as the order of corresponding DStreams in the list. Note that for adding a - * JavaPairDStream in the list of JavaDStreams, convert it to a JavaDStream using - * [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). + * same as the order of corresponding DStreams in the list. + * + * @note For adding a JavaPairDStream in the list of JavaDStreams, convert it to a + * JavaDStream using [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). * In the transform function, convert the JavaRDD corresponding to that JavaDStream to * a JavaPairRDD using org.apache.spark.api.java.JavaPairRDD.fromJavaRDD(). */ @@ -476,9 +477,10 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** * Create a new DStream in which each RDD is generated by applying a function on RDDs of * the DStreams. The order of the JavaRDDs in the transform function parameter will be the - * same as the order of corresponding DStreams in the list. Note that for adding a - * JavaPairDStream in the list of JavaDStreams, convert it to a JavaDStream using - * [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). + * same as the order of corresponding DStreams in the list. + * + * @note For adding a JavaPairDStream in the list of JavaDStreams, convert it to + * a JavaDStream using [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). * In the transform function, convert the JavaRDD corresponding to that JavaDStream to * a JavaPairRDD using org.apache.spark.api.java.JavaPairRDD.fromJavaRDD(). */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 7e0a2ca609c8..e23edfa50651 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -69,13 +69,13 @@ abstract class DStream[T: ClassTag] ( // Methods that should be implemented by subclasses of DStream // ======================================================================= - /** Time interval after which the DStream generates a RDD */ + /** Time interval after which the DStream generates an RDD */ def slideDuration: Duration /** List of parent DStreams on which this DStream depends on */ def dependencies: List[DStream[_]] - /** Method that generates a RDD for the given time */ + /** Method that generates an RDD for the given time */ def compute(validTime: Time): Option[RDD[T]] // ======================================================================= diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala index ed08191f41cc..9512db7d7d75 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala @@ -128,7 +128,7 @@ class InternalMapWithStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: Clas super.initialize(time) } - /** Method that generates a RDD for the given time */ + /** Method that generates an RDD for the given time */ override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD val prevStateRDD = getOrCompute(validTime - slideDuration) match { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index ce5a6e00fb2f..a37fac87300b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -186,7 +186,7 @@ class WriteAheadLogBackedBlockRDDSuite assert(rdd.collect() === data.flatten) // Verify that the block fetching is skipped when isBlockValid is set to false. - // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // This is done by using an RDD whose data is only in memory but is set to skip block fetching // Using that RDD will throw exception, as it skips block fetching even if the blocks are in // in BlockManager. if (testIsBlockValid) { From 8b1e1088eb274fb15260cd5d6d9508d42837a4d6 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 19 Nov 2016 11:28:25 +0000 Subject: [PATCH 260/381] [SPARK-18353][CORE] spark.rpc.askTimeout defalut value is not 120s ## What changes were proposed in this pull request? Avoid hard-coding spark.rpc.askTimeout to non-default in Client; fix doc about spark.rpc.askTimeout default ## How was this patch tested? Existing tests Author: Sean Owen Closes #15833 from srowen/SPARK-18353. --- core/src/main/scala/org/apache/spark/deploy/Client.scala | 4 +++- docs/configuration.md | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index ee276e1b7113..a4de3d7eaf45 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -221,7 +221,9 @@ object Client { val conf = new SparkConf() val driverArgs = new ClientArguments(args) - conf.set("spark.rpc.askTimeout", "10") + if (!conf.contains("spark.rpc.askTimeout")) { + conf.set("spark.rpc.askTimeout", "10s") + } Logger.getRootLogger.setLevel(driverArgs.logLevel) val rpcEnv = diff --git a/docs/configuration.md b/docs/configuration.md index c021a377ba10..a3b4ff01e6d9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1184,7 +1184,7 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.askTimeout - 120s + spark.network.timeout Duration for an RPC ask operation to wait before timing out. @@ -1566,7 +1566,7 @@ Apart from these, the following properties are also available, and may be useful spark.core.connection.ack.wait.timeout - 60s + spark.network.timeout How long for the connection to wait for ack to occur before timing out and giving up. To avoid unwilling timeout caused by long pause like GC, From ded5fefb6f5c0a97bf3d7fa1c0494dc434b6ee40 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 19 Nov 2016 13:48:56 +0000 Subject: [PATCH 261/381] [SPARK-18448][CORE] Fix @since 2.1.0 on new SparkSession.close() method ## What changes were proposed in this pull request? Fix since 2.1.0 on new SparkSession.close() method. I goofed in https://github.com/apache/spark/pull/15932 because it was back-ported to 2.1 instead of just master as originally planned. Author: Sean Owen Closes #15938 from srowen/SPARK-18448.2. --- sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e09e3caa3c98..71b1880dc071 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -652,7 +652,7 @@ class SparkSession private( /** * Synonym for `stop()`. * - * @since 2.2.0 + * @since 2.1.0 */ override def close(): Unit = stop() From ea77c81ec0db27ea4709f71dc080d00167505a7d Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Sat, 19 Nov 2016 16:02:59 -0800 Subject: [PATCH 262/381] [SPARK-17062][MESOS] add conf option to mesos dispatcher Adds --conf option to set spark configuration properties in mesos dispacther. Properties provided with --conf take precedence over properties within the properties file. The reason for this PR is that for simple configuration or testing purposes we need to provide a property file (ideally a shared one for a cluster) even if we just provide a single property. Manually tested. Author: Stavros Kontopoulos Author: Stavros Kontopoulos Closes #14650 from skonto/dipatcher_conf. --- .../org/apache/spark/deploy/SparkSubmit.scala | 18 ++--- .../spark/deploy/SparkSubmitArguments.scala | 6 +- .../apache/spark/util/CommandLineUtils.scala | 56 +++++++++++++++ .../scala/org/apache/spark/util/Utils.scala | 14 ++++ .../spark/deploy/SparkSubmitSuite.scala | 43 +++++++----- .../deploy/mesos/MesosClusterDispatcher.scala | 9 ++- .../MesosClusterDispatcherArguments.scala | 70 +++++++++++++++---- ...MesosClusterDispatcherArgumentsSuite.scala | 63 +++++++++++++++++ .../mesos/MesosClusterDispatcherSuite.scala | 40 +++++++++++ 9 files changed, 266 insertions(+), 53 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala create mode 100644 mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala create mode 100644 mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index c70061bc5b5b..85f80b6971e8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -41,12 +41,11 @@ import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBiblioResolver} -import org.apache.spark.{SPARK_REVISION, SPARK_VERSION, SparkException, SparkUserAppException} -import org.apache.spark.{SPARK_BRANCH, SPARK_BUILD_DATE, SPARK_BUILD_USER, SPARK_REPO_URL} +import org.apache.spark._ import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ import org.apache.spark.launcher.SparkLauncher -import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.util._ /** * Whether to submit, kill, or request the status of an application. @@ -63,7 +62,7 @@ private[deploy] object SparkSubmitAction extends Enumeration { * This program handles setting up the classpath with relevant Spark dependencies and provides * a layer over the different cluster managers and deploy modes that Spark supports. */ -object SparkSubmit { +object SparkSubmit extends CommandLineUtils { // Cluster managers private val YARN = 1 @@ -87,15 +86,6 @@ object SparkSubmit { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 // scalastyle:off println - // Exposed for testing - private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) - private[spark] var printStream: PrintStream = System.err - private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) - private[spark] def printErrorAndExit(str: String): Unit = { - printStream.println("Error: " + str) - printStream.println("Run with --help for usage help or --verbose for debug output") - exitFn(1) - } private[spark] def printVersionAndExit(): Unit = { printStream.println("""Welcome to ____ __ @@ -115,7 +105,7 @@ object SparkSubmit { } // scalastyle:on println - def main(args: Array[String]): Unit = { + override def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { // scalastyle:off println diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index f1761e7c1ec9..b1d36e1821cc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -412,10 +412,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S repositories = value case CONF => - value.split("=", 2).toSeq match { - case Seq(k, v) => sparkProperties(k) = v - case _ => SparkSubmit.printErrorAndExit(s"Spark config without '=': $value") - } + val (confName, confValue) = SparkSubmit.parseSparkConfProperty(value) + sparkProperties(confName) = confValue case PROXY_USER => proxyUser = value diff --git a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala new file mode 100644 index 000000000000..d73901686b70 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.PrintStream + +import org.apache.spark.SparkException + +/** + * Contains basic command line parsing functionality and methods to parse some common Spark CLI + * options. + */ +private[spark] trait CommandLineUtils { + + // Exposed for testing + private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) + + private[spark] var printStream: PrintStream = System.err + + // scalastyle:off println + + private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + + private[spark] def printErrorAndExit(str: String): Unit = { + printStream.println("Error: " + str) + printStream.println("Run with --help for usage help or --verbose for debug output") + exitFn(1) + } + + // scalastyle:on println + + private[spark] def parseSparkConfProperty(pair: String): (String, String) = { + pair.split("=", 2).toSeq match { + case Seq(k, v) => (k, v) + case _ => printErrorAndExit(s"Spark config without '=': $pair") + throw new SparkException(s"Spark config without '=': $pair") + } + } + + def main(args: Array[String]): Unit +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 23b95b9f649f..748d729554fc 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2056,6 +2056,20 @@ private[spark] object Utils extends Logging { path } + /** + * Updates Spark config with properties from a set of Properties. + * Provided properties have the highest priority. + */ + def updateSparkConfigFromProperties( + conf: SparkConf, + properties: Map[String, String]) : Unit = { + properties.filter { case (k, v) => + k.startsWith("spark.") + }.foreach { case (k, v) => + conf.set(k, v) + } + } + /** Load properties present in the given file. */ def getPropertiesFromFile(filename: String): Map[String, String] = { val file = new File(filename) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7c649e305a37..626888022903 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -34,21 +34,11 @@ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging import org.apache.spark.TestUtils.JavaSourceFromString -import org.apache.spark.util.{ResetSystemProperties, Utils} +import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} -// Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch -// of properties that needed to be cleared after tests. -class SparkSubmitSuite - extends SparkFunSuite - with Matchers - with BeforeAndAfterEach - with ResetSystemProperties - with Timeouts { - override def beforeEach() { - super.beforeEach() - System.setProperty("spark.testing", "true") - } +trait TestPrematureExit { + suite: SparkFunSuite => private val noOpOutputStream = new OutputStream { def write(b: Int) = {} @@ -65,16 +55,19 @@ class SparkSubmitSuite } /** Returns true if the script exits and the given search string is printed. */ - private def testPrematureExit(input: Array[String], searchString: String) = { + private[spark] def testPrematureExit( + input: Array[String], + searchString: String, + mainObject: CommandLineUtils = SparkSubmit) : Unit = { val printStream = new BufferPrintStream() - SparkSubmit.printStream = printStream + mainObject.printStream = printStream @volatile var exitedCleanly = false - SparkSubmit.exitFn = (_) => exitedCleanly = true + mainObject.exitFn = (_) => exitedCleanly = true val thread = new Thread { override def run() = try { - SparkSubmit.main(input) + mainObject.main(input) } catch { // If exceptions occur after the "exit" has happened, fine to ignore them. // These represent code paths not reachable during normal execution. @@ -88,6 +81,22 @@ class SparkSubmitSuite fail(s"Search string '$searchString' not found in $joined") } } +} + +// Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch +// of properties that needed to be cleared after tests. +class SparkSubmitSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach + with ResetSystemProperties + with Timeouts + with TestPrematureExit { + + override def beforeEach() { + super.beforeEach() + System.setProperty("spark.testing", "true") + } // scalastyle:off println test("prints usage on empty input") { diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 7d6693b4cdf5..792ade8f0bdb 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -25,7 +25,7 @@ import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.internal.Logging import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.util.{ShutdownHookManager, Utils} +import org.apache.spark.util.{CommandLineUtils, ShutdownHookManager, Utils} /* * A dispatcher that is responsible for managing and launching drivers, and is intended to be @@ -92,8 +92,11 @@ private[mesos] class MesosClusterDispatcher( } } -private[mesos] object MesosClusterDispatcher extends Logging { - def main(args: Array[String]) { +private[mesos] object MesosClusterDispatcher + extends Logging + with CommandLineUtils { + + override def main(args: Array[String]) { Utils.initDaemon(log) val conf = new SparkConf val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 11e13441eeba..ef08502ec8dd 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -18,23 +18,43 @@ package org.apache.spark.deploy.mesos import scala.annotation.tailrec +import scala.collection.mutable -import org.apache.spark.SparkConf import org.apache.spark.util.{IntParam, Utils} - +import org.apache.spark.SparkConf private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { - var host = Utils.localHostName() - var port = 7077 - var name = "Spark Cluster" - var webUiPort = 8081 + var host: String = Utils.localHostName() + var port: Int = 7077 + var name: String = "Spark Cluster" + var webUiPort: Int = 8081 + var verbose: Boolean = false var masterUrl: String = _ var zookeeperUrl: Option[String] = None var propertiesFile: String = _ + val confProperties: mutable.HashMap[String, String] = + new mutable.HashMap[String, String]() parse(args.toList) + // scalastyle:on println propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + Utils.updateSparkConfigFromProperties(conf, confProperties) + + // scalastyle:off println + if (verbose) { + MesosClusterDispatcher.printStream.println(s"Using host: $host") + MesosClusterDispatcher.printStream.println(s"Using port: $port") + MesosClusterDispatcher.printStream.println(s"Using webUiPort: $webUiPort") + MesosClusterDispatcher.printStream.println(s"Framework Name: $name") + + Option(propertiesFile).foreach { file => + MesosClusterDispatcher.printStream.println(s"Using properties file: $file") + } + + MesosClusterDispatcher.printStream.println(s"Spark Config properties set:") + conf.getAll.foreach(println) + } @tailrec private def parse(args: List[String]): Unit = args match { @@ -58,9 +78,10 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case ("--master" | "-m") :: value :: tail => if (!value.startsWith("mesos://")) { // scalastyle:off println - System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + MesosClusterDispatcher.printStream + .println("Cluster dispatcher only supports mesos (uri begins with mesos://)") // scalastyle:on println - System.exit(1) + MesosClusterDispatcher.exitFn(1) } masterUrl = value.stripPrefix("mesos://") parse(tail) @@ -73,28 +94,45 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: propertiesFile = value parse(tail) + case ("--conf") :: value :: tail => + val pair = MesosClusterDispatcher. + parseSparkConfProperty(value) + confProperties(pair._1) = pair._2 + parse(tail) + case ("--help") :: tail => - printUsageAndExit(0) + printUsageAndExit(0) + + case ("--verbose") :: tail => + verbose = true + parse(tail) case Nil => - if (masterUrl == null) { + if (Option(masterUrl).isEmpty) { // scalastyle:off println - System.err.println("--master is required") + MesosClusterDispatcher.printStream.println("--master is required") // scalastyle:on println printUsageAndExit(1) } - case _ => + case value => + // scalastyle:off println + MesosClusterDispatcher.printStream.println(s"Unrecognized option: '${value.head}'") + // scalastyle:on println printUsageAndExit(1) } private def printUsageAndExit(exitCode: Int): Unit = { + val outStream = MesosClusterDispatcher.printStream + // scalastyle:off println - System.err.println( + outStream.println( "Usage: MesosClusterDispatcher [options]\n" + "\n" + "Options:\n" + " -h HOST, --host HOST Hostname to listen on\n" + + " --help Show this help message and exit.\n" + + " --verbose, Print additional debug output.\n" + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + " --webui-port WEBUI_PORT WebUI Port to listen on (default: 8081)\n" + " --name NAME Framework name to show in Mesos UI\n" + @@ -102,8 +140,10 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: " -z --zk ZOOKEEPER Comma delimited URLs for connecting to \n" + " Zookeeper for persistence\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + - " Default is conf/spark-defaults.conf.") + " Default is conf/spark-defaults.conf \n" + + " --conf PROP=VALUE Arbitrary Spark configuration property.\n" + + " Takes precedence over defined properties in properties-file.") // scalastyle:on println - System.exit(exitCode) + MesosClusterDispatcher.exitFn(exitCode) } } diff --git a/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala new file mode 100644 index 000000000000..b6c0b325361d --- /dev/null +++ b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.TestPrematureExit + +class MesosClusterDispatcherArgumentsSuite extends SparkFunSuite + with TestPrematureExit { + + test("test if spark config args are passed sucessfully") { + val args = Array[String]("--master", "mesos://localhost:5050", "--conf", "key1=value1", + "--conf", "spark.mesos.key2=value2", "--verbose") + val conf = new SparkConf() + new MesosClusterDispatcherArguments(args, conf) + + assert(conf.getOption("key1").isEmpty) + assert(conf.get("spark.mesos.key2") == "value2") + } + + test("test non conf settings") { + val masterUrl = "mesos://localhost:5050" + val port = "1212" + val zookeeperUrl = "zk://localhost:2181" + val host = "localhost" + val webUiPort = "2323" + val name = "myFramework" + + val args1 = Array("--master", masterUrl, "--verbose", "--name", name) + val args2 = Array("-p", port, "-h", host, "-z", zookeeperUrl) + val args3 = Array("--webui-port", webUiPort) + + val args = args1 ++ args2 ++ args3 + val conf = new SparkConf() + val mesosDispClusterArgs = new MesosClusterDispatcherArguments(args, conf) + + assert(mesosDispClusterArgs.verbose) + assert(mesosDispClusterArgs.confProperties.isEmpty) + assert(mesosDispClusterArgs.host == host) + assert(Option(mesosDispClusterArgs.masterUrl).isDefined) + assert(mesosDispClusterArgs.masterUrl == masterUrl.stripPrefix("mesos://")) + assert(Option(mesosDispClusterArgs.zookeeperUrl).isDefined) + assert(mesosDispClusterArgs.zookeeperUrl contains zookeeperUrl) + assert(mesosDispClusterArgs.name == name) + assert(mesosDispClusterArgs.webUiPort == webUiPort.toInt) + assert(mesosDispClusterArgs.port == port.toInt) + } +} diff --git a/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala new file mode 100644 index 000000000000..7484e3b83670 --- /dev/null +++ b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.TestPrematureExit + +class MesosClusterDispatcherSuite extends SparkFunSuite + with TestPrematureExit{ + + test("prints usage on empty input") { + testPrematureExit(Array[String](), + "Usage: MesosClusterDispatcher", MesosClusterDispatcher) + } + + test("prints usage with only --help") { + testPrematureExit(Array("--help"), + "Usage: MesosClusterDispatcher", MesosClusterDispatcher) + } + + test("prints error with unrecognized options") { + testPrematureExit(Array("--blarg"), "Unrecognized option: '--blarg'", MesosClusterDispatcher) + testPrematureExit(Array("-bleg"), "Unrecognized option: '-bleg'", MesosClusterDispatcher) + } +} From 856e0042007c789dda4539fb19a5d4580999fbf4 Mon Sep 17 00:00:00 2001 From: sethah Date: Sun, 20 Nov 2016 01:42:37 +0000 Subject: [PATCH 263/381] [SPARK-18456][ML][FOLLOWUP] Use matrix abstraction for coefficients in LogisticRegression training ## What changes were proposed in this pull request? This is a follow up to some of the discussion [here](https://github.com/apache/spark/pull/15593). During LogisticRegression training, we store the coefficients combined with intercepts as a flat vector, but a more natural abstraction is a matrix. Here, we refactor the code to use matrix where possible, which makes the code more readable and greatly simplifies the indexing. Note: We do not use a Breeze matrix for the cost function as was mentioned in the linked PR. This is because LBFGS/OWLQN require an implicit `MutableInnerProductModule[DenseMatrix[Double], Double]` which is not natively defined in Breeze. We would need to extend Breeze in Spark to define it ourselves. Also, we do not modify the `regParamL1Fun` because OWLQN in Breeze requires a `MutableEnumeratedCoordinateField[(Int, Int), DenseVector[Double]]` (since we still use a dense vector for coefficients). Here again we would have to extend Breeze inside Spark. ## How was this patch tested? This is internal code refactoring - the current unit tests passing show us that the change did not break anything. No added functionality in this patch. Author: sethah Closes #15893 from sethah/logreg_refactor. --- .../classification/LogisticRegression.scala | 115 ++++++++---------- 1 file changed, 53 insertions(+), 62 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 71a7fe53c15f..f58efd36a1c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -463,16 +463,11 @@ class LogisticRegression @Since("1.2.0") ( } /* - The coefficients are laid out in column major order during training. e.g. for - `numClasses = 3` and `numFeatures = 2` and `fitIntercept = true` the layout is: - - Array(beta_11, beta_21, beta_31, beta_12, beta_22, beta_32, intercept_1, intercept_2, - intercept_3) - - where beta_jk corresponds to the coefficient for class `j` and feature `k`. + The coefficients are laid out in column major order during training. Here we initialize + a column major matrix of initial coefficients. */ - val initialCoefficientsWithIntercept = - Vectors.zeros(numCoefficientSets * numFeaturesPlusIntercept) + val initialCoefWithInterceptMatrix = + Matrices.zeros(numCoefficientSets, numFeaturesPlusIntercept) val initialModelIsValid = optInitialModel match { case Some(_initialModel) => @@ -491,18 +486,15 @@ class LogisticRegression @Since("1.2.0") ( } if (initialModelIsValid) { - val initialCoefWithInterceptArray = initialCoefficientsWithIntercept.toArray val providedCoef = optInitialModel.get.coefficientMatrix - providedCoef.foreachActive { (row, col, value) => - // convert matrix to column major for training - val flatIndex = col * numCoefficientSets + row + providedCoef.foreachActive { (classIndex, featureIndex, value) => // We need to scale the coefficients since they will be trained in the scaled space - initialCoefWithInterceptArray(flatIndex) = value * featuresStd(col) + initialCoefWithInterceptMatrix.update(classIndex, featureIndex, + value * featuresStd(featureIndex)) } if ($(fitIntercept)) { - optInitialModel.get.interceptVector.foreachActive { (index, value) => - val coefIndex = numCoefficientSets * numFeatures + index - initialCoefWithInterceptArray(coefIndex) = value + optInitialModel.get.interceptVector.foreachActive { (classIndex, value) => + initialCoefWithInterceptMatrix.update(classIndex, numFeatures, value) } } } else if ($(fitIntercept) && isMultinomial) { @@ -532,8 +524,7 @@ class LogisticRegression @Since("1.2.0") ( val rawIntercepts = histogram.map(c => math.log(c + 1)) // add 1 for smoothing val rawMean = rawIntercepts.sum / rawIntercepts.length rawIntercepts.indices.foreach { i => - initialCoefficientsWithIntercept.toArray(numClasses * numFeatures + i) = - rawIntercepts(i) - rawMean + initialCoefWithInterceptMatrix.update(i, numFeatures, rawIntercepts(i) - rawMean) } } else if ($(fitIntercept)) { /* @@ -549,12 +540,12 @@ class LogisticRegression @Since("1.2.0") ( b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - initialCoefficientsWithIntercept.toArray(numFeatures) = math.log( - histogram(1) / histogram(0)) + initialCoefWithInterceptMatrix.update(0, numFeatures, + math.log(histogram(1) / histogram(0))) } val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialCoefficientsWithIntercept.asBreeze.toDenseVector) + new BDV[Double](initialCoefWithInterceptMatrix.toArray)) /* Note that in Logistic Regression, the objective history (loss + regularization) @@ -586,15 +577,24 @@ class LogisticRegression @Since("1.2.0") ( Note that the intercept in scaled space and original space is the same; as a result, no scaling is needed. */ - val rawCoefficients = state.x.toArray.clone() - val coefficientArray = Array.tabulate(numCoefficientSets * numFeatures) { i => - val colMajorIndex = (i % numFeatures) * numCoefficientSets + i / numFeatures - val featureIndex = i % numFeatures - if (featuresStd(featureIndex) != 0.0) { - rawCoefficients(colMajorIndex) / featuresStd(featureIndex) - } else { - 0.0 + val allCoefficients = state.x.toArray.clone() + val allCoefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, + allCoefficients) + val denseCoefficientMatrix = new DenseMatrix(numCoefficientSets, numFeatures, + new Array[Double](numCoefficientSets * numFeatures), isTransposed = true) + val interceptVec = if ($(fitIntercept) || !isMultinomial) { + Vectors.zeros(numCoefficientSets) + } else { + Vectors.sparse(numCoefficientSets, Seq()) + } + // separate intercepts and coefficients from the combined matrix + allCoefMatrix.foreachActive { (classIndex, featureIndex, value) => + val isIntercept = $(fitIntercept) && (featureIndex == numFeatures) + if (!isIntercept && featuresStd(featureIndex) != 0.0) { + denseCoefficientMatrix.update(classIndex, featureIndex, + value / featuresStd(featureIndex)) } + if (isIntercept) interceptVec.toArray(classIndex) = value } if ($(regParam) == 0.0 && isMultinomial) { @@ -607,17 +607,16 @@ class LogisticRegression @Since("1.2.0") ( Friedman, et al. "Regularization Paths for Generalized Linear Models via Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf */ - val coefficientMean = coefficientArray.sum / coefficientArray.length - coefficientArray.indices.foreach { i => coefficientArray(i) -= coefficientMean} + val denseValues = denseCoefficientMatrix.values + val coefficientMean = denseValues.sum / denseValues.length + denseCoefficientMatrix.update(_ - coefficientMean) } - val denseCoefficientMatrix = - new DenseMatrix(numCoefficientSets, numFeatures, coefficientArray, isTransposed = true) // TODO: use `denseCoefficientMatrix.compressed` after SPARK-17471 val compressedCoefficientMatrix = if (isMultinomial) { denseCoefficientMatrix } else { - val compressedVector = Vectors.dense(coefficientArray).compressed + val compressedVector = Vectors.dense(denseCoefficientMatrix.values).compressed compressedVector match { case dv: DenseVector => denseCoefficientMatrix case sv: SparseVector => @@ -626,25 +625,13 @@ class LogisticRegression @Since("1.2.0") ( } } - val interceptsArray: Array[Double] = if ($(fitIntercept)) { - Array.tabulate(numCoefficientSets) { i => - val coefIndex = numFeatures * numCoefficientSets + i - rawCoefficients(coefIndex) - } - } else { - Array.empty[Double] - } - val interceptVector = if (interceptsArray.nonEmpty && isMultinomial) { - // The intercepts are never regularized, so we always center the mean. - val interceptMean = interceptsArray.sum / numClasses - interceptsArray.indices.foreach { i => interceptsArray(i) -= interceptMean } - Vectors.dense(interceptsArray) - } else if (interceptsArray.length == 1) { - Vectors.dense(interceptsArray) - } else { - Vectors.sparse(numCoefficientSets, Seq()) + // center the intercepts when using multinomial algorithm + if ($(fitIntercept) && isMultinomial) { + val interceptArray = interceptVec.toArray + val interceptMean = interceptArray.sum / interceptArray.length + (0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean } } - (compressedCoefficientMatrix, interceptVector.compressed, arrayBuilder.result()) + (compressedCoefficientMatrix, interceptVec.compressed, arrayBuilder.result()) } } @@ -1424,6 +1411,7 @@ private class LogisticAggregator( private val numFeatures = bcFeaturesStd.value.length private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures private val coefficientSize = bcCoefficients.value.size + private val numCoefficientSets = if (multinomial) numClasses else 1 if (multinomial) { require(numClasses == coefficientSize / numFeaturesPlusIntercept, s"The number of " + s"coefficients should be ${numClasses * numFeaturesPlusIntercept} but was $coefficientSize") @@ -1633,12 +1621,12 @@ private class LogisticAggregator( lossSum / weightSum } - def gradient: Vector = { + def gradient: Matrix = { require(weightSum > 0.0, s"The effective number of instances should be " + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) scal(1.0 / weightSum, result) - result + new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, result.toArray) } } @@ -1664,6 +1652,7 @@ private class LogisticCostFun( val featuresStd = bcFeaturesStd.value val numFeatures = featuresStd.length val numCoefficientSets = if (multinomial) numClasses else 1 + val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures val logisticAggregator = { val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) @@ -1675,24 +1664,25 @@ private class LogisticCostFun( )(seqOp, combOp, aggregationDepth) } - val totalGradientArray = logisticAggregator.gradient.toArray + val totalGradientMatrix = logisticAggregator.gradient + val coefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, coeffs.toArray) // regVal is the sum of coefficients squares excluding intercept for L2 regularization. val regVal = if (regParamL2 == 0.0) { 0.0 } else { var sum = 0.0 - coeffs.foreachActive { case (index, value) => + coefMatrix.foreachActive { case (classIndex, featureIndex, value) => // We do not apply regularization to the intercepts - val isIntercept = fitIntercept && index >= numCoefficientSets * numFeatures + val isIntercept = fitIntercept && (featureIndex == numFeatures) if (!isIntercept) { // The following code will compute the loss of the regularization; also // the gradient of the regularization, and add back to totalGradientArray. sum += { if (standardization) { - totalGradientArray(index) += regParamL2 * value + val gradValue = totalGradientMatrix(classIndex, featureIndex) + totalGradientMatrix.update(classIndex, featureIndex, gradValue + regParamL2 * value) value * value } else { - val featureIndex = index / numCoefficientSets if (featuresStd(featureIndex) != 0.0) { // If `standardization` is false, we still standardize the data // to improve the rate of convergence; as a result, we have to @@ -1700,7 +1690,8 @@ private class LogisticCostFun( // differently to get effectively the same objective function when // the training dataset is not standardized. val temp = value / (featuresStd(featureIndex) * featuresStd(featureIndex)) - totalGradientArray(index) += regParamL2 * temp + val gradValue = totalGradientMatrix(classIndex, featureIndex) + totalGradientMatrix.update(classIndex, featureIndex, gradValue + regParamL2 * temp) value * temp } else { 0.0 @@ -1713,6 +1704,6 @@ private class LogisticCostFun( } bcCoeffs.destroy(blocking = false) - (logisticAggregator.loss + regVal, new BDV(totalGradientArray)) + (logisticAggregator.loss + regVal, new BDV(totalGradientMatrix.toArray)) } } From d93b6552473468df297a08c0bef9ea0bf0f5c13a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 19 Nov 2016 21:50:20 -0800 Subject: [PATCH 264/381] [SPARK-18458][CORE] Fix signed integer overflow problem at an expression in RadixSort.java ## What changes were proposed in this pull request? This PR avoids that a result of an expression is negative due to signed integer overflow (e.g. 0x10?????? * 8 < 0). This PR casts each operand to `long` before executing a calculation. Since the result is interpreted as long, the result of the expression is positive. ## How was this patch tested? Manually executed query82 of TPC-DS with 100TB Author: Kazuaki Ishizaki Closes #15907 from kiszk/SPARK-18458. --- .../collection/unsafe/sort/RadixSort.java | 48 ++++++++++--------- .../unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../unsafe/sort/RadixSortSuite.scala | 28 +++++------ 3 files changed, 40 insertions(+), 38 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java index 404361734a55..3dd318471008 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java @@ -17,6 +17,8 @@ package org.apache.spark.util.collection.unsafe.sort; +import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; @@ -40,14 +42,14 @@ public class RadixSort { * of always copying the data back to position zero for efficiency. */ public static int sort( - LongArray array, int numRecords, int startByteIndex, int endByteIndex, + LongArray array, long numRecords, int startByteIndex, int endByteIndex, boolean desc, boolean signed) { assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0"; assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 2 <= array.size(); - int inIndex = 0; - int outIndex = numRecords; + long inIndex = 0; + long outIndex = numRecords; if (numRecords > 0) { long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex); for (int i = startByteIndex; i <= endByteIndex; i++) { @@ -55,13 +57,13 @@ public static int sort( sortAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -78,14 +80,14 @@ public static int sort( * @param signed whether this is a signed (two's complement) sort (only applies to last byte). */ private static void sortAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed); Object baseObject = array.getBaseObject(); - long baseOffset = array.getBaseOffset() + inIndex * 8; - long maxOffset = baseOffset + numRecords * 8; + long baseOffset = array.getBaseOffset() + inIndex * 8L; + long maxOffset = baseOffset + numRecords * 8L; for (long offset = baseOffset; offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); int bucket = (int)((value >>> (byteIdx * 8)) & 0xff); @@ -106,13 +108,13 @@ private static void sortAtByte( * significant byte. If the byte does not need sorting the array will be null. */ private static long[][] getCounts( - LongArray array, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting. // If all the byte values at a particular index are the same we don't need to count it. long bitwiseMax = 0; long bitwiseMin = -1L; - long maxOffset = array.getBaseOffset() + numRecords * 8; + long maxOffset = array.getBaseOffset() + numRecords * 8L; Object baseObject = array.getBaseObject(); for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); @@ -146,18 +148,18 @@ private static long[][] getCounts( * @return the input counts array. */ private static long[] transformCountsToOffsets( - long[] counts, int numRecords, long outputOffset, int bytesPerRecord, + long[] counts, long numRecords, long outputOffset, long bytesPerRecord, boolean desc, boolean signed) { assert counts.length == 256; int start = signed ? 128 : 0; // output the negative records first (values 129-255). if (desc) { - int pos = numRecords; + long pos = numRecords; for (int i = start; i < start + 256; i++) { pos -= counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; } } else { - int pos = 0; + long pos = 0; for (int i = start; i < start + 256; i++) { long tmp = counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; @@ -176,8 +178,8 @@ private static long[] transformCountsToOffsets( */ public static int sortKeyPrefixArray( LongArray array, - int startIndex, - int numRecords, + long startIndex, + long numRecords, int startByteIndex, int endByteIndex, boolean desc, @@ -186,8 +188,8 @@ public static int sortKeyPrefixArray( assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 4 <= array.size(); - int inIndex = startIndex; - int outIndex = startIndex + numRecords * 2; + long inIndex = startIndex; + long outIndex = startIndex + numRecords * 2L; if (numRecords > 0) { long[][] counts = getKeyPrefixArrayCounts( array, startIndex, numRecords, startByteIndex, endByteIndex); @@ -196,13 +198,13 @@ public static int sortKeyPrefixArray( sortKeyPrefixArrayAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -210,7 +212,7 @@ public static int sortKeyPrefixArray( * getCounts with some added parameters but that seems to hurt in benchmarks. */ private static long[][] getKeyPrefixArrayCounts( - LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; long bitwiseMax = 0; long bitwiseMin = -1L; @@ -238,11 +240,11 @@ private static long[][] getKeyPrefixArrayCounts( * Specialization of sortAtByte() for key-prefix arrays. */ private static void sortKeyPrefixArrayAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed); Object baseObject = array.getBaseObject(); long baseOffset = array.getBaseOffset() + inIndex * 8L; long maxOffset = baseOffset + numRecords * 16L; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 2a71e68adafa..252a35ec6bdf 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -322,7 +322,7 @@ public UnsafeSorterIterator getSortedIterator() { if (sortComparator != null) { if (this.radixSortSupport != null) { offset = RadixSort.sortKeyPrefixArray( - array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7, + array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { MemoryBlock unused = new MemoryBlock( diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index 366ffda7788d..d5956ea32096 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -22,6 +22,8 @@ import java.util.{Arrays, Comparator} import scala.util.Random +import com.google.common.primitives.Ints + import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray @@ -30,7 +32,7 @@ import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom class RadixSortSuite extends SparkFunSuite with Logging { - private val N = 10000 // scale this down for more readable results + private val N = 10000L // scale this down for more readable results /** * Describes a type of sort to test, e.g. two's complement descending. Each sort type has @@ -73,22 +75,22 @@ class RadixSortSuite extends SparkFunSuite with Logging { }, 2, 4, false, false, true)) - private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = { - val ref = Array.tabulate[Long](size) { i => rand } - val extended = ref ++ Array.fill[Long](size)(0) + private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { + val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } + val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) } - private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { - val ref = Array.tabulate[Long](size * 2) { i => rand } - val extended = ref ++ Array.fill[Long](size * 2)(0) + private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { + val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand } + val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0) (new LongArray(MemoryBlock.fromLongArray(ref)), new LongArray(MemoryBlock.fromLongArray(extended))) } - private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = { + private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = { var i = 0 - val out = new Array[Long](length) + val out = new Array[Long](Ints.checkedCast(length)) while (i < length) { out(i) = array.get(offset + i) i += 1 @@ -107,15 +109,13 @@ class RadixSortSuite extends SparkFunSuite with Logging { } } - private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { + private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) { val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( - buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { + buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] { override def compare( r1: RecordPointerAndKeyPrefix, - r2: RecordPointerAndKeyPrefix): Int = { - refCmp.compare(r1.keyPrefix, r2.keyPrefix) - } + r2: RecordPointerAndKeyPrefix): Int = refCmp.compare(r1.keyPrefix, r2.keyPrefix) }) } From bce9a03677f931d52491e7768aba9e4a19a7e696 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 19 Nov 2016 21:57:09 -0800 Subject: [PATCH 265/381] [SPARK-18508][SQL] Fix documentation error for DateDiff ## What changes were proposed in this pull request? The previous documentation and example for DateDiff was wrong. ## How was this patch tested? Doc only change. Author: Reynold Xin Closes #15937 from rxin/datediff-doc. --- .../sql/catalyst/expressions/datetimeExpressions.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 9cec6be841de..1db1d1995d94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1101,11 +1101,14 @@ case class TruncDate(date: Expression, format: Expression) * Returns the number of days from startDate to endDate. */ @ExpressionDescription( - usage = "_FUNC_(date1, date2) - Returns the number of days between `date1` and `date2`.", + usage = "_FUNC_(endDate, startDate) - Returns the number of days from `startDate` to `endDate`.", extended = """ Examples: - > SELECT _FUNC_('2009-07-30', '2009-07-31'); + > SELECT _FUNC_('2009-07-31', '2009-07-30'); 1 + + > SELECT _FUNC_('2009-07-30', '2009-07-31'); + -1 """) case class DateDiff(endDate: Expression, startDate: Expression) extends BinaryExpression with ImplicitCastInputTypes { From a64f25d8b403b17ff68c9575f6f35b22e5b62427 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 19 Nov 2016 21:57:49 -0800 Subject: [PATCH 266/381] [SQL] Fix documentation for Concat and ConcatWs --- .../sql/catalyst/expressions/stringExpressions.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e74ef9a08750..908aa44f81c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -40,15 +40,13 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} * An expression that concatenates multiple input strings into a single string. * If any input is null, concat returns null. */ -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of `str1`, `str2`, ..., `strN`.", + usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN.", extended = """ Examples: - > SELECT _FUNC_('Spark','SQL'); + > SELECT _FUNC_('Spark', 'SQL'); SparkSQL """) -// scalastyle:on line.size.limit case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) @@ -89,8 +87,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas usage = "_FUNC_(sep, [str | array(str)]+) - Returns the concatenation of the strings separated by `sep`.", extended = """ Examples: - > SELECT _FUNC_(' ', Spark', 'SQL'); - Spark SQL + > SELECT _FUNC_(' ', 'Spark', 'SQL'); + Spark SQL """) // scalastyle:on line.size.limit case class ConcatWs(children: Seq[Expression]) From 7ca7a635242377634c302b7816ce60bd9c908527 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 19 Nov 2016 23:55:09 -0800 Subject: [PATCH 267/381] [SPARK-15214][SQL] Code-generation for Generate ## What changes were proposed in this pull request? This PR adds code generation to `Generate`. It supports two code paths: - General `TraversableOnce` based iteration. This used for regular `Generator` (code generation supporting) expressions. This code path expects the expression to return a `TraversableOnce[InternalRow]` and it will iterate over the returned collection. This PR adds code generation for the `stack` generator. - Specialized `ArrayData/MapData` based iteration. This is used for the `explode`, `posexplode` & `inline` functions and operates directly on the `ArrayData`/`MapData` result that the child of the generator returns. ### Benchmarks I have added some benchmarks and it seems we can create a nice speedup for explode: #### Environment ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 Intel(R) Core(TM) i7-4980HQ CPU 2.80GHz ``` #### Explode Array ##### Before ``` generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode array wholestage off 7377 / 7607 2.3 439.7 1.0X generate explode array wholestage on 6055 / 6086 2.8 360.9 1.2X ``` ##### After ``` generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode array wholestage off 7432 / 7696 2.3 443.0 1.0X generate explode array wholestage on 631 / 646 26.6 37.6 11.8X ``` #### Explode Map ##### Before ``` generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode map wholestage off 12792 / 12848 1.3 762.5 1.0X generate explode map wholestage on 11181 / 11237 1.5 666.5 1.1X ``` ##### After ``` generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode map wholestage off 10949 / 10972 1.5 652.6 1.0X generate explode map wholestage on 870 / 913 19.3 51.9 12.6X ``` #### Posexplode ##### Before ``` generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate posexplode array wholestage off 7547 / 7580 2.2 449.8 1.0X generate posexplode array wholestage on 5786 / 5838 2.9 344.9 1.3X ``` ##### After ``` generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate posexplode array wholestage off 7535 / 7548 2.2 449.1 1.0X generate posexplode array wholestage on 620 / 624 27.1 37.0 12.1X ``` #### Inline ##### Before ``` generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate inline array wholestage off 6935 / 6978 2.4 413.3 1.0X generate inline array wholestage on 6360 / 6400 2.6 379.1 1.1X ``` ##### After ``` generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate inline array wholestage off 6940 / 6966 2.4 413.6 1.0X generate inline array wholestage on 1002 / 1012 16.7 59.7 6.9X ``` #### Stack ##### Before ``` generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate stack wholestage off 12980 / 13104 1.3 773.7 1.0X generate stack wholestage on 11566 / 11580 1.5 689.4 1.1X ``` ##### After ``` generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate stack wholestage off 12875 / 12949 1.3 767.4 1.0X generate stack wholestage on 840 / 845 20.0 50.0 15.3X ``` ## How was this patch tested? Existing tests. Author: Herman van Hovell Author: Herman van Hovell Closes #13065 from hvanhovell/SPARK-15214. --- .../sql/catalyst/expressions/generators.scala | 110 ++++++++-- .../SubexpressionEliminationSuite.scala | 16 +- .../spark/sql/execution/GenerateExec.scala | 202 +++++++++++++++++- .../spark/sql/GeneratorFunctionSuite.scala | 34 +++ .../org/apache/spark/sql/SQLQuerySuite.scala | 7 - .../execution/WholeStageCodegenSuite.scala | 32 ++- .../execution/benchmark/MiscBenchmark.scala | 99 ++++++++- 7 files changed, 463 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index d042bfb63d56..6c38f4998e91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -60,6 +62,26 @@ trait Generator extends Expression { * rows can be made here. */ def terminate(): TraversableOnce[InternalRow] = Nil + + /** + * Check if this generator supports code generation. + */ + def supportCodegen: Boolean = !isInstanceOf[CodegenFallback] +} + +/** + * A collection producing [[Generator]]. This trait provides a different path for code generation, + * by allowing code generation to return either an [[ArrayData]] or a [[MapData]] object. + */ +trait CollectionGenerator extends Generator { + /** The position of an element within the collection should also be returned. */ + def position: Boolean + + /** Rows will be inlined during generation. */ + def inline: Boolean + + /** The type of the returned collection object. */ + def collectionType: DataType = dataType } /** @@ -77,7 +99,9 @@ case class UserDefinedGenerator( private def initializeConverters(): Unit = { inputRow = new InterpretedProjection(children) convertToScala = { - val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) + val inputSchema = StructType(children.map { e => + StructField(e.simpleString, e.dataType, nullable = true) + }) CatalystTypeConverters.createToScalaConverter(inputSchema) }.asInstanceOf[InternalRow => Row] } @@ -109,8 +133,7 @@ case class UserDefinedGenerator( 1 2 3 NULL """) -case class Stack(children: Seq[Expression]) - extends Expression with Generator with CodegenFallback { +case class Stack(children: Seq[Expression]) extends Generator { private lazy val numRows = children.head.eval().asInstanceOf[Int] private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt @@ -149,21 +172,50 @@ case class Stack(children: Seq[Expression]) InternalRow(fields: _*) } } + + + /** + * Only support code generation when stack produces 50 rows or less. + */ + override def supportCodegen: Boolean = numRows <= 50 + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // Rows - we write these into an array. + val rowData = ctx.freshName("rows") + ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];") + val values = children.tail + val dataTypes = values.take(numFields).map(_.dataType) + val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row => + val fields = Seq.tabulate(numFields) { col => + val index = row * numFields + col + if (index < values.length) values(index) else Literal(null, dataTypes(col)) + } + val eval = CreateStruct(fields).genCode(ctx) + s"${eval.code}\nthis.$rowData[$row] = ${eval.value};" + }) + + // Create the collection. + val wrapperClass = classOf[mutable.WrappedArray[_]].getName + ctx.addMutableState( + s"$wrapperClass", + ev.value, + s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);") + ev.copy(code = code, isNull = "false") + } } /** - * A base class for Explode and PosExplode + * A base class for [[Explode]] and [[PosExplode]]. */ -abstract class ExplodeBase(child: Expression, position: Boolean) - extends UnaryExpression with Generator with CodegenFallback with Serializable { +abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with Serializable { + override val inline: Boolean = false - override def checkInputDataTypes(): TypeCheckResult = { - if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) { + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case _: ArrayType | _: MapType => TypeCheckResult.TypeCheckSuccess - } else { + case _ => TypeCheckResult.TypeCheckFailure( s"input to function explode should be array or map type, not ${child.dataType}") - } } // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) @@ -171,7 +223,7 @@ abstract class ExplodeBase(child: Expression, position: Boolean) case ArrayType(et, containsNull) => if (position) { new StructType() - .add("pos", IntegerType, false) + .add("pos", IntegerType, nullable = false) .add("col", et, containsNull) } else { new StructType() @@ -180,12 +232,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean) case MapType(kt, vt, valueContainsNull) => if (position) { new StructType() - .add("pos", IntegerType, false) - .add("key", kt, false) + .add("pos", IntegerType, nullable = false) + .add("key", kt, nullable = false) .add("value", vt, valueContainsNull) } else { new StructType() - .add("key", kt, false) + .add("key", kt, nullable = false) .add("value", vt, valueContainsNull) } } @@ -218,6 +270,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean) } } } + + override def collectionType: DataType = child.dataType + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.genCode(ctx) + } } /** @@ -239,7 +297,9 @@ abstract class ExplodeBase(child: Expression, position: Boolean) 20 """) // scalastyle:on line.size.limit -case class Explode(child: Expression) extends ExplodeBase(child, position = false) +case class Explode(child: Expression) extends ExplodeBase { + override val position: Boolean = false +} /** * Given an input array produces a sequence of rows for each position and value in the array. @@ -260,7 +320,9 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals 1 20 """) // scalastyle:on line.size.limit -case class PosExplode(child: Expression) extends ExplodeBase(child, position = true) +case class PosExplode(child: Expression) extends ExplodeBase { + override val position = true +} /** * Explodes an array of structs into a table. @@ -273,10 +335,12 @@ case class PosExplode(child: Expression) extends ExplodeBase(child, position = t 1 a 2 b """) -case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback { +case class Inline(child: Expression) extends UnaryExpression with CollectionGenerator { + override val inline: Boolean = true + override val position: Boolean = false override def checkInputDataTypes(): TypeCheckResult = child.dataType match { - case ArrayType(et, _) if et.isInstanceOf[StructType] => + case ArrayType(st: StructType, _) => TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure( @@ -284,9 +348,11 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with } override def elementSchema: StructType = child.dataType match { - case ArrayType(et : StructType, _) => et + case ArrayType(st: StructType, _) => st } + override def collectionType: DataType = child.dataType + private lazy val numFields = elementSchema.fields.length override def eval(input: InternalRow): TraversableOnce[InternalRow] = { @@ -298,4 +364,8 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with yield inputArray.getStruct(i, numFields) } } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.genCode(ctx) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 1e39b24fe877..2db2a043e546 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types.{DataType, IntegerType} class SubexpressionEliminationSuite extends SparkFunSuite { test("Semantic equals and hash") { @@ -162,13 +163,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite { test("Children of CodegenFallback") { val one = Literal(1) val two = Add(one, one) - val explode = Explode(two) - val add = Add(two, explode) + val fallback = CodegenFallbackExpression(two) + val add = Add(two, fallback) - var equivalence = new EquivalentExpressions + val equivalence = new EquivalentExpressions equivalence.addExprTree(add, true) - // the `two` inside `explode` should not be added + // the `two` inside `fallback` should not be added assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode } } + +case class CodegenFallbackExpression(child: Expression) + extends UnaryExpression with CodegenFallback { + override def dataType: DataType = child.dataType +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 19fbf0c16204..f80214af43fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} /** * For lazy computing, be sure the generator.terminate() called in the very last @@ -40,6 +42,10 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * + * This operator supports whole stage code generation for generators that do not implement + * terminate(). + * * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. @@ -54,7 +60,7 @@ case class GenerateExec( outer: Boolean, output: Seq[Attribute], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with CodegenSupport { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -103,5 +109,197 @@ case class GenerateExec( } } } -} + override def supportCodegen: Boolean = generator.supportCodegen + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + ctx.currentVars = input + ctx.copyResult = true + + // Add input rows to the values when we are joining + val values = if (join) { + input + } else { + Seq.empty + } + + boundGenerator match { + case e: CollectionGenerator => codeGenCollection(ctx, e, values, row) + case g => codeGenTraversableOnce(ctx, g, values, row) + } + } + + /** + * Generate code for [[CollectionGenerator]] expressions. + */ + private def codeGenCollection( + ctx: CodegenContext, + e: CollectionGenerator, + input: Seq[ExprCode], + row: ExprCode): String = { + + // Generate code for the generator. + val data = e.genCode(ctx) + + // Generate looping variables. + val index = ctx.freshName("index") + + // Add a check if the generate outer flag is true. + val checks = optionalCode(outer, data.isNull) + + // Add position + val position = if (e.position) { + Seq(ExprCode("", "false", index)) + } else { + Seq.empty + } + + // Generate code for either ArrayData or MapData + val (initMapData, updateRowData, values) = e.collectionType match { + case ArrayType(st: StructType, nullable) if e.inline => + val row = codeGenAccessor(ctx, data.value, "col", index, st, nullable, checks) + val fieldChecks = checks ++ optionalCode(nullable, row.isNull) + val columns = st.fields.toSeq.zipWithIndex.map { case (f, i) => + codeGenAccessor(ctx, row.value, f.name, i.toString, f.dataType, f.nullable, fieldChecks) + } + ("", row.code, columns) + + case ArrayType(dataType, nullable) => + ("", "", Seq(codeGenAccessor(ctx, data.value, "col", index, dataType, nullable, checks))) + + case MapType(keyType, valueType, valueContainsNull) => + // Materialize the key and the value arrays before we enter the loop. + val keyArray = ctx.freshName("keyArray") + val valueArray = ctx.freshName("valueArray") + val initArrayData = + s""" + |ArrayData $keyArray = ${data.isNull} ? null : ${data.value}.keyArray(); + |ArrayData $valueArray = ${data.isNull} ? null : ${data.value}.valueArray(); + """.stripMargin + val values = Seq( + codeGenAccessor(ctx, keyArray, "key", index, keyType, nullable = false, checks), + codeGenAccessor(ctx, valueArray, "value", index, valueType, valueContainsNull, checks)) + (initArrayData, "", values) + } + + // In case of outer=true we need to make sure the loop is executed at-least once when the + // array/map contains no input. We do this by setting the looping index to -1 if there is no + // input, evaluation of the array is prevented by a check in the accessor code. + val numElements = ctx.freshName("numElements") + val init = if (outer) { + s"$numElements == 0 ? -1 : 0" + } else { + "0" + } + val numOutput = metricTerm(ctx, "numOutputRows") + s""" + |${data.code} + |$initMapData + |int $numElements = ${data.isNull} ? 0 : ${data.value}.numElements(); + |for (int $index = $init; $index < $numElements; $index++) { + | $numOutput.add(1); + | $updateRowData + | ${consume(ctx, input ++ position ++ values)} + |} + """.stripMargin + } + + /** + * Generate code for a regular [[TraversableOnce]] returning [[Generator]]. + */ + private def codeGenTraversableOnce( + ctx: CodegenContext, + e: Expression, + input: Seq[ExprCode], + row: ExprCode): String = { + + // Generate the code for the generator + val data = e.genCode(ctx) + + // Generate looping variables. + val iterator = ctx.freshName("iterator") + val hasNext = ctx.freshName("hasNext") + val current = ctx.freshName("row") + + // Add a check if the generate outer flag is true. + val checks = optionalCode(outer, s"!$hasNext") + val values = e.dataType match { + case ArrayType(st: StructType, nullable) => + st.fields.toSeq.zipWithIndex.map { case (f, i) => + codeGenAccessor(ctx, current, f.name, s"$i", f.dataType, f.nullable, checks) + } + } + + // In case of outer=true we need to make sure the loop is executed at-least-once when the + // iterator contains no input. We do this by adding an 'outer' variable which guarantees + // execution of the first iteration even if there is no input. Evaluation of the iterator is + // prevented by checks in the next() and accessor code. + val numOutput = metricTerm(ctx, "numOutputRows") + if (outer) { + val outerVal = ctx.freshName("outer") + s""" + |${data.code} + |scala.collection.Iterator $iterator = ${data.value}.toIterator(); + |boolean $outerVal = true; + |while ($iterator.hasNext() || $outerVal) { + | $numOutput.add(1); + | boolean $hasNext = $iterator.hasNext(); + | InternalRow $current = (InternalRow)($hasNext? $iterator.next() : null); + | $outerVal = false; + | ${consume(ctx, input ++ values)} + |} + """.stripMargin + } else { + s""" + |${data.code} + |scala.collection.Iterator $iterator = ${data.value}.toIterator(); + |while ($iterator.hasNext()) { + | $numOutput.add(1); + | InternalRow $current = (InternalRow)($iterator.next()); + | ${consume(ctx, input ++ values)} + |} + """.stripMargin + } + } + + /** + * Generate accessor code for ArrayData and InternalRows. + */ + private def codeGenAccessor( + ctx: CodegenContext, + source: String, + name: String, + index: String, + dt: DataType, + nullable: Boolean, + initialChecks: Seq[String]): ExprCode = { + val value = ctx.freshName(name) + val javaType = ctx.javaType(dt) + val getter = ctx.getValue(source, dt, index) + val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)") + if (checks.nonEmpty) { + val isNull = ctx.freshName("isNull") + val code = + s""" + |boolean $isNull = ${checks.mkString(" || ")}; + |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter; + """.stripMargin + ExprCode(code, isNull, value) + } else { + ExprCode(s"$javaType $value = $getter;", "false", value) + } + } + + private def optionalCode(condition: Boolean, code: => String): Seq[String] = { + if (condition) Seq(code) + else Seq.empty + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index aedc0a8d6f70..f0995ea1d002 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, Generator} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructType} class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -202,4 +206,34 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"), Row(1) :: Row(2) :: Nil) } + + test("SPARK-14986: Outer lateral view with empty generate expression") { + checkAnswer( + sql("select nil from values 1 lateral view outer explode(array()) n as nil"), + Row(null) :: Nil + ) + } + + test("outer explode()") { + checkAnswer( + sql("select * from values 1, 2 lateral view outer explode(array()) a as b"), + Row(1, null) :: Row(2, null) :: Nil) + } + + test("outer generator()") { + spark.sessionState.functionRegistry.registerFunction("empty_gen", _ => EmptyGenerator()) + checkAnswer( + sql("select * from values 1, 2 lateral view outer empty_gen() a as b"), + Row(1, null) :: Row(2, null) :: Nil) + } +} + +case class EmptyGenerator() extends Generator { + override def children: Seq[Expression] = Nil + override def elementSchema: StructType = new StructType().add("id", IntegerType) + override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val iteratorClass = classOf[Iterator[_]].getName + ev.copy(code = s"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6b517bc70f7d..a715176d55d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2086,13 +2086,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("SPARK-14986: Outer lateral view with empty generate expression") { - checkAnswer( - sql("select nil from (select 1 as x ) x lateral view outer explode(array()) n as nil"), - Row(null) :: Nil - ) - } - test("data source table created in InMemoryCatalog should be able to read/write") { withTable("tbl") { sql("CREATE TABLE tbl(i INT, j STRING) USING parquet") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index f26e5e7b6990..e8ea7758cf59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Column, Dataset, Row} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.expressions.scalalang.typed @@ -113,4 +115,32 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) } + + test("generate should be included in WholeStageCodegen") { + import org.apache.spark.sql.functions._ + val ds = spark.range(2).select( + col("id"), + explode(array(col("id") + 1, col("id") + 2)).as("value")) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[GenerateExec]).isDefined) + assert(ds.collect() === Array(Row(0, 1), Row(0, 2), Row(1, 2), Row(1, 3))) + } + + test("large stack generator should not use WholeStageCodegen") { + def createStackGenerator(rows: Int): SparkPlan = { + val id = UnresolvedAttribute("id") + val stack = Stack(Literal(rows) +: Seq.tabulate(rows)(i => Add(id, Literal(i)))) + spark.range(500).select(Column(stack)).queryExecution.executedPlan + } + val isCodeGenerated: SparkPlan => Boolean = { + case WholeStageCodegenExec(_: GenerateExec) => true + case _ => false + } + + // Only 'stack' generators that produce 50 rows or less are code generated. + assert(createStackGenerator(50).find(isCodeGenerated).isDefined) + assert(createStackGenerator(100).find(isCodeGenerated).isEmpty) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala index 470c78120b19..01773c238b0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala @@ -102,7 +102,7 @@ class MiscBenchmark extends BenchmarkBase { } benchmark.run() - /** + /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- @@ -124,7 +124,7 @@ class MiscBenchmark extends BenchmarkBase { } benchmark.run() - /** + /* model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- @@ -132,4 +132,99 @@ class MiscBenchmark extends BenchmarkBase { collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X */ } + + ignore("generate explode") { + val N = 1 << 24 + runBenchmark("generate explode array", N) { + val df = sparkSession.range(N).selectExpr( + "id as key", + "array(rand(), rand(), rand(), rand(), rand()) as values") + df.selectExpr("key", "explode(values) value").count() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 + Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + + generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate explode array wholestage off 6920 / 7129 2.4 412.5 1.0X + generate explode array wholestage on 623 / 646 26.9 37.1 11.1X + */ + + runBenchmark("generate explode map", N) { + val df = sparkSession.range(N).selectExpr( + "id as key", + "map('a', rand(), 'b', rand(), 'c', rand(), 'd', rand(), 'e', rand()) pairs") + df.selectExpr("key", "explode(pairs) as (k, v)").count() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 + Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + + generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate explode map wholestage off 11978 / 11993 1.4 714.0 1.0X + generate explode map wholestage on 866 / 919 19.4 51.6 13.8X + */ + + runBenchmark("generate posexplode array", N) { + val df = sparkSession.range(N).selectExpr( + "id as key", + "array(rand(), rand(), rand(), rand(), rand()) as values") + df.selectExpr("key", "posexplode(values) as (idx, value)").count() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 + Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + + generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate posexplode array wholestage off 7502 / 7513 2.2 447.1 1.0X + generate posexplode array wholestage on 617 / 623 27.2 36.8 12.2X + */ + + runBenchmark("generate inline array", N) { + val df = sparkSession.range(N).selectExpr( + "id as key", + "array((rand(), rand()), (rand(), rand()), (rand(), 0.0d)) as values") + df.selectExpr("key", "inline(values) as (r1, r2)").count() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 + Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + + generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate inline array wholestage off 6901 / 6928 2.4 411.3 1.0X + generate inline array wholestage on 1001 / 1010 16.8 59.7 6.9X + */ + } + + ignore("generate regular generator") { + val N = 1 << 24 + runBenchmark("generate stack", N) { + val df = sparkSession.range(N).selectExpr( + "id as key", + "id % 2 as t1", + "id % 3 as t2", + "id % 5 as t3", + "id % 7 as t4", + "id % 13 as t5") + df.selectExpr("key", "stack(4, t1, t2, t3, t4, t5)").count() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 + Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + + generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + generate stack wholestage off 12953 / 13070 1.3 772.1 1.0X + generate stack wholestage on 836 / 847 20.1 49.8 15.5X + */ + } } From c528812ce770fd8a6626e7f9d2f8ca9d1e84642b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 20 Nov 2016 09:52:03 +0000 Subject: [PATCH 268/381] [SPARK-3359][BUILD][DOCS] Print examples and disable group and tparam tags in javadoc ## What changes were proposed in this pull request? This PR proposes/fixes two things. - Remove many errors to generate javadoc with Java8 from unrecognisable tags, `tparam` and `group`. ``` [error] .../spark/mllib/target/java/org/apache/spark/ml/classification/Classifier.java:18: error: unknown tag: group [error] /** group setParam */ [error] ^ [error] .../spark/mllib/target/java/org/apache/spark/ml/classification/Classifier.java:8: error: unknown tag: tparam [error] * tparam FeaturesType Type of input features. E.g., Vector [error] ^ ... ``` It does not fully resolve the problem but remove many errors. It seems both `group` and `tparam` are unrecognisable in javadoc. It seems we can't print them pretty in javadoc in a way of `example` here because they appear differently (both examples can be found in http://spark.apache.org/docs/2.0.2/api/scala/index.html#org.apache.spark.ml.classification.Classifier). - Print `example` in javadoc. Currently, there are few `example` tag in several places. ``` ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This operation might be used to evaluate a graph ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example We might use this operation to change the vertex values ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This function might be used to initialize edge ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This function might be used to initialize edge ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This function might be used to initialize edge ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example We can use this function to compute the in-degree of each ./graphx/src/main/scala/org/apache/spark/graphx/Graph.scala: * example This function is used to update the vertices with new values based on external data. ./graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala: * example Loads a file in the following format: ./graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala: * example This function is used to update the vertices with new ./graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala: * example This function can be used to filter the graph based on some property, without ./graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala: * example We can use the Pregel abstraction to implement PageRank: ./graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala: * example Construct a `VertexRDD` from a plain RDD: ./repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala: * example new SparkCommandLine(Nil).settings ./repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala: * example addImports("org.apache.spark.SparkContext") ./sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala: * example {{{ ``` **Before** 2016-11-20 2 43 23 **After** 2016-11-20 1 27 17 ## How was this patch tested? Maunally tested by `jekyll build` with Java 7 and 8 ``` java version "1.7.0_80" Java(TM) SE Runtime Environment (build 1.7.0_80-b15) Java HotSpot(TM) 64-Bit Server VM (build 24.80-b11, mixed mode) ``` ``` java version "1.8.0_45" Java(TM) SE Runtime Environment (build 1.8.0_45-b14) Java HotSpot(TM) 64-Bit Server VM (build 25.45-b02, mixed mode) ``` Note: this does not make sbt unidoc suceed with Java 8 yet but it reduces the number of errors with Java 8. Author: hyukjinkwon Closes #15939 from HyukjinKwon/SPARK-3359-javadoc. --- pom.xml | 13 +++++++++++++ project/SparkBuild.scala | 5 ++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 024b2850d0a3..7c0b0b59dc62 100644 --- a/pom.xml +++ b/pom.xml @@ -2477,11 +2477,24 @@ -Xdoclint:all -Xdoclint:-missing + + example + a + Example: + note a Note: + + group + X + + + tparam + X + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 92b45657210e..429a163d22a6 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -742,7 +742,10 @@ object Unidoc { "-windowtitle", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " JavaDoc", "-public", "-noqualifier", "java.lang", - "-tag", """note:a:Note\:""" + "-tag", """example:a:Example\:""", + "-tag", """note:a:Note\:""", + "-tag", "group:X", + "-tag", "tparam:X" ), // Use GitHub repository for Scaladoc source links From 6659ae555a464c7a16881b660265061481c0d25f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 20 Nov 2016 13:56:08 -0800 Subject: [PATCH 269/381] Fix Mesos build break for Scala 2.10. --- .../deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala index b6c0b325361d..33e7d69d53d3 100644 --- a/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala +++ b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala @@ -55,7 +55,7 @@ class MesosClusterDispatcherArgumentsSuite extends SparkFunSuite assert(Option(mesosDispClusterArgs.masterUrl).isDefined) assert(mesosDispClusterArgs.masterUrl == masterUrl.stripPrefix("mesos://")) assert(Option(mesosDispClusterArgs.zookeeperUrl).isDefined) - assert(mesosDispClusterArgs.zookeeperUrl contains zookeeperUrl) + assert(mesosDispClusterArgs.zookeeperUrl == Some(zookeeperUrl)) assert(mesosDispClusterArgs.name == name) assert(mesosDispClusterArgs.webUiPort == webUiPort.toInt) assert(mesosDispClusterArgs.port == port.toInt) From b625a36ebc59cbacc223fc03005bc0f6d296b6e7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 20 Nov 2016 20:00:59 -0800 Subject: [PATCH 270/381] [HOTFIX][SQL] Fix DDLSuite failure. --- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index a01073987423..02d9d1568490 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1426,8 +1426,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("DESCRIBE FUNCTION 'concat'"), Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") :: Row("Function: concat") :: - Row("Usage: concat(str1, str2, ..., strN) " + - "- Returns the concatenation of `str1`, `str2`, ..., `strN`.") :: Nil + Row("Usage: concat(str1, str2, ..., strN) - " + + "Returns the concatenation of str1, str2, ..., strN.") :: Nil ) // extended mode checkAnswer( From 658547974915ebcaae83e13e4c3bdf68d5426fda Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 21 Nov 2016 12:05:01 +0800 Subject: [PATCH 271/381] [SPARK-18467][SQL] Extracts method for preparing arguments from StaticInvoke, Invoke and NewInstance and modify to short circuit if arguments have null when `needNullCheck == true`. ## What changes were proposed in this pull request? This pr extracts method for preparing arguments from `StaticInvoke`, `Invoke` and `NewInstance` and modify to short circuit if arguments have `null` when `propageteNull == true`. The steps are as follows: 1. Introduce `InvokeLike` to extract common logic from `StaticInvoke`, `Invoke` and `NewInstance` to prepare arguments. `StaticInvoke` and `Invoke` had a risk to exceed 64kb JVM limit to prepare arguments but after this patch they can handle them because they share the preparing code of NewInstance, which handles the limit well. 2. Remove unneeded null checking and fix nullability of `NewInstance`. Avoid some of nullabilty checking which are not needed because the expression is not nullable. 3. Modify to short circuit if arguments have `null` when `needNullCheck == true`. If `needNullCheck == true`, preparing arguments can be skipped if we found one of them is `null`, so modified to short circuit in the case. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #15901 from ueshin/issues/SPARK-18467. --- .../expressions/objects/objects.scala | 163 +++++++++++------- 1 file changed, 101 insertions(+), 62 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 0e3d99127ed5..0b36091ece1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -32,6 +32,78 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ +/** + * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]]. + */ +trait InvokeLike extends Expression with NonSQLExpression { + + def arguments: Seq[Expression] + + def propagateNull: Boolean + + protected lazy val needNullCheck: Boolean = propagateNull && arguments.exists(_.nullable) + + /** + * Prepares codes for arguments. + * + * - generate codes for argument. + * - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments. + * - avoid some of nullabilty checking which are not needed because the expression is not + * nullable. + * - when needNullCheck == true, short circuit if we found one of arguments is null because + * preparing rest of arguments can be skipped in the case. + * + * @param ctx a [[CodegenContext]] + * @return (code to prepare arguments, argument string, result of argument null check) + */ + def prepareArguments(ctx: CodegenContext): (String, String, String) = { + + val resultIsNull = if (needNullCheck) { + val resultIsNull = ctx.freshName("resultIsNull") + ctx.addMutableState("boolean", resultIsNull, "") + resultIsNull + } else { + "false" + } + val argValues = arguments.map { e => + val argValue = ctx.freshName("argValue") + ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") + argValue + } + + val argCodes = if (needNullCheck) { + val reset = s"$resultIsNull = false;" + val argCodes = arguments.zipWithIndex.map { case (e, i) => + val expr = e.genCode(ctx) + val updateResultIsNull = if (e.nullable) { + s"$resultIsNull = ${expr.isNull};" + } else { + "" + } + s""" + if (!$resultIsNull) { + ${expr.code} + $updateResultIsNull + ${argValues(i)} = ${expr.value}; + } + """ + } + reset +: argCodes + } else { + arguments.zipWithIndex.map { case (e, i) => + val expr = e.genCode(ctx) + s""" + ${expr.code} + ${argValues(i)} = ${expr.value}; + """ + } + } + val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes) + + (argCode, argValues.mkString(", "), resultIsNull) + } +} + /** * Invokes a static function, returning the result. By default, any of the arguments being null * will result in returning null instead of calling the function. @@ -50,7 +122,7 @@ case class StaticInvoke( dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends Expression with NonSQLExpression { + propagateNull: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") @@ -62,16 +134,10 @@ case class StaticInvoke( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val argGen = arguments.map(_.genCode(ctx)) - val argString = argGen.map(_.value).mkString(", ") - val callFunc = s"$objectName.$functionName($argString)" + val (argCode, argString, resultIsNull) = prepareArguments(ctx) - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};" - } else { - s"boolean ${ev.isNull} = false;" - } + val callFunc = s"$objectName.$functionName($argString)" // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. @@ -82,9 +148,9 @@ case class StaticInvoke( } val code = s""" - ${argGen.map(_.code).mkString("\n")} - $setIsNull - final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc; + $argCode + boolean ${ev.isNull} = $resultIsNull; + final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc; $postNullCheck """ ev.copy(code = code) @@ -103,13 +169,15 @@ case class StaticInvoke( * @param functionName The name of the method to call. * @param dataType The expected return type of the function. * @param arguments An optional list of expressions, whos evaluation will be passed to the function. + * @param propagateNull When true, and any of the arguments is null, null will be returned instead + * of calling the function. */ case class Invoke( targetObject: Expression, functionName: String, dataType: DataType, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends Expression with NonSQLExpression { + propagateNull: Boolean = true) extends InvokeLike { override def nullable: Boolean = true override def children: Seq[Expression] = targetObject +: arguments @@ -131,8 +199,8 @@ case class Invoke( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val obj = targetObject.genCode(ctx) - val argGen = arguments.map(_.genCode(ctx)) - val argString = argGen.map(_.value).mkString(", ") + + val (argCode, argString, resultIsNull) = prepareArguments(ctx) val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty @@ -164,12 +232,6 @@ case class Invoke( """ } - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};" - } else { - s"boolean ${ev.isNull} = ${obj.isNull};" - } - // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. val postNullCheck = if (ctx.defaultValue(dataType) == "null") { @@ -177,15 +239,19 @@ case class Invoke( } else { "" } + val code = s""" ${obj.code} - ${argGen.map(_.code).mkString("\n")} - $setIsNull + boolean ${ev.isNull} = true; $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - $evaluate + if (!${obj.isNull}) { + $argCode + ${ev.isNull} = $resultIsNull; + if (!${ev.isNull}) { + $evaluate + } + $postNullCheck } - $postNullCheck """ ev.copy(code = code) } @@ -223,10 +289,10 @@ case class NewInstance( arguments: Seq[Expression], propagateNull: Boolean, dataType: DataType, - outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression { + outerPointer: Option[() => AnyRef]) extends InvokeLike { private val className = cls.getName - override def nullable: Boolean = propagateNull + override def nullable: Boolean = needNullCheck override def children: Seq[Expression] = arguments @@ -245,52 +311,25 @@ case class NewInstance( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val argIsNulls = ctx.freshName("argIsNulls") - ctx.addMutableState("boolean[]", argIsNulls, - s"$argIsNulls = new boolean[${arguments.size}];") - val argValues = arguments.zipWithIndex.map { case (e, i) => - val argValue = ctx.freshName("argValue") - ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") - argValue - } - val argCodes = arguments.zipWithIndex.map { case (e, i) => - val expr = e.genCode(ctx) - expr.code + s""" - $argIsNulls[$i] = ${expr.isNull}; - ${argValues(i)} = ${expr.value}; - """ - } - val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes) + val (argCode, argString, resultIsNull) = prepareArguments(ctx) val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) - var isNull = ev.isNull - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s""" - boolean $isNull = false; - for (int idx = 0; idx < ${arguments.length}; idx++) { - if ($argIsNulls[idx]) { $isNull = true; break; } - } - """ - } else { - isNull = "false" - "" - } + ev.isNull = resultIsNull val constructorCall = outer.map { gen => - s"""${gen.value}.new ${cls.getSimpleName}(${argValues.mkString(", ")})""" + s"${gen.value}.new ${cls.getSimpleName}($argString)" }.getOrElse { - s"new $className(${argValues.mkString(", ")})" + s"new $className($argString)" } val code = s""" $argCode ${outer.map(_.code).getOrElse("")} - $setIsNull - final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall; - """ - ev.copy(code = code, isNull = isNull) + final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; + """ + ev.copy(code = code) } override def toString: String = s"newInstance($cls)" From e811fbf9ed131bccbc46f3c5701c4ff317222fd9 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 21 Nov 2016 05:36:49 -0800 Subject: [PATCH 272/381] [SPARK-18282][ML][PYSPARK] Add python clustering summaries for GMM and BKM ## What changes were proposed in this pull request? Add model summary APIs for `GaussianMixtureModel` and `BisectingKMeansModel` in pyspark. ## How was this patch tested? Unit tests. Author: sethah Closes #15777 from sethah/pyspark_cluster_summaries. --- .../classification/LogisticRegression.scala | 11 +- .../spark/ml/clustering/BisectingKMeans.scala | 9 +- .../spark/ml/clustering/GaussianMixture.scala | 9 +- .../apache/spark/ml/clustering/KMeans.scala | 9 +- .../GeneralizedLinearRegression.scala | 11 +- .../ml/regression/LinearRegression.scala | 14 +- .../LogisticRegressionSuite.scala | 2 + .../ml/clustering/BisectingKMeansSuite.scala | 3 + .../ml/clustering/GaussianMixtureSuite.scala | 3 + .../spark/ml/clustering/KMeansSuite.scala | 3 + .../GeneralizedLinearRegressionSuite.scala | 2 + .../ml/regression/LinearRegressionSuite.scala | 2 + python/pyspark/ml/classification.py | 15 +- python/pyspark/ml/clustering.py | 162 +++++++++++++++++- python/pyspark/ml/regression.py | 16 +- python/pyspark/ml/tests.py | 32 ++++ 16 files changed, 256 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f58efd36a1c6..d07b4adebb08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -648,7 +648,7 @@ class LogisticRegression @Since("1.2.0") ( $(labelCol), $(featuresCol), objectiveHistory) - model.setSummary(logRegSummary) + model.setSummary(Some(logRegSummary)) } else { model } @@ -790,9 +790,9 @@ class LogisticRegressionModel private[spark] ( } } - private[classification] def setSummary( - summary: LogisticRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + private[classification] + def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -887,8 +887,7 @@ class LogisticRegressionModel private[spark] ( override def copy(extra: ParamMap): LogisticRegressionModel = { val newModel = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector, numClasses, isMultinomial), extra) - if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel.setParent(parent) + newModel.setSummary(trainingSummary).setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f8a606d60b2a..e6ca3aedffd9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -95,8 +95,7 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(this.parent) + copied.setSummary(trainingSummary).setParent(this.parent) } @Since("2.0.0") @@ -132,8 +131,8 @@ class BisectingKMeansModel private[ml] ( private var trainingSummary: Option[BisectingKMeansSummary] = None - private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[BisectingKMeansSummary]): this.type = { + this.trainingSummary = summary this } @@ -265,7 +264,7 @@ class BisectingKMeans @Since("2.0.0") ( val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index c6035cc4c964..92d0b7d085f1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -90,8 +90,7 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): GaussianMixtureModel = { val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(this.parent) + copied.setSummary(trainingSummary).setParent(this.parent) } @Since("2.0.0") @@ -150,8 +149,8 @@ class GaussianMixtureModel private[ml] ( private var trainingSummary: Option[GaussianMixtureSummary] = None - private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[GaussianMixtureSummary]): this.type = { + this.trainingSummary = summary this } @@ -340,7 +339,7 @@ class GaussianMixture @Since("2.0.0") ( .setParent(this) val summary = new GaussianMixtureSummary(model.transform(dataset), $(predictionCol), $(probabilityCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logNumFeatures(model.gaussians.head.mean.size) instr.logSuccess(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 26505b4cc150..152bd13b7a17 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -110,8 +110,7 @@ class KMeansModel private[ml] ( @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { val copied = copyValues(new KMeansModel(uid, parentModel), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(this.parent) + copied.setSummary(trainingSummary).setParent(this.parent) } /** @group setParam */ @@ -165,8 +164,8 @@ class KMeansModel private[ml] ( private var trainingSummary: Option[KMeansSummary] = None - private[clustering] def setSummary(summary: KMeansSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = { + this.trainingSummary = summary this } @@ -325,7 +324,7 @@ class KMeans @Since("1.5.0") ( val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 736fd3b9e0f6..3f9de1fe74c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -270,7 +270,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val .setParent(this)) val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, wlsModel.diagInvAtWA.toArray, 1, getSolver) - return model.setSummary(trainingSummary) + return model.setSummary(Some(trainingSummary)) } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). @@ -284,7 +284,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val .setParent(this)) val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver) - model.setSummary(trainingSummary) + model.setSummary(Some(trainingSummary)) } @Since("2.0.0") @@ -761,8 +761,8 @@ class GeneralizedLinearRegressionModel private[ml] ( def hasSummary: Boolean = trainingSummary.nonEmpty private[regression] - def setSummary(summary: GeneralizedLinearRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -778,8 +778,7 @@ class GeneralizedLinearRegressionModel private[ml] ( override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(parent) + copied.setSummary(trainingSummary).setParent(parent) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index da7ce6b46f2a..8ea5e1e6c453 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -225,7 +225,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model.diagInvAtWA.toArray, model.objectiveHistory) - return lrModel.setSummary(trainingSummary) + return lrModel.setSummary(Some(trainingSummary)) } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -278,7 +278,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model, Array(0D), Array(0D)) - return model.setSummary(trainingSummary) + return model.setSummary(Some(trainingSummary)) } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") @@ -400,7 +400,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model, Array(0D), objectiveHistory) - model.setSummary(trainingSummary) + model.setSummary(Some(trainingSummary)) } @Since("1.4.0") @@ -446,8 +446,9 @@ class LinearRegressionModel private[ml] ( throw new SparkException("No training summary available for this LinearRegressionModel") } - private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + private[regression] + def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -490,8 +491,7 @@ class LinearRegressionModel private[ml] ( @Since("1.4.0") override def copy(extra: ParamMap): LinearRegressionModel = { val newModel = copyValues(new LinearRegressionModel(uid, coefficients, intercept), extra) - if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel.setParent(parent) + newModel.setSummary(trainingSummary).setParent(parent) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 2877285eb4d5..e360542eae2a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -147,6 +147,8 @@ class LogisticRegressionSuite assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) } test("empty probabilityCol") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 49797d938d75..fc491cd6161f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -109,6 +109,9 @@ class BisectingKMeansSuite assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 7165b63ed3b9..07299123f8a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -111,6 +111,9 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 73972557d263..c1b7242e11a8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -123,6 +123,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("KMeansModel transform with non-default feature and prediction cols") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 6a4ac1735b2c..9b0fa67630d2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -197,6 +197,8 @@ class GeneralizedLinearRegressionSuite assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) assert(model.getFeaturesCol === "features") assert(model.getPredictionCol === "prediction") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index df97d0b2ae7a..0be82742a33b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -146,6 +146,8 @@ class LinearRegressionSuite assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) model.transform(datasetWithDenseFeature) .select("label", "prediction") diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 56c8c62259e7..83e1e8934766 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -309,13 +309,16 @@ def interceptVector(self): @since("2.0.0") def summary(self): """ - Gets summary (e.g. residuals, mse, r-squared ) of model on - training set. An exception is thrown if - `trainingSummary is None`. + Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model + trained on the training set. An exception is thrown if `trainingSummary is None`. """ - java_blrt_summary = self._call_java("summary") - # Note: Once multiclass is added, update this to return correct summary - return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) + if self.hasSummary: + java_blrt_summary = self._call_java("summary") + # Note: Once multiclass is added, update this to return correct summary + return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) @property @since("2.0.0") diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 7632f05c3b68..e58ec1e7ac29 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -17,16 +17,74 @@ from pyspark import since, keyword_only from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.ml.param.shared import * from pyspark.ml.common import inherit_doc -__all__ = ['BisectingKMeans', 'BisectingKMeansModel', +__all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary', 'KMeans', 'KMeansModel', - 'GaussianMixture', 'GaussianMixtureModel', + 'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary', 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel'] +class ClusteringSummary(JavaWrapper): + """ + .. note:: Experimental + + Clustering results for a given model. + + .. versionadded:: 2.1.0 + """ + + @property + @since("2.1.0") + def predictionCol(self): + """ + Name for column of predicted clusters in `predictions`. + """ + return self._call_java("predictionCol") + + @property + @since("2.1.0") + def predictions(self): + """ + DataFrame produced by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.1.0") + def featuresCol(self): + """ + Name for column of features in `predictions`. + """ + return self._call_java("featuresCol") + + @property + @since("2.1.0") + def k(self): + """ + The number of clusters the model was trained with. + """ + return self._call_java("k") + + @property + @since("2.1.0") + def cluster(self): + """ + DataFrame of predicted cluster centers for each training data point. + """ + return self._call_java("cluster") + + @property + @since("2.1.0") + def clusterSizes(self): + """ + Size of (number of data points in) each cluster. + """ + return self._call_java("clusterSizes") + + class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -56,6 +114,28 @@ def gaussiansDF(self): """ return self._call_java("gaussiansDF") + @property + @since("2.1.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @property + @since("2.1.0") + def summary(self): + """ + Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the + training set. An exception is thrown if no summary exists. + """ + if self.hasSummary: + return GaussianMixtureSummary(self._call_java("summary")) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) + @inherit_doc class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, @@ -92,6 +172,13 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> gm = GaussianMixture(k=3, tol=0.0001, ... maxIter=10, seed=10) >>> model = gm.fit(df) + >>> model.hasSummary + True + >>> summary = model.summary + >>> summary.k + 3 + >>> summary.clusterSizes + [2, 2, 2] >>> weights = model.weights >>> len(weights) 3 @@ -118,6 +205,8 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> model_path = temp_path + "/gmm_model" >>> model.save(model_path) >>> model2 = GaussianMixtureModel.load(model_path) + >>> model2.hasSummary + False >>> model2.weights == model.weights True >>> model2.gaussiansDF.show() @@ -181,6 +270,32 @@ def getK(self): return self.getOrDefault(self.k) +class GaussianMixtureSummary(ClusteringSummary): + """ + .. note:: Experimental + + Gaussian mixture clustering results for a given model. + + .. versionadded:: 2.1.0 + """ + + @property + @since("2.1.0") + def probabilityCol(self): + """ + Name for column of predicted probability of each cluster in `predictions`. + """ + return self._call_java("probabilityCol") + + @property + @since("2.1.0") + def probability(self): + """ + DataFrame of probabilities of each cluster for each training data point. + """ + return self._call_java("probability") + + class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by KMeans. @@ -346,6 +461,27 @@ def computeCost(self, dataset): """ return self._call_java("computeCost", dataset) + @property + @since("2.1.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model instance. + """ + return self._call_java("hasSummary") + + @property + @since("2.1.0") + def summary(self): + """ + Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the + training set. An exception is thrown if no summary exists. + """ + if self.hasSummary: + return BisectingKMeansSummary(self._call_java("summary")) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) + @inherit_doc class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed, @@ -373,6 +509,13 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte 2 >>> model.computeCost(df) 2.000... + >>> model.hasSummary + True + >>> summary = model.summary + >>> summary.k + 2 + >>> summary.clusterSizes + [2, 2] >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction @@ -387,6 +530,8 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> model_path = temp_path + "/bkm_model" >>> model.save(model_path) >>> model2 = BisectingKMeansModel.load(model_path) + >>> model2.hasSummary + False >>> model.clusterCenters()[0] == model2.clusterCenters()[0] array([ True, True], dtype=bool) >>> model.clusterCenters()[1] == model2.clusterCenters()[1] @@ -460,6 +605,17 @@ def _create_model(self, java_model): return BisectingKMeansModel(java_model) +class BisectingKMeansSummary(ClusteringSummary): + """ + .. note:: Experimental + + Bisecting KMeans clustering results for a given model. + + .. versionadded:: 2.1.0 + """ + pass + + @inherit_doc class LDAModel(JavaModel): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 0bc319ca4d60..385391ba53fd 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -160,8 +160,12 @@ def summary(self): training set. An exception is thrown if `trainingSummary is None`. """ - java_lrt_summary = self._call_java("summary") - return LinearRegressionTrainingSummary(java_lrt_summary) + if self.hasSummary: + java_lrt_summary = self._call_java("summary") + return LinearRegressionTrainingSummary(java_lrt_summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) @property @since("2.0.0") @@ -1459,8 +1463,12 @@ def summary(self): training set. An exception is thrown if `trainingSummary is None`. """ - java_glrt_summary = self._call_java("summary") - return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary) + if self.hasSummary: + java_glrt_summary = self._call_java("summary") + return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) @property @since("2.0.0") diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 9d46cc3b4ae6..c0f0d4073564 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1097,6 +1097,38 @@ def test_logistic_regression_summary(self): sameSummary = model.evaluate(df) self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + def test_gaussian_mixture_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + gmm = GaussianMixture(k=2) + model = gmm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertTrue(isinstance(s.probability, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + + def test_bisecting_kmeans_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + bkm = BisectingKMeans(k=2) + model = bkm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + class OneVsRestTests(SparkSessionTestCase): From 9f262ae163b6dca6526665b3ad12b3b2ea8fb873 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 21 Nov 2016 05:50:35 -0800 Subject: [PATCH 273/381] [SPARK-18398][SQL] Fix nullabilities of MapObjects and ExternalMapToCatalyst. ## What changes were proposed in this pull request? The nullabilities of `MapObject` can be made more strict by relying on `inputObject.nullable` and `lambdaFunction.nullable`. Also `ExternalMapToCatalyst.dataType` can be made more strict by relying on `valueConverter.nullable`. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #15840 from ueshin/issues/SPARK-18398. --- .../spark/sql/catalyst/expressions/objects/objects.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 0b36091ece1b..5c27179ec3b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -461,14 +461,15 @@ case class MapObjects private( lambdaFunction: Expression, inputData: Expression) extends Expression with NonSQLExpression { - override def nullable: Boolean = true + override def nullable: Boolean = inputData.nullable override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def dataType: DataType = ArrayType(lambdaFunction.dataType) + override def dataType: DataType = + ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) @@ -642,7 +643,8 @@ case class ExternalMapToCatalyst private( override def foldable: Boolean = false - override def dataType: MapType = MapType(keyConverter.dataType, valueConverter.dataType) + override def dataType: MapType = MapType( + keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") From 07beb5d21c6803e80733149f1560c71cd3cacc86 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 21 Nov 2016 13:57:36 +0000 Subject: [PATCH 274/381] [SPARK-18413][SQL] Add `maxConnections` JDBCOption ## What changes were proposed in this pull request? This PR adds a new JDBCOption `maxConnections` which means the maximum number of simultaneous JDBC connections allowed. This option applies only to writing with coalesce operation if needed. It defaults to the number of partitions of RDD. Previously, SQL users cannot cannot control this while Scala/Java/Python users can use `coalesce` (or `repartition`) API. **Reported Scenario** For the following cases, the number of connections becomes 200 and database cannot handle all of them. ```sql CREATE OR REPLACE TEMPORARY VIEW resultview USING org.apache.spark.sql.jdbc OPTIONS ( url "jdbc:oracle:thin:10.129.10.111:1521:BKDB", dbtable "result", user "HIVE", password "HIVE" ); -- set spark.sql.shuffle.partitions=200 INSERT OVERWRITE TABLE resultview SELECT g, count(1) AS COUNT FROM tnet.DT_LIVE_INFO GROUP BY g ``` ## How was this patch tested? Manual. Do the followings and see Spark UI. **Step 1 (MySQL)** ``` CREATE TABLE t1 (a INT); CREATE TABLE data (a INT); INSERT INTO data VALUES (1); INSERT INTO data VALUES (2); INSERT INTO data VALUES (3); ``` **Step 2 (Spark)** ```scala SPARK_HOME=$PWD bin/spark-shell --driver-memory 4G --driver-class-path mysql-connector-java-5.1.40-bin.jar scala> sql("SET spark.sql.shuffle.partitions=3") scala> sql("CREATE OR REPLACE TEMPORARY VIEW data USING org.apache.spark.sql.jdbc OPTIONS (url 'jdbc:mysql://localhost:3306/t', dbtable 'data', user 'root', password '')") scala> sql("CREATE OR REPLACE TEMPORARY VIEW t1 USING org.apache.spark.sql.jdbc OPTIONS (url 'jdbc:mysql://localhost:3306/t', dbtable 't1', user 'root', password '', maxConnections '1')") scala> sql("INSERT OVERWRITE TABLE t1 SELECT a FROM data GROUP BY a") scala> sql("CREATE OR REPLACE TEMPORARY VIEW t1 USING org.apache.spark.sql.jdbc OPTIONS (url 'jdbc:mysql://localhost:3306/t', dbtable 't1', user 'root', password '', maxConnections '2')") scala> sql("INSERT OVERWRITE TABLE t1 SELECT a FROM data GROUP BY a") scala> sql("CREATE OR REPLACE TEMPORARY VIEW t1 USING org.apache.spark.sql.jdbc OPTIONS (url 'jdbc:mysql://localhost:3306/t', dbtable 't1', user 'root', password '', maxConnections '3')") scala> sql("INSERT OVERWRITE TABLE t1 SELECT a FROM data GROUP BY a") scala> sql("CREATE OR REPLACE TEMPORARY VIEW t1 USING org.apache.spark.sql.jdbc OPTIONS (url 'jdbc:mysql://localhost:3306/t', dbtable 't1', user 'root', password '', maxConnections '4')") scala> sql("INSERT OVERWRITE TABLE t1 SELECT a FROM data GROUP BY a") ``` ![maxconnections](https://cloud.githubusercontent.com/assets/9700541/20287987/ed8409c2-aa84-11e6-8aab-ae28e63fe54d.png) Author: Dongjoon Hyun Closes #15868 from dongjoon-hyun/SPARK-18413. --- docs/sql-programming-guide.md | 7 +++++++ .../sql/execution/datasources/jdbc/JDBCOptions.scala | 6 ++++++ .../sql/execution/datasources/jdbc/JdbcUtils.scala | 9 ++++++++- .../org/apache/spark/sql/jdbc/JDBCWriteSuite.scala | 12 ++++++++++++ 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ba3e55fc061a..656e7ecdab0b 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1086,6 +1086,13 @@ the following case-sensitive options: + + maxConnections + + The maximum number of concurrent JDBC connections that can be used, if set. Only applies when writing. It works by limiting the operation's parallelism, which depends on the input's partition count. If its partition count exceeds this limit, the operation will coalesce the input to fewer partitions before writing. + + + isolationLevel diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 7f419b5788c4..d416eec6ddae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -122,6 +122,11 @@ class JDBCOptions( case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE } + // the maximum number of connections + val maxConnections = parameters.get(JDBC_MAX_CONNECTIONS).map(_.toInt) + require(maxConnections.isEmpty || maxConnections.get > 0, + s"Invalid value `${maxConnections.get}` for parameter `$JDBC_MAX_CONNECTIONS`. " + + "The minimum value is 1.") } object JDBCOptions { @@ -144,4 +149,5 @@ object JDBCOptions { val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") + val JDBC_MAX_CONNECTIONS = newOption("maxConnections") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 41edb6511c2c..cdc3c99daa1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -667,7 +667,14 @@ object JdbcUtils extends Logging { val getConnection: () => Connection = createConnectionFactory(options) val batchSize = options.batchSize val isolationLevel = options.isolationLevel - df.foreachPartition(iterator => savePartition( + val maxConnections = options.maxConnections + val repartitionedDF = + if (maxConnections.isDefined && maxConnections.get < df.rdd.getNumPartitions) { + df.coalesce(maxConnections.get) + } else { + df + } + repartitionedDF.foreachPartition(iterator => savePartition( getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e3d3c6c3a887..5795b4d860cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -312,4 +312,16 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { .options(properties.asScala) .save() } + + test("SPARK-18413: Add `maxConnections` JDBCOption") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val e = intercept[IllegalArgumentException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .option("url", url1) + .option(s"${JDBCOptions.JDBC_MAX_CONNECTIONS}", "0") + .save() + }.getMessage + assert(e.contains("Invalid value `0` for parameter `maxConnections`. The minimum value is 1")) + } } From 70176871ae10509f1a727a96e96b3da7762605b1 Mon Sep 17 00:00:00 2001 From: Gabriel Huang Date: Mon, 21 Nov 2016 16:08:34 -0500 Subject: [PATCH 275/381] [SPARK-18361][PYSPARK] Expose RDD localCheckpoint in PySpark ## What changes were proposed in this pull request? Expose RDD's localCheckpoint() and associated functions in PySpark. ## How was this patch tested? I added a UnitTest in python/pyspark/tests.py which passes. I certify that this is my original work, and I license it to the project under the project's open source license. Gabriel HUANG Developer at Cardabel (http://cardabel.com/) Author: Gabriel Huang Closes #15811 from gabrielhuang/pyspark-localcheckpoint. --- python/pyspark/rdd.py | 33 ++++++++++++++++++++++++++++++++- python/pyspark/tests.py | 17 +++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 641787ee20e0..f21a364df910 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -263,13 +263,44 @@ def checkpoint(self): def isCheckpointed(self): """ - Return whether this RDD has been checkpointed or not + Return whether this RDD is checkpointed and materialized, either reliably or locally. """ return self._jrdd.rdd().isCheckpointed() + def localCheckpoint(self): + """ + Mark this RDD for local checkpointing using Spark's existing caching layer. + + This method is for users who wish to truncate RDD lineages while skipping the expensive + step of replicating the materialized data in a reliable distributed file system. This is + useful for RDDs with long lineages that need to be truncated periodically (e.g. GraphX). + + Local checkpointing sacrifices fault-tolerance for performance. In particular, checkpointed + data is written to ephemeral local storage in the executors instead of to a reliable, + fault-tolerant storage. The effect is that if an executor fails during the computation, + the checkpointed data may no longer be accessible, causing an irrecoverable job failure. + + This is NOT safe to use with dynamic allocation, which removes executors along + with their cached blocks. If you must use both features, you are advised to set + L{spark.dynamicAllocation.cachedExecutorIdleTimeout} to a high value. + + The checkpoint directory set through L{SparkContext.setCheckpointDir()} is not used. + """ + self._jrdd.rdd().localCheckpoint() + + def isLocallyCheckpointed(self): + """ + Return whether this RDD is marked for local checkpointing. + + Exposed for testing. + """ + return self._jrdd.rdd().isLocallyCheckpointed() + def getCheckpointFile(self): """ Gets the name of the file to which this RDD was checkpointed + + Not defined if RDD is checkpointed locally. """ checkpointFile = self._jrdd.rdd().getCheckpointFile() if checkpointFile.isDefined(): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 3e0bd16d85ca..ab4bef8329cd 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -390,6 +390,23 @@ def test_checkpoint_and_restore(self): self.assertEqual([1, 2, 3, 4], recovered.collect()) +class LocalCheckpointTests(ReusedPySparkTestCase): + + def test_basic_localcheckpointing(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertFalse(flatMappedRDD.isLocallyCheckpointed()) + + flatMappedRDD.localCheckpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.isLocallyCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + + class AddFileTests(PySparkTestCase): def test_add_py_file(self): From ddd02f50bb7458410d65427321efc75da5e65224 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 21 Nov 2016 16:14:59 -0500 Subject: [PATCH 276/381] [SPARK-18517][SQL] DROP TABLE IF EXISTS should not warn for non-existing tables ## What changes were proposed in this pull request? Currently, `DROP TABLE IF EXISTS` shows warning for non-existing tables. However, it had better be quiet for this case by definition of the command. **BEFORE** ```scala scala> sql("DROP TABLE IF EXISTS nonexist") 16/11/20 20:48:26 WARN DropTableCommand: org.apache.spark.sql.catalyst.analysis.NoSuchTableException: Table or view 'nonexist' not found in database 'default'; ``` **AFTER** ```scala scala> sql("DROP TABLE IF EXISTS nonexist") res0: org.apache.spark.sql.DataFrame = [] ``` ## How was this patch tested? Manual because this is related to the warning messages instead of exceptions. Author: Dongjoon Hyun Closes #15953 from dongjoon-hyun/SPARK-18517. --- .../scala/org/apache/spark/sql/execution/command/ddl.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 588aa05c37b4..d80b000bcc59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryComparison} @@ -203,6 +203,7 @@ case class DropTableCommand( sparkSession.sharedState.cacheManager.uncacheQuery( sparkSession.table(tableName.quotedString)) } catch { + case _: NoSuchTableException if ifExists => case NonFatal(e) => log.warn(e.toString, e) } catalog.refreshTable(tableName) From a2d464770cd183daa7d727bf377bde9c21e29e6a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 21 Nov 2016 13:23:32 -0800 Subject: [PATCH 277/381] [SPARK-17765][SQL] Support for writing out user-defined type in ORC datasource ## What changes were proposed in this pull request? This PR adds the support for `UserDefinedType` when writing out instead of throwing `ClassCastException` in ORC data source. In more details, `OrcStruct` is being created based on string from`DataType.catalogString`. For user-defined type, it seems it returns `sqlType.simpleString` for `catalogString` by default[1]. However, during type-dispatching to match the output with the schema, it tries to cast to, for example, `StructType`[2]. So, running the codes below (`MyDenseVector` was borrowed[3]) : ``` scala val data = Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))) val udtDF = data.toDF("id", "vectors") udtDF.write.orc("/tmp/test.orc") ``` ends up throwing an exception as below: ``` java.lang.ClassCastException: org.apache.spark.sql.UDT$MyDenseVectorUDT cannot be cast to org.apache.spark.sql.types.ArrayType at org.apache.spark.sql.hive.HiveInspectors$class.wrapperFor(HiveInspectors.scala:381) at org.apache.spark.sql.hive.orc.OrcSerializer.wrapperFor(OrcFileFormat.scala:164) ... ``` So, this PR uses `UserDefinedType.sqlType` during finding the correct converter when writing out in ORC data source. [1]https://github.com/apache/spark/blob/dfdcab00c7b6200c22883baa3ebc5818be09556f/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala#L95 [2]https://github.com/apache/spark/blob/d2dc8c4a162834818190ffd82894522c524ca3e5/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala#L326 [3]https://github.com/apache/spark/blob/2bfed1a0c5be7d0718fd574a4dad90f4f6b44be7/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala#L38-L70 ## How was this patch tested? Unit tests in `OrcQuerySuite`. Author: hyukjinkwon Closes #15361 from HyukjinKwon/SPARK-17765. --- .../org/apache/spark/sql/hive/HiveInspectors.scala | 3 +++ .../org/apache/spark/sql/hive/orc/OrcQuerySuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index e303065127c3..52aa1088acd4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -246,6 +246,9 @@ private[hive] trait HiveInspectors { * Wraps with Hive types based on object inspector. */ protected def wrapperFor(oi: ObjectInspector, dataType: DataType): Any => Any = oi match { + case _ if dataType.isInstanceOf[UserDefinedType[_]] => + val sqlType = dataType.asInstanceOf[UserDefinedType[_]].sqlType + wrapperFor(oi, sqlType) case x: ConstantObjectInspector => (o: Any) => x.getWritableConstantValue diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index a628977af2f4..b8761e9de288 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -93,6 +93,16 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } + test("Read/write UserDefinedType") { + withTempPath { path => + val data = Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))) + val udtDF = data.toDF("id", "vectors") + udtDF.write.orc(path.getAbsolutePath) + val readBack = spark.read.schema(udtDF.schema).orc(path.getAbsolutePath) + checkAnswer(udtDF, readBack) + } + } + test("Creating case class RDD table") { val data = (1 to 100).map(i => (i, s"val_$i")) sparkContext.parallelize(data).toDF().createOrReplaceTempView("t") From 97a8239a625df455d2c439f3628a529d6d9413ca Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 21 Nov 2016 17:24:02 -0800 Subject: [PATCH 278/381] [SPARK-18493] Add missing python APIs: withWatermark and checkpoint to dataframe ## What changes were proposed in this pull request? This PR adds two of the newly added methods of `Dataset`s to Python: `withWatermark` and `checkpoint` ## How was this patch tested? Doc tests Author: Burak Yavuz Closes #15921 from brkyvz/py-watermark. --- python/pyspark/sql/dataframe.py | 57 ++++++++++++++++++- .../scala/org/apache/spark/sql/Dataset.scala | 10 +++- 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 38998900837c..6fe622643291 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -322,6 +322,54 @@ def show(self, n=20, truncate=True): def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + @since(2.1) + def checkpoint(self, eager=True): + """Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the + logical plan of this DataFrame, which is especially useful in iterative algorithms where the + plan may grow exponentially. It will be saved to files inside the checkpoint + directory set with L{SparkContext.setCheckpointDir()}. + + :param eager: Whether to checkpoint this DataFrame immediately + + .. note:: Experimental + """ + jdf = self._jdf.checkpoint(eager) + return DataFrame(jdf, self.sql_ctx) + + @since(2.1) + def withWatermark(self, eventTime, delayThreshold): + """Defines an event time watermark for this :class:`DataFrame`. A watermark tracks a point + in time before which we assume no more late data is going to arrive. + + Spark will use this watermark for several purposes: + - To know when a given time window aggregation can be finalized and thus can be emitted + when using output modes that do not allow updates. + + - To minimize the amount of state that we need to keep for on-going aggregations. + + The current watermark is computed by looking at the `MAX(eventTime)` seen across + all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost + of coordinating this value across partitions, the actual watermark used is only guaranteed + to be at least `delayThreshold` behind the actual event time. In some cases we may still + process records that arrive more than `delayThreshold` late. + + :param eventTime: the name of the column that contains the event time of the row. + :param delayThreshold: the minimum delay to wait to data to arrive late, relative to the + latest record that has been processed in the form of an interval + (e.g. "1 minute" or "5 hours"). + + .. note:: Experimental + + >>> sdf.select('name', sdf.time.cast('timestamp')).withWatermark('time', '10 minutes') + DataFrame[name: string, time: timestamp] + """ + if not eventTime or type(eventTime) is not str: + raise TypeError("eventTime should be provided as a string") + if not delayThreshold or type(delayThreshold) is not str: + raise TypeError("delayThreshold should be provided as a string interval") + jdf = self._jdf.withWatermark(eventTime, delayThreshold) + return DataFrame(jdf, self.sql_ctx) + @since(1.3) def count(self): """Returns the number of rows in this :class:`DataFrame`. @@ -1626,6 +1674,7 @@ def _test(): from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext, SparkSession import pyspark.sql.dataframe + from pyspark.sql.functions import from_unixtime globs = pyspark.sql.dataframe.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc @@ -1638,9 +1687,11 @@ def _test(): globs['df3'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), - Row(name='Bob', age=5, height=None), - Row(name='Tom', age=None, height=None), - Row(name=None, age=None, height=None)]).toDF() + Row(name='Bob', age=5, height=None), + Row(name='Tom', age=None, height=None), + Row(name=None, age=None, height=None)]).toDF() + globs['sdf'] = sc.parallelize([Row(name='Tom', time=1479441846), + Row(name='Bob', time=1479442946)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3c75a6a45ec8..7ba6ffce278c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -485,7 +485,10 @@ class Dataset[T] private[sql]( def isStreaming: Boolean = logicalPlan.isStreaming /** - * Returns a checkpointed version of this Dataset. + * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to truncate + * the logical plan of this Dataset, which is especially useful in iterative algorithms where the + * plan may grow exponentially. It will be saved to files inside the checkpoint + * directory set with `SparkContext#setCheckpointDir`. * * @group basic * @since 2.1.0 @@ -495,7 +498,10 @@ class Dataset[T] private[sql]( def checkpoint(): Dataset[T] = checkpoint(eager = true) /** - * Returns a checkpointed version of this Dataset. + * Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the + * logical plan of this Dataset, which is especially useful in iterative algorithms where the + * plan may grow exponentially. It will be saved to files inside the checkpoint + * directory set with `SparkContext#setCheckpointDir`. * * @group basic * @since 2.1.0 From ebeb0830a3a4837c7354a0eee667b9f5fad389c5 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Mon, 21 Nov 2016 21:14:13 -0800 Subject: [PATCH 279/381] [SPARK-18425][STRUCTURED STREAMING][TESTS] Test `CompactibleFileStreamLog` directly ## What changes were proposed in this pull request? Right now we are testing the most of `CompactibleFileStreamLog` in `FileStreamSinkLogSuite` (because `FileStreamSinkLog` once was the only subclass of `CompactibleFileStreamLog`, but now it's not the case any more). Let's refactor the tests so that `CompactibleFileStreamLog` is directly tested, making future changes (like https://github.com/apache/spark/pull/15828, https://github.com/apache/spark/pull/15827) to `CompactibleFileStreamLog` much easier to test and much easier to review. ## How was this patch tested? the PR itself is about tests Author: Liwei Lin Closes #15870 from lw-lin/test-compact-1113. --- .../CompactibleFileStreamLogSuite.scala | 216 +++++++++++++++++- .../streaming/FileStreamSinkLogSuite.scala | 68 ------ 2 files changed, 214 insertions(+), 70 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 2cd2157b293c..e511fda57912 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -17,12 +17,79 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.SparkFunSuite +import java.io._ +import java.nio.charset.StandardCharsets._ -class CompactibleFileStreamLogSuite extends SparkFunSuite { +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.sql.execution.streaming.FakeFileSystem._ +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.test.SharedSQLContext + +class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { + + /** To avoid caching of FS objects */ + override protected val sparkConf = + new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") import CompactibleFileStreamLog._ + /** -- testing of `object CompactibleFileStreamLog` begins -- */ + + test("getBatchIdFromFileName") { + assert(1234L === getBatchIdFromFileName("1234")) + assert(1234L === getBatchIdFromFileName("1234.compact")) + intercept[NumberFormatException] { + getBatchIdFromFileName("1234a") + } + } + + test("isCompactionBatch") { + assert(false === isCompactionBatch(0, compactInterval = 3)) + assert(false === isCompactionBatch(1, compactInterval = 3)) + assert(true === isCompactionBatch(2, compactInterval = 3)) + assert(false === isCompactionBatch(3, compactInterval = 3)) + assert(false === isCompactionBatch(4, compactInterval = 3)) + assert(true === isCompactionBatch(5, compactInterval = 3)) + } + + test("nextCompactionBatchId") { + assert(2 === nextCompactionBatchId(0, compactInterval = 3)) + assert(2 === nextCompactionBatchId(1, compactInterval = 3)) + assert(5 === nextCompactionBatchId(2, compactInterval = 3)) + assert(5 === nextCompactionBatchId(3, compactInterval = 3)) + assert(5 === nextCompactionBatchId(4, compactInterval = 3)) + assert(8 === nextCompactionBatchId(5, compactInterval = 3)) + } + + test("getValidBatchesBeforeCompactionBatch") { + intercept[AssertionError] { + getValidBatchesBeforeCompactionBatch(0, compactInterval = 3) + } + intercept[AssertionError] { + getValidBatchesBeforeCompactionBatch(1, compactInterval = 3) + } + assert(Seq(0, 1) === getValidBatchesBeforeCompactionBatch(2, compactInterval = 3)) + intercept[AssertionError] { + getValidBatchesBeforeCompactionBatch(3, compactInterval = 3) + } + intercept[AssertionError] { + getValidBatchesBeforeCompactionBatch(4, compactInterval = 3) + } + assert(Seq(2, 3, 4) === getValidBatchesBeforeCompactionBatch(5, compactInterval = 3)) + } + + test("getAllValidBatches") { + assert(Seq(0) === getAllValidBatches(0, compactInterval = 3)) + assert(Seq(0, 1) === getAllValidBatches(1, compactInterval = 3)) + assert(Seq(2) === getAllValidBatches(2, compactInterval = 3)) + assert(Seq(2, 3) === getAllValidBatches(3, compactInterval = 3)) + assert(Seq(2, 3, 4) === getAllValidBatches(4, compactInterval = 3)) + assert(Seq(5) === getAllValidBatches(5, compactInterval = 3)) + assert(Seq(5, 6) === getAllValidBatches(6, compactInterval = 3)) + assert(Seq(5, 6, 7) === getAllValidBatches(7, compactInterval = 3)) + assert(Seq(8) === getAllValidBatches(8, compactInterval = 3)) + } + test("deriveCompactInterval") { // latestCompactBatchId(4) + 1 <= default(5) // then use latestestCompactBatchId + 1 === 5 @@ -30,4 +97,149 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite { // First divisor of 10 greater than 4 === 5 assert(5 === deriveCompactInterval(4, 9)) } + + /** -- testing of `object CompactibleFileStreamLog` ends -- */ + + test("batchIdToPath") { + withFakeCompactibleFileStreamLog( + fileCleanupDelayMs = Long.MaxValue, + defaultCompactInterval = 3, + compactibleLog => { + assert("0" === compactibleLog.batchIdToPath(0).getName) + assert("1" === compactibleLog.batchIdToPath(1).getName) + assert("2.compact" === compactibleLog.batchIdToPath(2).getName) + assert("3" === compactibleLog.batchIdToPath(3).getName) + assert("4" === compactibleLog.batchIdToPath(4).getName) + assert("5.compact" === compactibleLog.batchIdToPath(5).getName) + }) + } + + test("serialize") { + withFakeCompactibleFileStreamLog( + fileCleanupDelayMs = Long.MaxValue, + defaultCompactInterval = 3, + compactibleLog => { + val logs = Array("entry_1", "entry_2", "entry_3") + val expected = s"""${FakeCompactibleFileStreamLog.VERSION} + |"entry_1" + |"entry_2" + |"entry_3"""".stripMargin + val baos = new ByteArrayOutputStream() + compactibleLog.serialize(logs, baos) + assert(expected === baos.toString(UTF_8.name())) + + baos.reset() + compactibleLog.serialize(Array(), baos) + assert(FakeCompactibleFileStreamLog.VERSION === baos.toString(UTF_8.name())) + }) + } + + test("deserialize") { + withFakeCompactibleFileStreamLog( + fileCleanupDelayMs = Long.MaxValue, + defaultCompactInterval = 3, + compactibleLog => { + val logs = s"""${FakeCompactibleFileStreamLog.VERSION} + |"entry_1" + |"entry_2" + |"entry_3"""".stripMargin + val expected = Array("entry_1", "entry_2", "entry_3") + assert(expected === + compactibleLog.deserialize(new ByteArrayInputStream(logs.getBytes(UTF_8)))) + + assert(Nil === + compactibleLog.deserialize( + new ByteArrayInputStream(FakeCompactibleFileStreamLog.VERSION.getBytes(UTF_8)))) + }) + } + + testWithUninterruptibleThread("compact") { + withFakeCompactibleFileStreamLog( + fileCleanupDelayMs = Long.MaxValue, + defaultCompactInterval = 3, + compactibleLog => { + for (batchId <- 0 to 10) { + compactibleLog.add(batchId, Array("some_path_" + batchId)) + val expectedFiles = (0 to batchId).map { id => "some_path_" + id } + assert(compactibleLog.allFiles() === expectedFiles) + if (isCompactionBatch(batchId, 3)) { + // Since batchId is a compaction batch, the batch log file should contain all logs + assert(compactibleLog.get(batchId).getOrElse(Nil) === expectedFiles) + } + } + }) + } + + testWithUninterruptibleThread("delete expired file") { + // Set `fileCleanupDelayMs` to 0 so that we can detect the deleting behaviour deterministically + withFakeCompactibleFileStreamLog( + fileCleanupDelayMs = 0, + defaultCompactInterval = 3, + compactibleLog => { + val fs = compactibleLog.metadataPath.getFileSystem(spark.sessionState.newHadoopConf()) + + def listBatchFiles(): Set[String] = { + fs.listStatus(compactibleLog.metadataPath).map(_.getPath.getName).filter { fileName => + try { + getBatchIdFromFileName(fileName) + true + } catch { + case _: NumberFormatException => false + } + }.toSet + } + + compactibleLog.add(0, Array("some_path_0")) + assert(Set("0") === listBatchFiles()) + compactibleLog.add(1, Array("some_path_1")) + assert(Set("0", "1") === listBatchFiles()) + compactibleLog.add(2, Array("some_path_2")) + assert(Set("2.compact") === listBatchFiles()) + compactibleLog.add(3, Array("some_path_3")) + assert(Set("2.compact", "3") === listBatchFiles()) + compactibleLog.add(4, Array("some_path_4")) + assert(Set("2.compact", "3", "4") === listBatchFiles()) + compactibleLog.add(5, Array("some_path_5")) + assert(Set("5.compact") === listBatchFiles()) + }) + } + + private def withFakeCompactibleFileStreamLog( + fileCleanupDelayMs: Long, + defaultCompactInterval: Int, + f: FakeCompactibleFileStreamLog => Unit + ): Unit = { + withTempDir { file => + val compactibleLog = new FakeCompactibleFileStreamLog( + fileCleanupDelayMs, + defaultCompactInterval, + spark, + file.getCanonicalPath) + f(compactibleLog) + } + } +} + +object FakeCompactibleFileStreamLog { + val VERSION = "test_version" +} + +class FakeCompactibleFileStreamLog( + _fileCleanupDelayMs: Long, + _defaultCompactInterval: Int, + sparkSession: SparkSession, + path: String) + extends CompactibleFileStreamLog[String]( + FakeCompactibleFileStreamLog.VERSION, + sparkSession, + path + ) { + + override protected def fileCleanupDelayMs: Long = _fileCleanupDelayMs + + override protected def isDeletingExpiredLog: Boolean = true + + override protected def defaultCompactInterval: Int = _defaultCompactInterval + + override def compactLogs(logs: Seq[String]): Seq[String] = logs } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index e1bc674a2807..e046fee0c04d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -29,61 +29,6 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { import CompactibleFileStreamLog._ import FileStreamSinkLog._ - test("getBatchIdFromFileName") { - assert(1234L === getBatchIdFromFileName("1234")) - assert(1234L === getBatchIdFromFileName("1234.compact")) - intercept[NumberFormatException] { - getBatchIdFromFileName("1234a") - } - } - - test("isCompactionBatch") { - assert(false === isCompactionBatch(0, compactInterval = 3)) - assert(false === isCompactionBatch(1, compactInterval = 3)) - assert(true === isCompactionBatch(2, compactInterval = 3)) - assert(false === isCompactionBatch(3, compactInterval = 3)) - assert(false === isCompactionBatch(4, compactInterval = 3)) - assert(true === isCompactionBatch(5, compactInterval = 3)) - } - - test("nextCompactionBatchId") { - assert(2 === nextCompactionBatchId(0, compactInterval = 3)) - assert(2 === nextCompactionBatchId(1, compactInterval = 3)) - assert(5 === nextCompactionBatchId(2, compactInterval = 3)) - assert(5 === nextCompactionBatchId(3, compactInterval = 3)) - assert(5 === nextCompactionBatchId(4, compactInterval = 3)) - assert(8 === nextCompactionBatchId(5, compactInterval = 3)) - } - - test("getValidBatchesBeforeCompactionBatch") { - intercept[AssertionError] { - getValidBatchesBeforeCompactionBatch(0, compactInterval = 3) - } - intercept[AssertionError] { - getValidBatchesBeforeCompactionBatch(1, compactInterval = 3) - } - assert(Seq(0, 1) === getValidBatchesBeforeCompactionBatch(2, compactInterval = 3)) - intercept[AssertionError] { - getValidBatchesBeforeCompactionBatch(3, compactInterval = 3) - } - intercept[AssertionError] { - getValidBatchesBeforeCompactionBatch(4, compactInterval = 3) - } - assert(Seq(2, 3, 4) === getValidBatchesBeforeCompactionBatch(5, compactInterval = 3)) - } - - test("getAllValidBatches") { - assert(Seq(0) === getAllValidBatches(0, compactInterval = 3)) - assert(Seq(0, 1) === getAllValidBatches(1, compactInterval = 3)) - assert(Seq(2) === getAllValidBatches(2, compactInterval = 3)) - assert(Seq(2, 3) === getAllValidBatches(3, compactInterval = 3)) - assert(Seq(2, 3, 4) === getAllValidBatches(4, compactInterval = 3)) - assert(Seq(5) === getAllValidBatches(5, compactInterval = 3)) - assert(Seq(5, 6) === getAllValidBatches(6, compactInterval = 3)) - assert(Seq(5, 6, 7) === getAllValidBatches(7, compactInterval = 3)) - assert(Seq(8) === getAllValidBatches(8, compactInterval = 3)) - } - test("compactLogs") { withFileStreamSinkLog { sinkLog => val logs = Seq( @@ -184,19 +129,6 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { } } - test("batchIdToPath") { - withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") { - withFileStreamSinkLog { sinkLog => - assert("0" === sinkLog.batchIdToPath(0).getName) - assert("1" === sinkLog.batchIdToPath(1).getName) - assert("2.compact" === sinkLog.batchIdToPath(2).getName) - assert("3" === sinkLog.batchIdToPath(3).getName) - assert("4" === sinkLog.batchIdToPath(4).getName) - assert("5.compact" === sinkLog.batchIdToPath(5).getName) - } - } - } - testWithUninterruptibleThread("compact") { withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") { withFileStreamSinkLog { sinkLog => From acb97157796231fef74aba985825b05b607b9279 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 22 Nov 2016 00:05:30 -0800 Subject: [PATCH 280/381] [SPARK-18444][SPARKR] SparkR running in yarn-cluster mode should not download Spark package. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When running SparkR job in yarn-cluster mode, it will download Spark package from apache website which is not necessary. ``` ./bin/spark-submit --master yarn-cluster ./examples/src/main/r/dataframe.R ``` The following is output: ``` Attaching package: ‘SparkR’ The following objects are masked from ‘package:stats’: cov, filter, lag, na.omit, predict, sd, var, window The following objects are masked from ‘package:base’: as.data.frame, colnames, colnames<-, drop, endsWith, intersect, rank, rbind, sample, startsWith, subset, summary, transform, union Spark not found in SPARK_HOME: Spark not found in the cache directory. Installation will start. MirrorUrl not provided. Looking for preferred site from apache website... ...... ``` There's no ```SPARK_HOME``` in yarn-cluster mode since the R process is in a remote host of the yarn cluster rather than in the client host. The JVM comes up first and the R process then connects to it. So in such cases we should never have to download Spark as Spark is already running. ## How was this patch tested? Offline test. Author: Yanbo Liang Closes #15888 from yanboliang/spark-18444. --- R/pkg/R/sparkR.R | 20 +++++++---- R/pkg/R/utils.R | 4 +++ R/pkg/inst/tests/testthat/test_sparkR.R | 46 +++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 6 deletions(-) create mode 100644 R/pkg/inst/tests/testthat/test_sparkR.R diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 6b4a2f2fdc85..a7152b431399 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -373,8 +373,13 @@ sparkR.session <- function( overrideEnvs(sparkConfigMap, paramMap) } + deployMode <- "" + if (exists("spark.submit.deployMode", envir = sparkConfigMap)) { + deployMode <- sparkConfigMap[["spark.submit.deployMode"]] + } + if (!exists(".sparkRjsc", envir = .sparkREnv)) { - retHome <- sparkCheckInstall(sparkHome, master) + retHome <- sparkCheckInstall(sparkHome, master, deployMode) if (!is.null(retHome)) sparkHome <- retHome sparkExecutorEnvMap <- new.env() sparkR.sparkContext(master, appName, sparkHome, sparkConfigMap, sparkExecutorEnvMap, @@ -550,24 +555,27 @@ processSparkPackages <- function(packages) { # # @param sparkHome directory to find Spark package. # @param master the Spark master URL, used to check local or remote mode. +# @param deployMode whether to deploy your driver on the worker nodes (cluster) +# or locally as an external client (client). # @return NULL if no need to update sparkHome, and new sparkHome otherwise. -sparkCheckInstall <- function(sparkHome, master) { +sparkCheckInstall <- function(sparkHome, master, deployMode) { if (!isSparkRShell()) { if (!is.na(file.info(sparkHome)$isdir)) { msg <- paste0("Spark package found in SPARK_HOME: ", sparkHome) message(msg) NULL } else { - if (!nzchar(master) || isMasterLocal(master)) { - msg <- paste0("Spark not found in SPARK_HOME: ", - sparkHome) + if (isMasterLocal(master)) { + msg <- paste0("Spark not found in SPARK_HOME: ", sparkHome) message(msg) packageLocalDir <- install.spark() packageLocalDir - } else { + } else if (isClientMode(master) || deployMode == "client") { msg <- paste0("Spark not found in SPARK_HOME: ", sparkHome, "\n", installInstruction("remote")) stop(msg) + } else { + NULL } } } else { diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 20004549cc03..098c0e3e31e9 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -777,6 +777,10 @@ isMasterLocal <- function(master) { grepl("^local(\\[([0-9]+|\\*)\\])?$", master, perl = TRUE) } +isClientMode <- function(master) { + grepl("([a-z]+)-client$", master, perl = TRUE) +} + isSparkRShell <- function() { grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE) } diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/inst/tests/testthat/test_sparkR.R new file mode 100644 index 000000000000..f73fc6baecce --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_sparkR.R @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in sparkR.R") + +test_that("sparkCheckInstall", { + # "local, yarn-client, mesos-client" mode, SPARK_HOME was set correctly, + # and the SparkR job was submitted by "spark-submit" + sparkHome <- paste0(tempdir(), "/", "sparkHome") + dir.create(sparkHome) + master <- "" + deployMode <- "" + expect_true(is.null(sparkCheckInstall(sparkHome, master, deployMode))) + unlink(sparkHome, recursive = TRUE) + + # "yarn-cluster, mesos-cluster" mode, SPARK_HOME was not set, + # and the SparkR job was submitted by "spark-submit" + sparkHome <- "" + master <- "" + deployMode <- "" + expect_true(is.null(sparkCheckInstall(sparkHome, master, deployMode))) + + # "yarn-client, mesos-client" mode, SPARK_HOME was not set + sparkHome <- "" + master <- "yarn-client" + deployMode <- "" + expect_error(sparkCheckInstall(sparkHome, master, deployMode)) + sparkHome <- "" + master <- "" + deployMode <- "client" + expect_error(sparkCheckInstall(sparkHome, master, deployMode)) +}) From 4922f9cdcac8b7c10320ac1fb701997fffa45d46 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 22 Nov 2016 11:26:10 +0000 Subject: [PATCH 281/381] [SPARK-18514][DOCS] Fix the markdown for `Note:`/`NOTE:`/`Note that` across R API documentation ## What changes were proposed in this pull request? It seems in R, there are - `Note:` - `NOTE:` - `Note that` This PR proposes to fix those to `Note:` to be consistent. **Before** ![2016-11-21 11 30 07](https://cloud.githubusercontent.com/assets/6477701/20468848/2f27b0fa-afde-11e6-89e3-993701269dbe.png) **After** ![2016-11-21 11 29 44](https://cloud.githubusercontent.com/assets/6477701/20468851/39469664-afde-11e6-9929-ad80be7fc405.png) ## How was this patch tested? The notes were found via ```bash grep -r "NOTE: " . grep -r "Note that " . ``` And then fixed one by one comparing with API documentation. After that, manually tested via `sh create-docs.sh` under `./R`. Author: hyukjinkwon Closes #15952 from HyukjinKwon/SPARK-18514. --- R/pkg/R/DataFrame.R | 6 ++++-- R/pkg/R/functions.R | 7 ++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 4e3d97bb3ad0..9a51d530f120 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2541,7 +2541,8 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame #' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL. -#' Note that this does not remove duplicate rows across the two SparkDataFrames. +#' +#' Note: This does not remove duplicate rows across the two SparkDataFrames. #' #' @param x A SparkDataFrame #' @param y A SparkDataFrame @@ -2584,7 +2585,8 @@ setMethod("unionAll", #' Union two or more SparkDataFrames #' #' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL. -#' Note that this does not remove duplicate rows across the two SparkDataFrames. +#' +#' Note: This does not remove duplicate rows across the two SparkDataFrames. #' #' @param x a SparkDataFrame. #' @param ... additional SparkDataFrame(s). diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f8a9d3ce5d91..bf5c96373c63 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2296,7 +2296,7 @@ setMethod("n", signature(x = "Column"), #' A pattern could be for instance \preformatted{dd.MM.yyyy} and could return a string like '18.03.1993'. All #' pattern letters of \code{java.text.SimpleDateFormat} can be used. #' -#' NOTE: Use when ever possible specialized functions like \code{year}. These benefit from a +#' Note: Use when ever possible specialized functions like \code{year}. These benefit from a #' specialized implementation. #' #' @param y Column to compute on. @@ -2341,7 +2341,7 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' Locate the position of the first occurrence of substr column in the given string. #' Returns null if either of the arguments are null. #' -#' NOTE: The position is not zero based, but 1 based index. Returns 0 if substr +#' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param y column to check @@ -2779,7 +2779,8 @@ setMethod("window", signature(x = "Column"), #' locate #' #' Locate the position of the first occurrence of substr. -#' NOTE: The position is not zero based, but 1 based index. Returns 0 if substr +#' +#' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param substr a character string to be matched. From 933a6548d423cf17448207a99299cf36fc1a95f6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 22 Nov 2016 11:40:18 +0000 Subject: [PATCH 282/381] [SPARK-18447][DOCS] Fix the markdown for `Note:`/`NOTE:`/`Note that` across Python API documentation ## What changes were proposed in this pull request? It seems in Python, there are - `Note:` - `NOTE:` - `Note that` - `.. note::` This PR proposes to fix those to `.. note::` to be consistent. **Before** 2016-11-21 1 18 49 2016-11-21 12 42 43 **After** 2016-11-21 1 18 42 2016-11-21 12 42 51 ## How was this patch tested? The notes were found via ```bash grep -r "Note: " . grep -r "NOTE: " . grep -r "Note that " . ``` And then fixed one by one comparing with API documentation. After that, manually tested via `make html` under `./python/docs`. Author: hyukjinkwon Closes #15947 from HyukjinKwon/SPARK-18447. --- python/pyspark/conf.py | 4 +- python/pyspark/context.py | 8 ++-- python/pyspark/ml/classification.py | 45 +++++++++--------- python/pyspark/ml/clustering.py | 8 ++-- python/pyspark/ml/feature.py | 13 +++--- python/pyspark/ml/linalg/__init__.py | 11 +++-- python/pyspark/ml/regression.py | 32 ++++++------- python/pyspark/mllib/clustering.py | 6 +-- python/pyspark/mllib/feature.py | 24 +++++----- python/pyspark/mllib/linalg/__init__.py | 11 +++-- python/pyspark/mllib/linalg/distributed.py | 15 +++--- python/pyspark/mllib/regression.py | 2 +- python/pyspark/mllib/stat/_statistics.py | 3 +- python/pyspark/mllib/tree.py | 12 ++--- python/pyspark/rdd.py | 54 +++++++++++----------- python/pyspark/sql/dataframe.py | 28 ++++++----- python/pyspark/sql/functions.py | 11 +++-- python/pyspark/sql/streaming.py | 10 ++-- python/pyspark/streaming/context.py | 2 +- python/pyspark/streaming/kinesis.py | 4 +- 20 files changed, 157 insertions(+), 146 deletions(-) diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 64b6f238e9c3..491b3a81972b 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -90,8 +90,8 @@ class SparkConf(object): All setter methods in this class support chaining. For example, you can write C{conf.setMaster("local").setAppName("My app")}. - Note that once a SparkConf object is passed to Spark, it is cloned - and can no longer be modified by the user. + .. note:: Once a SparkConf object is passed to Spark, it is cloned + and can no longer be modified by the user. """ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2c2cf6a373bb..2fd3aee01d76 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -520,8 +520,8 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): ... (a-hdfs-path/part-nnnnn, its content) - NOTE: Small files are preferred, as each file will be loaded - fully in memory. + .. note:: Small files are preferred, as each file will be loaded + fully in memory. >>> dirPath = os.path.join(tempdir, "files") >>> os.mkdir(dirPath) @@ -547,8 +547,8 @@ def binaryFiles(self, path, minPartitions=None): in a key-value pair, where the key is the path of each file, the value is the content of each file. - Note: Small files are preferred, large file is also allowable, but - may cause bad performance. + .. note:: Small files are preferred, large file is also allowable, but + may cause bad performance. """ minPartitions = minPartitions or self.defaultMinPartitions return RDD(self._jsc.binaryFiles(path, minPartitions), self, diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 83e1e8934766..8054a34db30f 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -440,9 +440,9 @@ def roc(self): .. seealso:: `Wikipedia reference \ `_ - Note: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("roc") @@ -453,9 +453,9 @@ def areaUnderROC(self): Computes the area under the receiver operating characteristic (ROC) curve. - Note: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("areaUnderROC") @@ -467,9 +467,9 @@ def pr(self): containing two fields recall, precision with (0.0, 1.0) prepended to it. - Note: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("pr") @@ -480,9 +480,9 @@ def fMeasureByThreshold(self): Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. - Note: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("fMeasureByThreshold") @@ -494,9 +494,9 @@ def precisionByThreshold(self): Every possible probability obtained in transforming the dataset are used as thresholds used in calculating the precision. - Note: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("precisionByThreshold") @@ -508,9 +508,9 @@ def recallByThreshold(self): Every possible probability obtained in transforming the dataset are used as thresholds used in calculating the recall. - Note: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("recallByThreshold") @@ -695,9 +695,9 @@ def featureImportances(self): where gain is scaled by the number of instances passing through node - Normalize importances for tree to sum to 1. - Note: Feature importance for single decision trees can have high variance due to - correlated predictor variables. Consider using a :py:class:`RandomForestClassifier` - to determine feature importance instead. + .. note:: Feature importance for single decision trees can have high variance due to + correlated predictor variables. Consider using a :py:class:`RandomForestClassifier` + to determine feature importance instead. """ return self._call_java("featureImportances") @@ -839,7 +839,6 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol `Gradient-Boosted Trees (GBTs) `_ learning algorithm for classification. It supports binary labels, as well as both continuous and categorical features. - Note: Multiclass labels are not currently supported. The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. @@ -851,6 +850,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol - We expect to implement TreeBoost in the future: `SPARK-4240 `_ + .. note:: Multiclass labels are not currently supported. + >>> from numpy import allclose >>> from pyspark.ml.linalg import Vectors >>> from pyspark.ml.feature import StringIndexer diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index e58ec1e7ac29..b29b5ac70e6f 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -155,7 +155,7 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte While this process is generally guaranteed to converge, it is not guaranteed to find a global optimum. - Note: For high-dimensional data (with many features), this algorithm may perform poorly. + .. note:: For high-dimensional data (with many features), this algorithm may perform poorly. This is due to high-dimensional data (a) making it difficult to cluster at all (based on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. @@ -749,9 +749,9 @@ def getCheckpointFiles(self): If using checkpointing and :py:attr:`LDA.keepLastCheckpoint` is set to true, then there may be saved checkpoint files. This method is provided so that users can manage those files. - Note that removing the checkpoints can cause failures if a partition is lost and is needed - by certain :py:class:`DistributedLDAModel` methods. Reference counting will clean up the - checkpoints when this model and derivative data go out of scope. + .. note:: Removing the checkpoints can cause failures if a partition is lost and is needed + by certain :py:class:`DistributedLDAModel` methods. Reference counting will clean up + the checkpoints when this model and derivative data go out of scope. :return List of checkpoint files from training """ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 635cf1304588..40b63d4d31d4 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -742,8 +742,8 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav For the case E_max == E_min, Rescaled(e_i) = 0.5 * (max + min) - Note that since zero values will probably be transformed to non-zero values, output of the - transformer will be DenseVector even for sparse input. + .. note:: Since zero values will probably be transformed to non-zero values, output of the + transformer will be DenseVector even for sparse input. >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) @@ -1014,9 +1014,9 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, :py:attr:`dropLast`) because it makes the vector entries sum up to one, and hence linearly dependent. So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. - Note that this is different from scikit-learn's OneHotEncoder, - which keeps all categories. - The output vectors are sparse. + + .. note:: This is different from scikit-learn's OneHotEncoder, + which keeps all categories. The output vectors are sparse. .. seealso:: @@ -1698,7 +1698,8 @@ def getLabels(self): class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ A feature transformer that filters out stop words from input. - Note: null values from input array are preserved unless adding null to stopWords explicitly. + + .. note:: null values from input array are preserved unless adding null to stopWords explicitly. >>> df = spark.createDataFrame([(["a", "b", "c"],)], ["text"]) >>> remover = StopWordsRemover(inputCol="text", outputCol="words", stopWords=["b"]) diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index a5df727fdb41..1705c156ce4c 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -746,11 +746,12 @@ def __hash__(self): class Vectors(object): """ - Factory methods for working with vectors. Note that dense vectors - are simply represented as NumPy array objects, so there is no need - to covert them for use in MLlib. For sparse vectors, the factory - methods in this class create an MLlib-compatible type, or users - can pass in SciPy's C{scipy.sparse} column vectors. + Factory methods for working with vectors. + + .. note:: Dense vectors are simply represented as NumPy array objects, + so there is no need to covert them for use in MLlib. For sparse vectors, + the factory methods in this class create an MLlib-compatible type, or users + can pass in SciPy's C{scipy.sparse} column vectors. """ @staticmethod diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 385391ba53fd..b42e80706980 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -245,9 +245,9 @@ def explainedVariance(self): .. seealso:: `Wikipedia explain variation \ `_ - Note: This ignores instance weights (setting all to 1.0) from - `LinearRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("explainedVariance") @@ -259,9 +259,9 @@ def meanAbsoluteError(self): corresponding to the expected value of the absolute error loss or l1-norm loss. - Note: This ignores instance weights (setting all to 1.0) from - `LinearRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("meanAbsoluteError") @@ -273,9 +273,9 @@ def meanSquaredError(self): corresponding to the expected value of the squared error loss or quadratic loss. - Note: This ignores instance weights (setting all to 1.0) from - `LinearRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("meanSquaredError") @@ -286,9 +286,9 @@ def rootMeanSquaredError(self): Returns the root mean squared error, which is defined as the square root of the mean squared error. - Note: This ignores instance weights (setting all to 1.0) from - `LinearRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("rootMeanSquaredError") @@ -301,9 +301,9 @@ def r2(self): .. seealso:: `Wikipedia coefficient of determination \ ` - Note: This ignores instance weights (setting all to 1.0) from - `LinearRegression.weightCol`. This will change in later Spark - versions. + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. """ return self._call_java("r2") @@ -822,7 +822,7 @@ def featureImportances(self): where gain is scaled by the number of instances passing through node - Normalize importances for tree to sum to 1. - Note: Feature importance for single decision trees can have high variance due to + .. note:: Feature importance for single decision trees can have high variance due to correlated predictor variables. Consider using a :py:class:`RandomForestRegressor` to determine feature importance instead. """ diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 2036168e456f..91123ace3387 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -699,9 +699,9 @@ class StreamingKMeansModel(KMeansModel): * n_t+1: New number of weights. * a: Decay Factor, which gives the forgetfulness. - Note that if a is set to 1, it is the weighted mean of the previous - and new data. If it set to zero, the old centroids are completely - forgotten. + .. note:: If a is set to 1, it is the weighted mean of the previous + and new data. If it set to zero, the old centroids are completely + forgotten. :param clusterCenters: Initial cluster centers. diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 7eaa2282cb8b..bde0f67be775 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -114,9 +114,9 @@ def transform(self, vector): """ Applies transformation on a vector or an RDD[Vector]. - Note: In Python, transform cannot currently be used within - an RDD transformation or action. - Call transform directly on the RDD instead. + .. note:: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. :param vector: Vector or RDD of Vector to be transformed. """ @@ -139,9 +139,9 @@ def transform(self, vector): """ Applies standardization transformation on a vector. - Note: In Python, transform cannot currently be used within - an RDD transformation or action. - Call transform directly on the RDD instead. + .. note:: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. :param vector: Vector or RDD of Vector to be standardized. :return: Standardized vector. If the variance of a column is @@ -407,7 +407,7 @@ class HashingTF(object): Maps a sequence of terms to their term frequencies using the hashing trick. - Note: the terms must be hashable (can not be dict/set/list...). + .. note:: The terms must be hashable (can not be dict/set/list...). :param numFeatures: number of features (default: 2^20) @@ -469,9 +469,9 @@ def transform(self, x): the terms which occur in fewer than `minDocFreq` documents will have an entry of 0. - Note: In Python, transform cannot currently be used within - an RDD transformation or action. - Call transform directly on the RDD instead. + .. note:: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. :param x: an RDD of term frequency vectors or a term frequency vector @@ -551,7 +551,7 @@ def transform(self, word): """ Transforms a word to its vector representation - Note: local use only + .. note:: Local use only :param word: a word :return: vector representation of word(s) @@ -570,7 +570,7 @@ def findSynonyms(self, word, num): :param num: number of synonyms to find :return: array of (word, cosineSimilarity) - Note: local use only + .. note:: Local use only """ if not isinstance(word, basestring): word = _convert_to_vector(word) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index d37e715c8d8e..031f22c02098 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -835,11 +835,12 @@ def __hash__(self): class Vectors(object): """ - Factory methods for working with vectors. Note that dense vectors - are simply represented as NumPy array objects, so there is no need - to covert them for use in MLlib. For sparse vectors, the factory - methods in this class create an MLlib-compatible type, or users - can pass in SciPy's C{scipy.sparse} column vectors. + Factory methods for working with vectors. + + .. note:: Dense vectors are simply represented as NumPy array objects, + so there is no need to covert them for use in MLlib. For sparse vectors, + the factory methods in this class create an MLlib-compatible type, or users + can pass in SciPy's C{scipy.sparse} column vectors. """ @staticmethod diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 538cada7d163..600655c912ca 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -171,8 +171,9 @@ def computeColumnSummaryStatistics(self): def computeCovariance(self): """ Computes the covariance matrix, treating each row as an - observation. Note that this cannot be computed on matrices - with more than 65535 columns. + observation. + + .. note:: This cannot be computed on matrices with more than 65535 columns. >>> rows = sc.parallelize([[1, 2], [2, 1]]) >>> mat = RowMatrix(rows) @@ -185,8 +186,9 @@ def computeCovariance(self): @since('2.0.0') def computeGramianMatrix(self): """ - Computes the Gramian matrix `A^T A`. Note that this cannot be - computed on matrices with more than 65535 columns. + Computes the Gramian matrix `A^T A`. + + .. note:: This cannot be computed on matrices with more than 65535 columns. >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6]]) >>> mat = RowMatrix(rows) @@ -458,8 +460,9 @@ def columnSimilarities(self): @since('2.0.0') def computeGramianMatrix(self): """ - Computes the Gramian matrix `A^T A`. Note that this cannot be - computed on matrices with more than 65535 columns. + Computes the Gramian matrix `A^T A`. + + .. note:: This cannot be computed on matrices with more than 65535 columns. >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), ... IndexedRow(1, [4, 5, 6])]) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 705022934e41..1b66f5b51044 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -44,7 +44,7 @@ class LabeledPoint(object): Vector of features for this point (NumPy array, list, pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix). - Note: 'label' and 'features' are accessible as class attributes. + .. note:: 'label' and 'features' are accessible as class attributes. .. versionadded:: 1.0.0 """ diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 67d5f0e44f41..49b26446dbc3 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -164,7 +164,6 @@ def chiSqTest(observed, expected=None): of fit test of the observed data against the expected distribution, or againt the uniform distribution (by default), with each category having an expected frequency of `1 / len(observed)`. - (Note: `observed` cannot contain negative values) If `observed` is matrix, conduct Pearson's independence test on the input contingency matrix, which cannot contain negative entries or @@ -176,6 +175,8 @@ def chiSqTest(observed, expected=None): contingency matrix for which the chi-squared statistic is computed. All label and feature values must be categorical. + .. note:: `observed` cannot contain negative values + :param observed: it could be a vector containing the observed categorical counts/relative frequencies, or the contingency matrix (containing either counts or relative frequencies), diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index b3011d42e56a..a6089fc8b9d3 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -40,9 +40,9 @@ def predict(self, x): Predict values for a single data point or an RDD of points using the model trained. - Note: In Python, predict cannot currently be used within an RDD - transformation or action. - Call predict directly on the RDD instead. + .. note:: In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. """ if isinstance(x, RDD): return self.call("predict", x.map(_convert_to_vector)) @@ -85,9 +85,9 @@ def predict(self, x): """ Predict the label of one or more examples. - Note: In Python, predict cannot currently be used within an RDD - transformation or action. - Call predict directly on the RDD instead. + .. note:: In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. :param x: Data point (feature vector), or an RDD of data points (feature diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index f21a364df910..9e05da89af08 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -417,10 +417,8 @@ def sample(self, withReplacement, fraction, seed=None): with replacement: expected number of times each element is chosen; fraction must be >= 0 :param seed: seed for the random number generator - .. note:: - - This is not guaranteed to provide exactly the fraction specified of the total count - of the given :class:`DataFrame`. + .. note:: This is not guaranteed to provide exactly the fraction specified of the total + count of the given :class:`DataFrame`. >>> rdd = sc.parallelize(range(100), 4) >>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14 @@ -460,8 +458,8 @@ def takeSample(self, withReplacement, num, seed=None): """ Return a fixed-size sampled subset of this RDD. - Note that this method should only be used if the resulting array is expected - to be small, as all the data is loaded into the driver's memory. + .. note:: This method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. >>> rdd = sc.parallelize(range(0, 10)) >>> len(rdd.takeSample(True, 20, 1)) @@ -572,7 +570,7 @@ def intersection(self, other): Return the intersection of this RDD and another one. The output will not contain any duplicate elements, even if the input RDDs did. - Note that this method performs a shuffle internally. + .. note:: This method performs a shuffle internally. >>> rdd1 = sc.parallelize([1, 10, 2, 3, 4, 5]) >>> rdd2 = sc.parallelize([1, 6, 2, 3, 7, 8]) @@ -803,8 +801,9 @@ def func(it): def collect(self): """ Return a list that contains all of the elements in this RDD. - Note that this method should only be used if the resulting array is expected - to be small, as all the data is loaded into the driver's memory. + + .. note:: This method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. """ with SCCallSiteSync(self.context) as css: port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) @@ -1251,10 +1250,10 @@ def top(self, num, key=None): """ Get the top N elements from an RDD. - Note that this method should only be used if the resulting array is expected - to be small, as all the data is loaded into the driver's memory. + .. note:: This method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. - Note: It returns the list sorted in descending order. + .. note:: It returns the list sorted in descending order. >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) [12] @@ -1276,8 +1275,8 @@ def takeOrdered(self, num, key=None): Get the N elements from an RDD ordered in ascending order or as specified by the optional key function. - Note that this method should only be used if the resulting array is expected - to be small, as all the data is loaded into the driver's memory. + .. note:: this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6) [1, 2, 3, 4, 5, 6] @@ -1298,11 +1297,11 @@ def take(self, num): that partition to estimate the number of additional partitions needed to satisfy the limit. - Note that this method should only be used if the resulting array is expected - to be small, as all the data is loaded into the driver's memory. - Translated from the Scala implementation in RDD#take(). + .. note:: this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) [2, 3] >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) @@ -1366,8 +1365,9 @@ def first(self): def isEmpty(self): """ - Returns true if and only if the RDD contains no elements at all. Note that an RDD - may be empty even when it has at least 1 partition. + Returns true if and only if the RDD contains no elements at all. + + .. note:: an RDD may be empty even when it has at least 1 partition. >>> sc.parallelize([]).isEmpty() True @@ -1558,8 +1558,8 @@ def collectAsMap(self): """ Return the key-value pairs in this RDD to the master as a dictionary. - Note that this method should only be used if the resulting data is expected - to be small, as all the data is loaded into the driver's memory. + .. note:: this method should only be used if the resulting data is expected + to be small, as all the data is loaded into the driver's memory. >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() >>> m[1] @@ -1796,8 +1796,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, set of aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined - type" C. Note that V and C can be different -- for example, one might - group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]). + type" C. Users provide three functions: @@ -1809,6 +1808,9 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, In addition, users can control the partitioning of the output RDD. + .. note:: V and C can be different -- for example, one might group an RDD of type + (Int, Int) into an RDD of type (Int, List[Int]). + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> def add(a, b): return a + str(b) >>> sorted(x.combineByKey(str, add, add).collect()) @@ -1880,9 +1882,9 @@ def groupByKey(self, numPartitions=None, partitionFunc=portable_hash): Group the values for each key in the RDD into a single sequence. Hash-partitions the resulting RDD with numPartitions partitions. - Note: If you are grouping in order to perform an aggregation (such as a - sum or average) over each key, using reduceByKey or aggregateByKey will - provide much better performance. + .. note:: If you are grouping in order to perform an aggregation (such as a + sum or average) over each key, using reduceByKey or aggregateByKey will + provide much better performance. >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(rdd.groupByKey().mapValues(len).collect()) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 6fe622643291..b9d90384e3e2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -457,7 +457,7 @@ def foreachPartition(self, f): def cache(self): """Persists the :class:`DataFrame` with the default storage level (C{MEMORY_AND_DISK}). - .. note:: the default storage level has changed to C{MEMORY_AND_DISK} to match Scala in 2.0. + .. note:: The default storage level has changed to C{MEMORY_AND_DISK} to match Scala in 2.0. """ self.is_cached = True self._jdf.cache() @@ -470,7 +470,7 @@ def persist(self, storageLevel=StorageLevel.MEMORY_AND_DISK): a new storage level if the :class:`DataFrame` does not have a storage level set yet. If no storage level is specified defaults to (C{MEMORY_AND_DISK}). - .. note:: the default storage level has changed to C{MEMORY_AND_DISK} to match Scala in 2.0. + .. note:: The default storage level has changed to C{MEMORY_AND_DISK} to match Scala in 2.0. """ self.is_cached = True javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) @@ -597,10 +597,8 @@ def distinct(self): def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. - .. note:: - - This is not guaranteed to provide exactly the fraction specified of the total count - of the given :class:`DataFrame`. + .. note:: This is not guaranteed to provide exactly the fraction specified of the total + count of the given :class:`DataFrame`. >>> df.sample(False, 0.5, 42).count() 2 @@ -866,8 +864,8 @@ def describe(self, *cols): This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical or string columns. - .. note:: This function is meant for exploratory data analysis, as we make no \ - guarantee about the backward compatibility of the schema of the resulting DataFrame. + .. note:: This function is meant for exploratory data analysis, as we make no + guarantee about the backward compatibility of the schema of the resulting DataFrame. >>> df.describe(['age']).show() +-------+------------------+ @@ -900,8 +898,8 @@ def describe(self, *cols): def head(self, n=None): """Returns the first ``n`` rows. - Note that this method should only be used if the resulting array is expected - to be small, as all the data is loaded into the driver's memory. + .. note:: This method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. :param n: int, default 1. Number of rows to return. :return: If n is greater than 1, return a list of :class:`Row`. @@ -1462,8 +1460,8 @@ def freqItems(self, cols, support=None): "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. - .. note:: This function is meant for exploratory data analysis, as we make no \ - guarantee about the backward compatibility of the schema of the resulting DataFrame. + .. note:: This function is meant for exploratory data analysis, as we make no + guarantee about the backward compatibility of the schema of the resulting DataFrame. :param cols: Names of the columns to calculate frequent items for as a list or tuple of strings. @@ -1564,11 +1562,11 @@ def toDF(self, *cols): def toPandas(self): """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. - Note that this method should only be used if the resulting Pandas's DataFrame is expected - to be small, as all the data is loaded into the driver's memory. - This is only available if Pandas is installed and available. + .. note:: This method should only be used if the resulting Pandas's DataFrame is expected + to be small, as all the data is loaded into the driver's memory. + >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 46a092f16d4f..d8abafcde384 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -359,7 +359,7 @@ def grouping_id(*cols): (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) - .. note:: the list of columns should match with grouping columns exactly, or empty (means all + .. note:: The list of columns should match with grouping columns exactly, or empty (means all the grouping columns). >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show() @@ -547,7 +547,7 @@ def shiftRightUnsigned(col, numBits): def spark_partition_id(): """A column for partition ID. - Note that this is indeterministic because it depends on data partitioning and task scheduling. + .. note:: This is indeterministic because it depends on data partitioning and task scheduling. >>> df.repartition(1).select(spark_partition_id().alias("pid")).collect() [Row(pid=0), Row(pid=0)] @@ -1852,9 +1852,10 @@ def __call__(self, *cols): @since(1.3) def udf(f, returnType=StringType()): """Creates a :class:`Column` expression representing a user defined function (UDF). - Note that the user-defined functions must be deterministic. Due to optimization, - duplicate invocations may be eliminated or the function may even be invoked more times than - it is present in the query. + + .. note:: The user-defined functions must be deterministic. Due to optimization, + duplicate invocations may be eliminated or the function may even be invoked more times than + it is present in the query. :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 0e4589be976e..9c3a237699f9 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -90,10 +90,12 @@ def awaitTermination(self, timeout=None): @since(2.0) def processAllAvailable(self): """Blocks until all available data in the source has been processed and committed to the - sink. This method is intended for testing. Note that in the case of continually arriving - data, this method may block forever. Additionally, this method is only guaranteed to block - until data that has been synchronously appended data to a stream source prior to invocation. - (i.e. `getOffset` must immediately reflect the addition). + sink. This method is intended for testing. + + .. note:: In the case of continually arriving data, this method may block forever. + Additionally, this method is only guaranteed to block until data that has been + synchronously appended data to a stream source prior to invocation. + (i.e. `getOffset` must immediately reflect the addition). """ return self._jsq.processAllAvailable() diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index ec3ad9933cf6..17c34f8a1c54 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -304,7 +304,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None): Create an input stream from an queue of RDDs or list. In each batch, it will process either one or all of the RDDs returned by the queue. - NOTE: changes to the queue after the stream is created will not be recognized. + .. note:: Changes to the queue after the stream is created will not be recognized. @param rdds: Queue of RDDs @param oneAtATime: pick one rdd each time or pick all of them once. diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index 434ce83e1e6f..3a8d8b819fd3 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -42,8 +42,8 @@ def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, Create an input stream that pulls messages from a Kinesis stream. This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - Note: The given AWS credentials will get saved in DStream checkpoints if checkpointing is - enabled. Make sure that your checkpoint directory is secure. + .. note:: The given AWS credentials will get saved in DStream checkpoints if checkpointing + is enabled. Make sure that your checkpoint directory is secure. :param ssc: StreamingContext object :param kinesisAppName: Kinesis application name used by the Kinesis Client Library (KCL) to From bb152cdfbb8d02130c71d2326ae81939725c2cf0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 Nov 2016 09:16:20 -0800 Subject: [PATCH 283/381] [SPARK-18519][SQL] map type can not be used in EqualTo ## What changes were proposed in this pull request? Technically map type is not orderable, but can be used in equality comparison. However, due to the limitation of the current implementation, map type can't be used in equality comparison so that it can't be join key or grouping key. This PR makes this limitation explicit, to avoid wrong result. ## How was this patch tested? updated tests. Author: Wenchen Fan Closes #15956 from cloud-fan/map-type. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 15 ------- .../sql/catalyst/expressions/predicates.scala | 30 +++++++++++++ .../analysis/AnalysisErrorSuite.scala | 44 +++++++------------ .../ExpressionTypeCheckingSuite.scala | 2 + 4 files changed, 48 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 98e50d0d3c67..80e577e5c4c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -183,21 +183,6 @@ trait CheckAnalysis extends PredicateHelper { s"join condition '${condition.sql}' " + s"of type ${condition.dataType.simpleString} is not a boolean.") - case j @ Join(_, _, _, Some(condition)) => - def checkValidJoinConditionExprs(expr: Expression): Unit = expr match { - case p: Predicate => - p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs) - case e if e.dataType.isInstanceOf[BinaryType] => - failAnalysis(s"binary type expression ${e.sql} cannot be used " + - "in join conditions") - case e if e.dataType.isInstanceOf[MapType] => - failAnalysis(s"map type expression ${e.sql} cannot be used " + - "in join conditions") - case _ => // OK - } - - checkValidJoinConditionExprs(condition) - case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case aggExpr: AggregateExpression => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7946c201f4ff..2ad452b6a90c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -412,6 +412,21 @@ case class EqualTo(left: Expression, right: Expression) override def inputType: AbstractDataType = AnyDataType + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + // TODO: although map type is not orderable, technically map type should be able to be used + // in equality comparison, remove this type check once we support it. + if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " + + s"input type is ${left.dataType.catalogString}.") + } else { + TypeCheckResult.TypeCheckSuccess + } + case failure => failure + } + } + override def symbol: String = "=" protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -440,6 +455,21 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def inputType: AbstractDataType = AnyDataType + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + // TODO: although map type is not orderable, technically map type should be able to be used + // in equality comparison, remove this type check once we support it. + if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " + + s"input type is ${left.dataType.catalogString}.") + } else { + TypeCheckResult.TypeCheckSuccess + } + case failure => failure + } + } + override def symbol: String = "<=>" override def nullable: Boolean = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 21afe9fec594..8c1faea2394c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -465,34 +465,22 @@ class AnalysisErrorSuite extends AnalysisTest { "another aggregate function." :: Nil) } - test("Join can't work on binary and map types") { - val plan = - Join( - LocalRelation( - AttributeReference("a", BinaryType)(exprId = ExprId(2)), - AttributeReference("b", IntegerType)(exprId = ExprId(1))), - LocalRelation( - AttributeReference("c", BinaryType)(exprId = ExprId(4)), - AttributeReference("d", IntegerType)(exprId = ExprId(3))), - Cross, - Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)), - AttributeReference("c", BinaryType)(exprId = ExprId(4))))) - - assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil) - - val plan2 = - Join( - LocalRelation( - AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), - AttributeReference("b", IntegerType)(exprId = ExprId(1))), - LocalRelation( - AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)), - AttributeReference("d", IntegerType)(exprId = ExprId(3))), - Cross, - Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), - AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4))))) - - assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil) + test("Join can work on binary types but can't work on map types") { + val left = LocalRelation('a.binary, 'b.map(StringType, StringType)) + val right = LocalRelation('c.binary, 'd.map(StringType, StringType)) + + val plan1 = left.join( + right, + joinType = Cross, + condition = Some('a === 'c)) + + assertAnalysisSuccess(plan1) + + val plan2 = left.join( + right, + joinType = Cross, + condition = Some('b === 'd)) + assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil) } test("PredicateSubQuery is used outside of a filter") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 542e654bbce1..744057b7c5f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -111,6 +111,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) + assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo") + assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe") assertError(LessThan('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") assertError(LessThanOrEqual('mapField, 'mapField), From 45ea46b7b397f023b4da878eb11e21b08d931115 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Tue, 22 Nov 2016 12:06:21 -0800 Subject: [PATCH 284/381] [SPARK-18504][SQL] Scalar subquery with extra group by columns returning incorrect result ## What changes were proposed in this pull request? This PR blocks an incorrect result scenario in scalar subquery where there are GROUP BY column(s) that are not part of the correlated predicate(s). Example: // Incorrect result Seq(1).toDF("c1").createOrReplaceTempView("t1") Seq((1,1),(1,2)).toDF("c1","c2").createOrReplaceTempView("t2") sql("select (select sum(-1) from t2 where t1.c1=t2.c1 group by t2.c2) from t1").show // How can selecting a scalar subquery from a 1-row table return 2 rows? ## How was this patch tested? sql/test, catalyst/test new test case covering the reported problem is added to SubquerySuite.scala Author: Nattavut Sutyanyong Closes #15936 from nsyca/scalarSubqueryIncorrect-1. --- .../sql/catalyst/analysis/Analyzer.scala | 3 -- .../sql/catalyst/analysis/CheckAnalysis.scala | 30 +++++++++++++++---- .../org/apache/spark/sql/SubquerySuite.scala | 12 ++++++++ 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ec5f710fd987..0155741ddbc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1241,9 +1241,6 @@ class Analyzer( */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { - case s @ ScalarSubquery(sub, conditions, exprId) - if sub.resolved && conditions.isEmpty && sub.output.size != 1 => - failAnalysis(s"Scalar subquery must return only one column, but got ${sub.output.size}") case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, exprId) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 80e577e5c4c7..26d26385904f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -117,19 +117,37 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis(s"Window specification $s is not valid because $m") case None => w } + case s @ ScalarSubquery(query, conditions, _) + // If no correlation, the output must be exactly one column + if (conditions.isEmpty && query.output.size != 1) => + failAnalysis( + s"Scalar subquery must return only one column, but got ${query.output.size}") case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty => - // Make sure correlated scalar subqueries contain one row for every outer row by - // enforcing that they are aggregates which contain exactly one aggregate expressions. - // The analyzer has already checked that subquery contained only one output column, and - // added all the grouping expressions to the aggregate. - def checkAggregate(a: Aggregate): Unit = { - val aggregates = a.expressions.flatMap(_.collect { + def checkAggregate(agg: Aggregate): Unit = { + // Make sure correlated scalar subqueries contain one row for every outer row by + // enforcing that they are aggregates which contain exactly one aggregate expressions. + // The analyzer has already checked that subquery contained only one output column, + // and added all the grouping expressions to the aggregate. + val aggregates = agg.expressions.flatMap(_.collect { case a: AggregateExpression => a }) if (aggregates.isEmpty) { failAnalysis("The output of a correlated scalar subquery must be aggregated") } + + // SPARK-18504: block cases where GROUP BY columns + // are not part of the correlated columns + val groupByCols = ExpressionSet.apply(agg.groupingExpressions.flatMap(_.references)) + val predicateCols = ExpressionSet.apply(conditions.flatMap(_.references)) + val invalidCols = groupByCols.diff(predicateCols) + // GROUP BY columns must be a subset of columns in the predicates + if (invalidCols.nonEmpty) { + failAnalysis( + "a GROUP BY clause in a scalar correlated subquery " + + "cannot contain non-correlated columns: " + + invalidCols.mkString(",")) + } } // Skip projects and subquery aliases added by the Analyzer and the SQLBuilder. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index c84a6f161893..f1dd1c620e66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -483,6 +483,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(1, null) :: Row(2, 6.0) :: Row(3, 2.0) :: Row(null, null) :: Row(6, null) :: Nil) } + test("SPARK-18504 extra GROUP BY column in correlated scalar subquery is not permitted") { + withTempView("t") { + Seq((1, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") + + val errMsg = intercept[AnalysisException] { + sql("select (select sum(-1) from t t2 where t1.c2 = t2.c1 group by t2.c2) sum from t t1") + } + assert(errMsg.getMessage.contains( + "a GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:")) + } + } + test("non-aggregated correlated scalar subquery") { val msg1 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") From 702cd403fc8e5ce8281fe8828197ead46bdb8832 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 Nov 2016 15:25:22 -0500 Subject: [PATCH 285/381] [SPARK-18507][SQL] HiveExternalCatalog.listPartitions should only call getTable once ## What changes were proposed in this pull request? HiveExternalCatalog.listPartitions should only call `getTable` once, instead of calling it for every partitions. ## How was this patch tested? N/A Author: Wenchen Fan Closes #15978 from cloud-fan/perf. --- .../scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 5dbb4024bbee..ff0923f04893 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -907,8 +907,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat db: String, table: String, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = withClient { + val actualPartColNames = getTable(db, table).partitionColumnNames client.getPartitions(db, table, partialSpec.map(lowerCasePartitionSpec)).map { part => - part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames)) + part.copy(spec = restorePartitionSpec(part.spec, actualPartColNames)) } } From bdc8153e8689262708c7fade5c065bd7fc8a84fc Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 22 Nov 2016 13:03:50 -0800 Subject: [PATCH 286/381] [SPARK-18465] Add 'IF EXISTS' clause to 'UNCACHE' to not throw exceptions when table doesn't exist ## What changes were proposed in this pull request? While this behavior is debatable, consider the following use case: ```sql UNCACHE TABLE foo; CACHE TABLE foo AS SELECT * FROM bar ``` The command above fails the first time you run it. But I want to run the command above over and over again, and I don't want to change my code just for the first run of it. The issue is that subsequent `CACHE TABLE` commands do not overwrite the existing table. Now we can do: ```sql UNCACHE TABLE IF EXISTS foo; CACHE TABLE foo AS SELECT * FROM bar ``` ## How was this patch tested? Unit tests Author: Burak Yavuz Closes #15896 from brkyvz/uncache. --- .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../apache/spark/sql/execution/SparkSqlParser.scala | 2 +- .../apache/spark/sql/execution/command/cache.scala | 12 ++++++++++-- .../org/apache/spark/sql/hive/CachedTableSuite.scala | 5 ++++- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index fcca11c69f0a..bd05855f0a19 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -142,7 +142,7 @@ statement | REFRESH TABLE tableIdentifier #refreshTable | REFRESH .*? #refreshResource | CACHE LAZY? TABLE tableIdentifier (AS? query)? #cacheTable - | UNCACHE TABLE tableIdentifier #uncacheTable + | UNCACHE TABLE (IF EXISTS)? tableIdentifier #uncacheTable | CLEAR CACHE #clearCache | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE tableIdentifier partitionSpec? #loadData diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 112d812cb6c7..df509a56792e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -233,7 +233,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Create an [[UncacheTableCommand]] logical plan. */ override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) { - UncacheTableCommand(visitTableIdentifier(ctx.tableIdentifier)) + UncacheTableCommand(visitTableIdentifier(ctx.tableIdentifier), ctx.EXISTS != null) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index c31f4dc9aba4..336f14dd97ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -49,10 +50,17 @@ case class CacheTableCommand( } -case class UncacheTableCommand(tableIdent: TableIdentifier) extends RunnableCommand { +case class UncacheTableCommand( + tableIdent: TableIdentifier, + ifExists: Boolean) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { - sparkSession.catalog.uncacheTable(tableIdent.quotedString) + val tableId = tableIdent.quotedString + try { + sparkSession.catalog.uncacheTable(tableId) + } catch { + case _: NoSuchTableException if ifExists => // don't throw + } Seq.empty[Row] } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index fc35304c80ec..3871b3d78588 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -101,13 +101,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("DROP TABLE IF EXISTS nonexistantTable") } - test("correct error on uncache of nonexistant tables") { + test("uncache of nonexistant tables") { + // make sure table doesn't exist + intercept[NoSuchTableException](spark.table("nonexistantTable")) intercept[NoSuchTableException] { spark.catalog.uncacheTable("nonexistantTable") } intercept[NoSuchTableException] { sql("UNCACHE TABLE nonexistantTable") } + sql("UNCACHE TABLE IF EXISTS nonexistantTable") } test("no error on uncache of non-cached table") { From 016bc62859669e93e8190a1b52b01d046282cc22 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 22 Nov 2016 12:20:45 -0800 Subject: [PATCH 287/381] refine test --- .../scala/org/apache/spark/sql/Dataset.scala | 7 ++- .../spark/sql/DatasetToArrowSuite.scala | 53 ++++++++++++++----- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index cb3d11857fca..3bb3a3eb551b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2293,7 +2293,7 @@ class Dataset[T] private[sql]( case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) case StringType => - ArrowType.Utf8.INSTANCE + ArrowType.List.INSTANCE case DoubleType => new ArrowType.FloatingPoint(Precision.DOUBLE) case FloatType => @@ -2354,7 +2354,10 @@ class Dataset[T] private[sql]( case IntegerType => rows.foreach { row => buf.writeInt(row.getInt(idx)) } case StringType => - rows.foreach { row => buf.writeByte(row.getByte(idx)) } + // TODO: Transform String type + rows.foreach { row => + buf.writeBytes(row.getString(idx).getBytes()) + } case DoubleType => rows.foreach { row => buf.writeDouble(row.getDouble(idx)) } case FloatType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala index c4271800f300..bcdbbd5b2705 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala @@ -25,16 +25,17 @@ import java.nio.channels.FileChannel import scala.util.Random import io.netty.buffer.ArrowBuf +import org.apache.arrow.flatbuf.Precision import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.file.ArrowReader import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -case class ArrowIntTest(a: Int, b: Int) -case class ArrowIntDoubleTest(a: Int, b: Double) +case class ArrowTestClass(a: Int, b: Double, c: String) class DatasetToArrowSuite extends QueryTest with SharedSQLContext { @@ -44,12 +45,15 @@ class DatasetToArrowSuite extends QueryTest with SharedSQLContext { @transient var dataset: Dataset[_] = _ @transient var column1: Seq[Int] = _ @transient var column2: Seq[Double] = _ + @transient var column3: Seq[String] = _ override def beforeAll(): Unit = { super.beforeAll() column1 = Seq.fill(numElements)(Random.nextInt) column2 = Seq.fill(numElements)(Random.nextDouble) - dataset = column1.zip(column2).map{ case (c1, c2) => ArrowIntDoubleTest(c1, c2) }.toDS() + column3 = Seq.fill(numElements)(Random.nextString(Random.nextInt(100))) + dataset = column1.zip(column2).zip(column3) + .map{ case ((c1, c2), c3) => ArrowTestClass(c1, c2, c3) }.toDS() } test("Collect as arrow to python") { @@ -63,32 +67,50 @@ class DatasetToArrowSuite extends QueryTest with SharedSQLContext { val footer = reader.readFooter() val schema = footer.getSchema - assert(schema.getFields.size() === dataset.schema.fields.length) - assert(schema.getFields.get(0).getName === dataset.schema.fields(0).name) - assert(schema.getFields.get(0).isNullable === dataset.schema.fields(0).nullable) - assert(schema.getFields.get(0).getType.isInstanceOf[ArrowType.Int]) - assert(schema.getFields.get(1).getName === dataset.schema.fields(1).name) - assert(schema.getFields.get(1).isNullable === dataset.schema.fields(1).nullable) - assert(schema.getFields.get(1).getType.isInstanceOf[ArrowType.FloatingPoint]) + val numCols = schema.getFields.size() + assert(numCols === dataset.schema.fields.length) + for (i <- 0 to schema.getFields.size()) { + val arrowField = schema.getFields.get(i) + val sparkField = dataset.schema.fields(i) + assert(arrowField.getName === sparkField.name) + assert(arrowField.isNullable === sparkField.nullable) + assert(DatasetToArrowSuite.compareSchemaTypes(arrowField.getType, sparkField.dataType)) + } val blockMetadata = footer.getRecordBatches assert(blockMetadata.size() === 1) val recordBatch = reader.readRecordBatch(blockMetadata.get(0)) val nodes = recordBatch.getNodes - assert(nodes.size() === 2) + assert(nodes.size() === numCols) val firstNode = nodes.get(0) assert(firstNode.getLength === numElements) assert(firstNode.getNullCount === 0) val buffers = recordBatch.getBuffers - assert(buffers.size() === 4) + assert(buffers.size() === numCols * 2) val column1Read = receiver.getIntArray(buffers.get(1)) assert(column1Read === column1) val column2Read = receiver.getDoubleArray(buffers.get(3)) assert(column2Read === column2) + // TODO: Check column 3 is right + } +} + +object DatasetToArrowSuite { + def compareSchemaTypes(at: ArrowType, dt: DataType): Boolean = { + (at, dt) match { + case (_: ArrowType.Int, _: IntegerType) => true + case (_: ArrowType.FloatingPoint, _: DoubleType) => + at.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.DOUBLE + case (_: ArrowType.FloatingPoint, _: FloatType) => + at.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.SINGLE + case (_: ArrowType.Utf8, _: StringType) => true + case (_: ArrowType.Bool, _: BooleanType) => true + case _ => false + } } } @@ -110,6 +132,13 @@ class RecordBatchReceiver { resultArray } + def getStringArray(buf: ArrowBuf): Array[String] = { + val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asCharBuffer() + val resultArray = Array.ofDim[String](buffer.remaining()) + // TODO: Get String Array back + resultArray + } + private def array(buf: ArrowBuf): Array[Byte] = { val bytes = Array.ofDim[Byte](buf.readableBytes()) buf.readBytes(bytes) From 2fd101b2f0028e005fbb0bdd29e59af37aa637da Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 22 Nov 2016 14:15:57 -0800 Subject: [PATCH 288/381] [SPARK-18373][SPARK-18529][SS][KAFKA] Make failOnDataLoss=false work with Spark jobs ## What changes were proposed in this pull request? This PR adds `CachedKafkaConsumer.getAndIgnoreLostData` to handle corner cases of `failOnDataLoss=false`. It also resolves [SPARK-18529](https://issues.apache.org/jira/browse/SPARK-18529) after refactoring codes: Timeout will throw a TimeoutException. ## How was this patch tested? Because I cannot find any way to manually control the Kafka server to clean up logs, it's impossible to write unit tests for each corner case. Therefore, I just created `test("stress test for failOnDataLoss=false")` which should cover most of corner cases. I also modified some existing tests to test for both `failOnDataLoss=false` and `failOnDataLoss=true` to make sure it doesn't break existing logic. Author: Shixiong Zhu Closes #15820 from zsxwing/failOnDataLoss. --- .../sql/kafka010/CachedKafkaConsumer.scala | 236 ++++++++++++-- .../spark/sql/kafka010/KafkaSource.scala | 23 +- .../spark/sql/kafka010/KafkaSourceRDD.scala | 42 ++- .../spark/sql/kafka010/KafkaSourceSuite.scala | 297 +++++++++++++++--- .../spark/sql/kafka010/KafkaTestUtils.scala | 20 +- 5 files changed, 523 insertions(+), 95 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index 3b5a96534f9b..3f438e99185b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -18,12 +18,16 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} +import java.util.concurrent.TimeoutException -import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer} +import scala.collection.JavaConverters._ + +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer, OffsetOutOfRangeException} import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging +import org.apache.spark.sql.kafka010.KafkaSource._ /** @@ -34,10 +38,18 @@ import org.apache.spark.internal.Logging private[kafka010] case class CachedKafkaConsumer private( topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) extends Logging { + import CachedKafkaConsumer._ private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - private val consumer = { + private var consumer = createConsumer + + /** Iterator to the already fetch data */ + private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + private var nextOffsetInFetchedData = UNKNOWN_OFFSET + + /** Create a KafkaConsumer to fetch records for `topicPartition` */ + private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = { val c = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) val tps = new ju.ArrayList[TopicPartition]() tps.add(topicPartition) @@ -45,42 +57,193 @@ private[kafka010] case class CachedKafkaConsumer private( c } - /** Iterator to the already fetch data */ - private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] - private var nextOffsetInFetchedData = -2L - /** - * Get the record for the given offset, waiting up to timeout ms if IO is necessary. - * Sequential forward access will use buffers, but random access will be horribly inefficient. + * Get the record for the given offset if available. Otherwise it will either throw error + * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), + * or null. + * + * @param offset the offset to fetch. + * @param untilOffset the max offset to fetch. Exclusive. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at + * offset if available, or throw exception.when `failOnDataLoss` is `false`, + * this method will either return record at offset if available, or return + * the next earliest available record less than untilOffset, or null. It + * will not throw any exception. */ - def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[Array[Byte], Array[Byte]] = { + def get( + offset: Long, + untilOffset: Long, + pollTimeoutMs: Long, + failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { + require(offset < untilOffset, + s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]") logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset") - if (offset != nextOffsetInFetchedData) { - logInfo(s"Initial fetch for $topicPartition $offset") - seek(offset) - poll(pollTimeoutMs) + // The following loop is basically for `failOnDataLoss = false`. When `failOnDataLoss` is + // `false`, first, we will try to fetch the record at `offset`. If no such record exists, then + // we will move to the next available offset within `[offset, untilOffset)` and retry. + // If `failOnDataLoss` is `true`, the loop body will be executed only once. + var toFetchOffset = offset + while (toFetchOffset != UNKNOWN_OFFSET) { + try { + return fetchData(toFetchOffset, pollTimeoutMs) + } catch { + case e: OffsetOutOfRangeException => + // When there is some error thrown, it's better to use a new consumer to drop all cached + // states in the old consumer. We don't need to worry about the performance because this + // is not a common path. + resetConsumer() + reportDataLoss(failOnDataLoss, s"Cannot fetch offset $toFetchOffset", e) + toFetchOffset = getEarliestAvailableOffsetBetween(toFetchOffset, untilOffset) + } } + resetFetchedData() + null + } - if (!fetchedData.hasNext()) { poll(pollTimeoutMs) } - assert(fetchedData.hasNext(), - s"Failed to get records for $groupId $topicPartition $offset " + - s"after polling for $pollTimeoutMs") - var record = fetchedData.next() + /** + * Return the next earliest available offset in [offset, untilOffset). If all offsets in + * [offset, untilOffset) are invalid (e.g., the topic is deleted and recreated), it will return + * `UNKNOWN_OFFSET`. + */ + private def getEarliestAvailableOffsetBetween(offset: Long, untilOffset: Long): Long = { + val (earliestOffset, latestOffset) = getAvailableOffsetRange() + logWarning(s"Some data may be lost. Recovering from the earliest offset: $earliestOffset") + if (offset >= latestOffset || earliestOffset >= untilOffset) { + // [offset, untilOffset) and [earliestOffset, latestOffset) have no overlap, + // either + // -------------------------------------------------------- + // ^ ^ ^ ^ + // | | | | + // earliestOffset latestOffset offset untilOffset + // + // or + // -------------------------------------------------------- + // ^ ^ ^ ^ + // | | | | + // offset untilOffset earliestOffset latestOffset + val warningMessage = + s""" + |The current available offset range is [$earliestOffset, $latestOffset). + | Offset ${offset} is out of range, and records in [$offset, $untilOffset) will be + | skipped ${additionalMessage(failOnDataLoss = false)} + """.stripMargin + logWarning(warningMessage) + UNKNOWN_OFFSET + } else if (offset >= earliestOffset) { + // ----------------------------------------------------------------------------- + // ^ ^ ^ ^ + // | | | | + // earliestOffset offset min(untilOffset,latestOffset) max(untilOffset, latestOffset) + // + // This will happen when a topic is deleted and recreated, and new data are pushed very fast, + // then we will see `offset` disappears first then appears again. Although the parameters + // are same, the state in Kafka cluster is changed, so the outer loop won't be endless. + logWarning(s"Found a disappeared offset $offset. " + + s"Some data may be lost ${additionalMessage(failOnDataLoss = false)}") + offset + } else { + // ------------------------------------------------------------------------------ + // ^ ^ ^ ^ + // | | | | + // offset earliestOffset min(untilOffset,latestOffset) max(untilOffset, latestOffset) + val warningMessage = + s""" + |The current available offset range is [$earliestOffset, $latestOffset). + | Offset ${offset} is out of range, and records in [$offset, $earliestOffset) will be + | skipped ${additionalMessage(failOnDataLoss = false)} + """.stripMargin + logWarning(warningMessage) + earliestOffset + } + } - if (record.offset != offset) { - logInfo(s"Buffer miss for $groupId $topicPartition $offset") + /** + * Get the record at `offset`. + * + * @throws OffsetOutOfRangeException if `offset` is out of range + * @throws TimeoutException if cannot fetch the record in `pollTimeoutMs` milliseconds. + */ + private def fetchData( + offset: Long, + pollTimeoutMs: Long): ConsumerRecord[Array[Byte], Array[Byte]] = { + if (offset != nextOffsetInFetchedData || !fetchedData.hasNext()) { + // This is the first fetch, or the last pre-fetched data has been drained. + // Seek to the offset because we may call seekToBeginning or seekToEnd before this. seek(offset) poll(pollTimeoutMs) - assert(fetchedData.hasNext(), - s"Failed to get records for $groupId $topicPartition $offset " + - s"after polling for $pollTimeoutMs") - record = fetchedData.next() + } + + if (!fetchedData.hasNext()) { + // We cannot fetch anything after `poll`. Two possible cases: + // - `offset` is out of range so that Kafka returns nothing. Just throw + // `OffsetOutOfRangeException` to let the caller handle it. + // - Cannot fetch any data before timeout. TimeoutException will be thrown. + val (earliestOffset, latestOffset) = getAvailableOffsetRange() + if (offset < earliestOffset || offset >= latestOffset) { + throw new OffsetOutOfRangeException( + Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava) + } else { + throw new TimeoutException( + s"Cannot fetch record for offset $offset in $pollTimeoutMs milliseconds") + } + } else { + val record = fetchedData.next() + nextOffsetInFetchedData = record.offset + 1 + // `seek` is always called before "poll". So "record.offset" must be same as "offset". assert(record.offset == offset, - s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset") + s"The fetched data has a different offset: expected $offset but was ${record.offset}") + record } + } + + /** Create a new consumer and reset cached states */ + private def resetConsumer(): Unit = { + consumer.close() + consumer = createConsumer + resetFetchedData() + } - nextOffsetInFetchedData = offset + 1 - record + /** Reset the internal pre-fetched data. */ + private def resetFetchedData(): Unit = { + nextOffsetInFetchedData = UNKNOWN_OFFSET + fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + } + + /** + * Return an addition message including useful message and instruction. + */ + private def additionalMessage(failOnDataLoss: Boolean): String = { + if (failOnDataLoss) { + s"(GroupId: $groupId, TopicPartition: $topicPartition). " + + s"$INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE" + } else { + s"(GroupId: $groupId, TopicPartition: $topicPartition). " + + s"$INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE" + } + } + + /** + * Throw an exception or log a warning as per `failOnDataLoss`. + */ + private def reportDataLoss( + failOnDataLoss: Boolean, + message: String, + cause: Throwable = null): Unit = { + val finalMessage = s"$message ${additionalMessage(failOnDataLoss)}" + if (failOnDataLoss) { + if (cause != null) { + throw new IllegalStateException(finalMessage) + } else { + throw new IllegalStateException(finalMessage, cause) + } + } else { + if (cause != null) { + logWarning(finalMessage) + } else { + logWarning(finalMessage, cause) + } + } } private def close(): Unit = consumer.close() @@ -96,10 +259,24 @@ private[kafka010] case class CachedKafkaConsumer private( logDebug(s"Polled $groupId ${p.partitions()} ${r.size}") fetchedData = r.iterator } + + /** + * Return the available offset range of the current partition. It's a pair of the earliest offset + * and the latest offset. + */ + private def getAvailableOffsetRange(): (Long, Long) = { + consumer.seekToBeginning(Set(topicPartition).asJava) + val earliestOffset = consumer.position(topicPartition) + consumer.seekToEnd(Set(topicPartition).asJava) + val latestOffset = consumer.position(topicPartition) + (earliestOffset, latestOffset) + } } private[kafka010] object CachedKafkaConsumer extends Logging { + private val UNKNOWN_OFFSET = -2L + private case class CacheKey(groupId: String, topicPartition: TopicPartition) private lazy val cache = { @@ -140,7 +317,10 @@ private[kafka010] object CachedKafkaConsumer extends Logging { // If this is reattempt at running the task, then invalidate cache and start with // a new consumer if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) { - cache.remove(key) + val removedConsumer = cache.remove(key) + if (removedConsumer != null) { + removedConsumer.close() + } new CachedKafkaConsumer(topicPartition, kafkaParams) } else { if (!cache.containsKey(key)) { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 341081a338c0..1d0d402b82a3 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -281,7 +281,7 @@ private[kafka010] case class KafkaSource( // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. val rdd = new KafkaSourceRDD( - sc, executorKafkaParams, offsetRanges, pollTimeoutMs).map { cr => + sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr => Row(cr.key, cr.value, cr.topic, cr.partition, cr.offset, cr.timestamp, cr.timestampType.id) } @@ -463,10 +463,9 @@ private[kafka010] case class KafkaSource( */ private def reportDataLoss(message: String): Unit = { if (failOnDataLoss) { - throw new IllegalStateException(message + - ". Set the source option 'failOnDataLoss' to 'false' if you want to ignore these checks.") + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") } else { - logWarning(message) + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") } } } @@ -475,6 +474,22 @@ private[kafka010] case class KafkaSource( /** Companion object for the [[KafkaSource]]. */ private[kafka010] object KafkaSource { + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you want your streaming query to fail on such cases, set the source + | option "failOnDataLoss" to "true". + """.stripMargin + + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you don't want your streaming query to fail on such cases, set the + | source option "failOnDataLoss" to "false". + """.stripMargin + def kafkaSchema: StructType = StructType(Seq( StructField("key", BinaryType), StructField("value", BinaryType), diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 802dd040aed9..244cd2c225bd 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -28,6 +28,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.NextIterator /** Offset range that one partition of the KafkaSourceRDD has to read */ @@ -61,7 +62,8 @@ private[kafka010] class KafkaSourceRDD( sc: SparkContext, executorKafkaParams: ju.Map[String, Object], offsetRanges: Seq[KafkaSourceRDDOffsetRange], - pollTimeoutMs: Long) + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends RDD[ConsumerRecord[Array[Byte], Array[Byte]]](sc, Nil) { override def persist(newLevel: StorageLevel): this.type = { @@ -130,23 +132,31 @@ private[kafka010] class KafkaSourceRDD( logInfo(s"Beginning offset ${range.fromOffset} is the same as ending offset " + s"skipping ${range.topic} ${range.partition}") Iterator.empty - } else { - - val consumer = CachedKafkaConsumer.getOrCreate( - range.topic, range.partition, executorKafkaParams) - var requestOffset = range.fromOffset - - logDebug(s"Creating iterator for $range") - - new Iterator[ConsumerRecord[Array[Byte], Array[Byte]]]() { - override def hasNext(): Boolean = requestOffset < range.untilOffset - override def next(): ConsumerRecord[Array[Byte], Array[Byte]] = { - assert(hasNext(), "Can't call next() once untilOffset has been reached") - val r = consumer.get(requestOffset, pollTimeoutMs) - requestOffset += 1 - r + new NextIterator[ConsumerRecord[Array[Byte], Array[Byte]]]() { + val consumer = CachedKafkaConsumer.getOrCreate( + range.topic, range.partition, executorKafkaParams) + var requestOffset = range.fromOffset + + override def getNext(): ConsumerRecord[Array[Byte], Array[Byte]] = { + if (requestOffset >= range.untilOffset) { + // Processed all offsets in this partition. + finished = true + null + } else { + val r = consumer.get(requestOffset, range.untilOffset, pollTimeoutMs, failOnDataLoss) + if (r == null) { + // Losing some data. Skip the rest offsets in this partition. + finished = true + null + } else { + requestOffset = r.offset + 1 + r + } + } } + + override protected def close(): Unit = {} } } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 89e713f92df4..cd52fd93d10a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.kafka010 +import java.util.Properties +import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata @@ -27,8 +31,9 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ +import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.streaming.{ ProcessingTime, StreamTest } +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.test.SharedSQLContext abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -202,7 +207,7 @@ class KafkaSourceSuite extends KafkaSourceTest { test("cannot stop Kafka stream") { val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 5) + testUtils.createTopic(topic, partitions = 5) testUtils.sendMessages(topic, (101 to 105).map { _.toString }.toArray) val reader = spark @@ -223,52 +228,85 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } - test("assign from latest offsets") { - val topic = newTopic() - testFromLatestOffsets(topic, false, "assign" -> assignString(topic, 0 to 4)) - } + for (failOnDataLoss <- Seq(true, false)) { + test(s"assign from latest offsets (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromLatestOffsets( + topic, + addPartitions = false, + failOnDataLoss = failOnDataLoss, + "assign" -> assignString(topic, 0 to 4)) + } - test("assign from earliest offsets") { - val topic = newTopic() - testFromEarliestOffsets(topic, false, "assign" -> assignString(topic, 0 to 4)) - } + test(s"assign from earliest offsets (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromEarliestOffsets( + topic, + addPartitions = false, + failOnDataLoss = failOnDataLoss, + "assign" -> assignString(topic, 0 to 4)) + } - test("assign from specific offsets") { - val topic = newTopic() - testFromSpecificOffsets(topic, "assign" -> assignString(topic, 0 to 4)) - } + test(s"assign from specific offsets (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromSpecificOffsets( + topic, + failOnDataLoss = failOnDataLoss, + "assign" -> assignString(topic, 0 to 4), + "failOnDataLoss" -> failOnDataLoss.toString) + } - test("subscribing topic by name from latest offsets") { - val topic = newTopic() - testFromLatestOffsets(topic, true, "subscribe" -> topic) - } + test(s"subscribing topic by name from latest offsets (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromLatestOffsets( + topic, + addPartitions = true, + failOnDataLoss = failOnDataLoss, + "subscribe" -> topic) + } - test("subscribing topic by name from earliest offsets") { - val topic = newTopic() - testFromEarliestOffsets(topic, true, "subscribe" -> topic) - } + test(s"subscribing topic by name from earliest offsets (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromEarliestOffsets( + topic, + addPartitions = true, + failOnDataLoss = failOnDataLoss, + "subscribe" -> topic) + } - test("subscribing topic by name from specific offsets") { - val topic = newTopic() - testFromSpecificOffsets(topic, "subscribe" -> topic) - } + test(s"subscribing topic by name from specific offsets (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromSpecificOffsets(topic, failOnDataLoss = failOnDataLoss, "subscribe" -> topic) + } - test("subscribing topic by pattern from latest offsets") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-suffix" - testFromLatestOffsets(topic, true, "subscribePattern" -> s"$topicPrefix-.*") - } + test(s"subscribing topic by pattern from latest offsets (failOnDataLoss: $failOnDataLoss)") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromLatestOffsets( + topic, + addPartitions = true, + failOnDataLoss = failOnDataLoss, + "subscribePattern" -> s"$topicPrefix-.*") + } - test("subscribing topic by pattern from earliest offsets") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-suffix" - testFromEarliestOffsets(topic, true, "subscribePattern" -> s"$topicPrefix-.*") - } + test(s"subscribing topic by pattern from earliest offsets (failOnDataLoss: $failOnDataLoss)") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromEarliestOffsets( + topic, + addPartitions = true, + failOnDataLoss = failOnDataLoss, + "subscribePattern" -> s"$topicPrefix-.*") + } - test("subscribing topic by pattern from specific offsets") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-suffix" - testFromSpecificOffsets(topic, "subscribePattern" -> s"$topicPrefix-.*") + test(s"subscribing topic by pattern from specific offsets (failOnDataLoss: $failOnDataLoss)") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromSpecificOffsets( + topic, + failOnDataLoss = failOnDataLoss, + "subscribePattern" -> s"$topicPrefix-.*") + } } test("subscribing topic by pattern with topic deletions") { @@ -413,13 +451,59 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } + test("delete a topic when a Spark job is running") { + KafkaSourceSuite.collectedData.clear() + + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribe", topic) + // If a topic is deleted and we try to poll data starting from offset 0, + // the Kafka consumer will just block until timeout and return an empty result. + // So set the timeout to 1 second to make this test fast. + .option("kafkaConsumer.pollTimeoutMs", "1000") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + KafkaSourceSuite.globalTestUtils = testUtils + // The following ForeachWriter will delete the topic before fetching data from Kafka + // in executors. + val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + override def open(partitionId: Long, version: Long): Boolean = { + KafkaSourceSuite.globalTestUtils.deleteTopic(topic) + true + } + + override def process(value: Int): Unit = { + KafkaSourceSuite.collectedData.add(value) + } + + override def close(errorOrNull: Throwable): Unit = {} + }).start() + query.processAllAvailable() + query.stop() + // `failOnDataLoss` is `false`, we should not fail the query + assert(query.exception.isEmpty) + } + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } - private def testFromSpecificOffsets(topic: String, options: (String, String)*): Unit = { + private def testFromSpecificOffsets( + topic: String, + failOnDataLoss: Boolean, + options: (String, String)*): Unit = { val partitionOffsets = Map( new TopicPartition(topic, 0) -> -2L, new TopicPartition(topic, 1) -> -1L, @@ -448,6 +532,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .option("startingOffsets", startingOffsets) .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("failOnDataLoss", failOnDataLoss.toString) options.foreach { case (k, v) => reader.option(k, v) } val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -469,6 +554,7 @@ class KafkaSourceSuite extends KafkaSourceTest { private def testFromLatestOffsets( topic: String, addPartitions: Boolean, + failOnDataLoss: Boolean, options: (String, String)*): Unit = { testUtils.createTopic(topic, partitions = 5) testUtils.sendMessages(topic, Array("-1")) @@ -480,6 +566,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .option("startingOffsets", s"latest") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("failOnDataLoss", failOnDataLoss.toString) options.foreach { case (k, v) => reader.option(k, v) } val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -513,6 +600,7 @@ class KafkaSourceSuite extends KafkaSourceTest { private def testFromEarliestOffsets( topic: String, addPartitions: Boolean, + failOnDataLoss: Boolean, options: (String, String)*): Unit = { testUtils.createTopic(topic, partitions = 5) testUtils.sendMessages(topic, (1 to 3).map { _.toString }.toArray) @@ -524,6 +612,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .option("startingOffsets", s"earliest") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("failOnDataLoss", failOnDataLoss.toString) options.foreach { case (k, v) => reader.option(k, v) } val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -552,6 +641,11 @@ class KafkaSourceSuite extends KafkaSourceTest { } } +object KafkaSourceSuite { + @volatile var globalTestUtils: KafkaTestUtils = _ + val collectedData = new ConcurrentLinkedQueue[Any]() +} + class KafkaSourceStressSuite extends KafkaSourceTest { @@ -615,7 +709,7 @@ class KafkaSourceStressSuite extends KafkaSourceTest { } }) case 2 => // Add new partitions - AddKafkaData(topics.toSet, d: _*)(message = "Add partitiosn", + AddKafkaData(topics.toSet, d: _*)(message = "Add partition", topicAction = (topic, partition) => { testUtils.addPartitions(topic, partition.get + nextInt(1, 6)) }) @@ -626,3 +720,122 @@ class KafkaSourceStressSuite extends KafkaSourceTest { iterations = 50) } } + +class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with SharedSQLContext { + + import testImplicits._ + + private var testUtils: KafkaTestUtils = _ + + private val topicId = new AtomicInteger(0) + + private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}" + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils { + override def brokerConfiguration: Properties = { + val props = super.brokerConfiguration + // Try to make Kafka clean up messages as fast as possible. However, there is a hard-code + // 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) so this test should run at + // least 30 seconds. + props.put("log.cleaner.backoff.ms", "100") + props.put("log.segment.bytes", "40") + props.put("log.retention.bytes", "40") + props.put("log.retention.check.interval.ms", "100") + props.put("delete.retention.ms", "10") + props.put("log.flush.scheduler.interval.ms", "10") + props + } + } + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + super.afterAll() + } + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + + override def open(partitionId: Long, version: Long): Boolean = { + true + } + + override def process(value: Int): Unit = { + // Slow down the processing speed so that messages may be aged out. + Thread.sleep(Random.nextInt(500)) + } + + override def close(errorOrNull: Throwable): Unit = { + } + }).start() + + val testTime = 1.minutes + val startTime = System.currentTimeMillis() + // Track the current existing topics + val topics = mutable.ArrayBuffer[String]() + // Track topics that have been deleted + val deletedTopics = mutable.Set[String]() + while (System.currentTimeMillis() - testTime.toMillis < startTime) { + Random.nextInt(10) match { + case 0 => // Create a new topic + val topic = newTopic() + topics += topic + // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small + // chance that a topic will be recreated after deletion due to the asynchronous update. + // Hence, always overwrite to handle this race condition. + testUtils.createTopic(topic, partitions = 1, overwrite = true) + logInfo(s"Create topic $topic") + case 1 if topics.nonEmpty => // Delete an existing topic + val topic = topics.remove(Random.nextInt(topics.size)) + testUtils.deleteTopic(topic) + logInfo(s"Delete topic $topic") + deletedTopics += topic + case 2 if deletedTopics.nonEmpty => // Recreate a topic that was deleted. + val topic = deletedTopics.toSeq(Random.nextInt(deletedTopics.size)) + deletedTopics -= topic + topics += topic + // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small + // chance that a topic will be recreated after deletion due to the asynchronous update. + // Hence, always overwrite to handle this race condition. + testUtils.createTopic(topic, partitions = 1, overwrite = true) + logInfo(s"Create topic $topic") + case 3 => + Thread.sleep(1000) + case _ => // Push random messages + for (topic <- topics) { + val size = Random.nextInt(10) + for (_ <- 0 until size) { + testUtils.sendMessages(topic, Array(Random.nextInt(10).toString)) + } + } + } + // `failOnDataLoss` is `false`, we should not fail the query + if (query.exception.nonEmpty) { + throw query.exception.get + } + } + + query.stop() + // `failOnDataLoss` is `false`, we should not fail the query + if (query.exception.nonEmpty) { + throw query.exception.get + } + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 9b24ccdd560e..f43917e151c5 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -155,8 +155,16 @@ class KafkaTestUtils extends Logging { } /** Create a Kafka topic and wait until it is propagated to the whole cluster */ - def createTopic(topic: String, partitions: Int): Unit = { - AdminUtils.createTopic(zkUtils, topic, partitions, 1) + def createTopic(topic: String, partitions: Int, overwrite: Boolean = false): Unit = { + var created = false + while (!created) { + try { + AdminUtils.createTopic(zkUtils, topic, partitions, 1) + created = true + } catch { + case e: kafka.common.TopicExistsException if overwrite => deleteTopic(topic) + } + } // wait until metadata is propagated (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) @@ -244,7 +252,7 @@ class KafkaTestUtils extends Logging { offsets } - private def brokerConfiguration: Properties = { + protected def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") props.put("host.name", "localhost") @@ -302,9 +310,11 @@ class KafkaTestUtils extends Logging { } checkpoints.forall(checkpointsPerLogDir => !checkpointsPerLogDir.contains(tp)) }) - deletePath && topicPath && replicaManager && logManager && cleaner + // ensure the topic is gone + val deleted = !zkUtils.getAllTopics().contains(topic) + deletePath && topicPath && replicaManager && logManager && cleaner && deleted } - eventually(timeout(10.seconds)) { + eventually(timeout(60.seconds)) { assert(isDeleted, s"$topic not deleted after timeout") } } From 9c42d4a76ca8046fcca2e20067f2aa461977e65a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Nov 2016 15:10:49 -0800 Subject: [PATCH 289/381] [SPARK-16803][SQL] SaveAsTable does not work when target table is a Hive serde table ### What changes were proposed in this pull request? In Spark 2.0, `SaveAsTable` does not work when the target table is a Hive serde table, but Spark 1.6 works. **Spark 1.6** ``` Scala scala> sql("create table sample.sample stored as SEQUENCEFILE as select 1 as key, 'abc' as value") res2: org.apache.spark.sql.DataFrame = [] scala> val df = sql("select key, value as value from sample.sample") df: org.apache.spark.sql.DataFrame = [key: int, value: string] scala> df.write.mode("append").saveAsTable("sample.sample") scala> sql("select * from sample.sample").show() +---+-----+ |key|value| +---+-----+ | 1| abc| | 1| abc| +---+-----+ ``` **Spark 2.0** ``` Scala scala> df.write.mode("append").saveAsTable("sample.sample") org.apache.spark.sql.AnalysisException: Saving data in MetastoreRelation sample, sample is not supported.; ``` So far, we do not plan to support it in Spark 2.1 due to the risk. Spark 1.6 works because it internally uses insertInto. But, if we change it back it will break the semantic of saveAsTable (this method uses by-name resolution instead of using by-position resolution used by insertInto). More extra changes are needed to support `hive` as a `format` in DataFrameWriter. Instead, users should use insertInto API. This PR corrects the error messages. Users can understand how to bypass it before we support it in a separate PR. ### How was this patch tested? Test cases are added Author: gatorsmile Closes #15926 from gatorsmile/saveAsTableFix5. --- .../command/createDataSourceTables.scala | 4 ++++ .../sql/hive/MetastoreDataSourcesSuite.scala | 20 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 7e16e43f2bb0..add732c1afc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -175,6 +175,10 @@ case class CreateDataSourceTableAsSelectCommand( existingSchema = Some(l.schema) case s: SimpleCatalogRelation if DDLUtils.isDatasourceTable(s.metadata) => existingSchema = Some(s.metadata.schema) + case c: CatalogRelation if c.catalogTable.provider == Some(DDLUtils.HIVE_PROVIDER) => + throw new AnalysisException("Saving data in the Hive serde table " + + s"${c.catalogTable.identifier} is not supported yet. Please use the " + + "insertInto() API as an alternative..") case o => throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 4ab1a54edc46..c7cc75fbc8a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -413,6 +413,26 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } + test("saveAsTable(CTAS) using append and insertInto when the target table is Hive serde") { + val tableName = "tab1" + withTable(tableName) { + sql(s"CREATE TABLE $tableName STORED AS SEQUENCEFILE AS SELECT 1 AS key, 'abc' AS value") + + val df = sql(s"SELECT key, value FROM $tableName") + val e = intercept[AnalysisException] { + df.write.mode(SaveMode.Append).saveAsTable(tableName) + }.getMessage + assert(e.contains("Saving data in the Hive serde table `default`.`tab1` is not supported " + + "yet. Please use the insertInto() API as an alternative.")) + + df.write.insertInto(tableName) + checkAnswer( + sql(s"SELECT * FROM $tableName"), + Row(1, "abc") :: Row(1, "abc") :: Nil + ) + } + } + test("SPARK-5839 HiveMetastoreCatalog does not recognize table aliases of data source tables.") { withTable("savedJsonTable") { // Save the df as a managed table (by not specifying the path). From 39a1d30636857715247c82d551b200e1c331ad69 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 22 Nov 2016 15:57:07 -0800 Subject: [PATCH 290/381] [SPARK-18533] Raise correct error upon specification of schema for datasource tables created using CTAS ## What changes were proposed in this pull request? Fixes the inconsistency of error raised between data source and hive serde tables when schema is specified in CTAS scenario. In the process the grammar for create table (datasource) is simplified. **before:** ``` SQL spark-sql> create table t2 (c1 int, c2 int) using parquet as select * from t1; Error in query: mismatched input 'as' expecting {, '.', 'OPTIONS', 'CLUSTERED', 'PARTITIONED'}(line 1, pos 64) == SQL == create table t2 (c1 int, c2 int) using parquet as select * from t1 ----------------------------------------------------------------^^^ ``` **After:** ```SQL spark-sql> create table t2 (c1 int, c2 int) using parquet as select * from t1 > ; Error in query: Operation not allowed: Schema may not be specified in a Create Table As Select (CTAS) statement(line 1, pos 0) == SQL == create table t2 (c1 int, c2 int) using parquet as select * from t1 ^^^ ``` ## How was this patch tested? Added a new test in CreateTableAsSelectSuite Author: Dilip Biswal Closes #15968 from dilipbiswal/ctas. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 6 +---- .../spark/sql/execution/SparkSqlParser.scala | 24 +++++++++++++++++-- .../sources/CreateTableAsSelectSuite.scala | 9 +++++++ 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index bd05855f0a19..4531fe4a0eba 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -71,11 +71,7 @@ statement | createTableHeader ('(' colTypeList ')')? tableProvider (OPTIONS tablePropertyList)? (PARTITIONED BY partitionColumnNames=identifierList)? - bucketSpec? #createTableUsing - | createTableHeader tableProvider - (OPTIONS tablePropertyList)? - (PARTITIONED BY partitionColumnNames=identifierList)? - bucketSpec? AS? query #createTableUsing + bucketSpec? (AS? query)? #createTableUsing | createTableHeader ('(' columns=colTypeList ')')? (COMMENT STRING)? (PARTITIONED BY '(' partitionColumns=colTypeList ')')? diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index df509a56792e..0300bfe1ece3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -322,7 +322,20 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create a [[CreateTable]] logical plan. + * Create a data source table, returning a [[CreateTable]] logical plan. + * + * Expected format: + * {{{ + * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name + * USING table_provider + * [OPTIONS table_property_list] + * [PARTITIONED BY (col_name, col_name, ...)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [AS select_statement]; + * }}} */ override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) @@ -371,6 +384,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { operationNotAllowed("CREATE TEMPORARY TABLE ... USING ... AS query", ctx) } + // Don't allow explicit specification of schema for CTAS + if (schema.nonEmpty) { + operationNotAllowed( + "Schema may not be specified in a Create Table As Select (CTAS) statement", + ctx) + } CreateTable(tableDesc, mode, Some(query)) } else { if (temp) { @@ -1052,7 +1071,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { "CTAS statement." operationNotAllowed(errorMessage, ctx) } - // Just use whatever is projected in the select statement as our schema + + // Don't allow explicit specification of schema for CTAS. if (schema.nonEmpty) { operationNotAllowed( "Schema may not be specified in a Create Table As Select (CTAS) statement", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 5cc9467395ad..61939fe5ef5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -249,4 +249,13 @@ class CreateTableAsSelectSuite } } } + + test("specifying the column list for CTAS") { + withTable("t") { + val e = intercept[ParseException] { + sql("CREATE TABLE t (a int, b int) USING parquet AS SELECT 1, 2") + }.getMessage + assert(e.contains("Schema may not be specified in a Create Table As Select (CTAS)")) + } + } } From d0212eb0f22473ee5482fe98dafc24e16ffcfc63 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 22 Nov 2016 16:49:15 -0800 Subject: [PATCH 291/381] [SPARK-18530][SS][KAFKA] Change Kafka timestamp column type to TimestampType ## What changes were proposed in this pull request? Changed Kafka timestamp column type to TimestampType. ## How was this patch tested? `test("Kafka column types")`. Author: Shixiong Zhu Closes #15969 from zsxwing/SPARK-18530. --- .../spark/sql/kafka010/KafkaSource.scala | 16 +++- .../spark/sql/kafka010/KafkaSourceSuite.scala | 81 ++++++++++++++++++- 2 files changed, 93 insertions(+), 4 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 1d0d402b82a3..d9ab4bb4f873 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -32,9 +32,12 @@ import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSource._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.UninterruptibleThread /** @@ -282,7 +285,14 @@ private[kafka010] case class KafkaSource( // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. val rdd = new KafkaSourceRDD( sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr => - Row(cr.key, cr.value, cr.topic, cr.partition, cr.offset, cr.timestamp, cr.timestampType.id) + InternalRow( + cr.key, + cr.value, + UTF8String.fromString(cr.topic), + cr.partition, + cr.offset, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)), + cr.timestampType.id) } logInfo("GetBatch generating RDD of offset range: " + @@ -293,7 +303,7 @@ private[kafka010] case class KafkaSource( currentPartitionOffsets = Some(untilPartitionOffsets) } - sqlContext.createDataFrame(rdd, schema) + sqlContext.internalCreateDataFrame(rdd, schema) } /** Stop this source and free any resources it has allocated. */ @@ -496,7 +506,7 @@ private[kafka010] object KafkaSource { StructField("topic", StringType), StructField("partition", IntegerType), StructField("offset", LongType), - StructField("timestamp", LongType), + StructField("timestamp", TimestampType), StructField("timestampType", IntegerType) )) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index cd52fd93d10a..f9f62581a306 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.kafka010 +import java.nio.charset.StandardCharsets.UTF_8 import java.util.Properties import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random @@ -33,6 +33,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.test.SharedSQLContext @@ -551,6 +552,84 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } + test("Kafka column types") { + val now = System.currentTimeMillis() + val topic = newTopic() + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", s"earliest") + .option("subscribe", topic) + .load() + + val query = kafka + .writeStream + .format("memory") + .outputMode("append") + .queryName("kafkaColumnTypes") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaColumnTypes").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row") + assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row") + assert(row.getAs[String]("topic") === topic, s"Unexpected results: $row") + assert(row.getAs[Int]("partition") === 0, s"Unexpected results: $row") + assert(row.getAs[Long]("offset") === 0L, s"Unexpected results: $row") + // We cannot check the exact timestamp as it's the time that messages were inserted by the + // producer. So here we just use a low bound to make sure the internal conversion works. + assert(row.getAs[java.sql.Timestamp]("timestamp").getTime >= now, s"Unexpected results: $row") + assert(row.getAs[Int]("timestampType") === 0, s"Unexpected results: $row") + query.stop() + } + + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() + val topic = newTopic() + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", s"earliest") + .option("subscribe", topic) + .load() + + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") + + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() + } + private def testFromLatestOffsets( topic: String, addPartitions: Boolean, From 982b82e32e0fc7d30c5d557944a79eb3e6d2da59 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 22 Nov 2016 19:17:48 -0800 Subject: [PATCH 292/381] [SPARK-18501][ML][SPARKR] Fix spark.glm errors when fitting on collinear data ## What changes were proposed in this pull request? * Fix SparkR ```spark.glm``` errors when fitting on collinear data, since ```standard error of coefficients, t value and p value``` are not available in this condition. * Scala/Python GLM summary should throw exception if users get ```standard error of coefficients, t value and p value``` but the underlying WLS was solved by local "l-bfgs". ## How was this patch tested? Add unit tests. Author: Yanbo Liang Closes #15930 from yanboliang/spark-18501. --- R/pkg/R/mllib.R | 21 ++++++-- R/pkg/inst/tests/testthat/test_mllib.R | 9 ++++ .../GeneralizedLinearRegressionWrapper.scala | 54 +++++++++++-------- .../GeneralizedLinearRegression.scala | 46 +++++++++++++--- .../GeneralizedLinearRegressionSuite.scala | 21 ++++++++ 5 files changed, 115 insertions(+), 36 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 265e64e7466f..02bc6456de4d 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -278,8 +278,10 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat #' @param object a fitted generalized linear model. #' @return \code{summary} returns a summary object of the fitted model, a list of components -#' including at least the coefficients, null/residual deviance, null/residual degrees -#' of freedom, AIC and number of iterations IRLS takes. +#' including at least the coefficients matrix (which includes coefficients, standard error +#' of coefficients, t value and p value), null/residual deviance, null/residual degrees of +#' freedom, AIC and number of iterations IRLS takes. If there are collinear columns +#' in you data, the coefficients matrix only provides coefficients. #' #' @rdname spark.glm #' @export @@ -303,9 +305,18 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), } else { dataFrame(callJMethod(jobj, "rDevianceResiduals")) } - coefficients <- matrix(coefficients, ncol = 4) - colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") - rownames(coefficients) <- unlist(features) + # If the underlying WeightedLeastSquares using "normal" solver, we can provide + # coefficients, standard error of coefficients, t value and p value. Otherwise, + # it will be fitted by local "l-bfgs", we can only provide coefficients. + if (length(features) == length(coefficients)) { + coefficients <- matrix(coefficients, ncol = 1) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + } else { + coefficients <- matrix(coefficients, ncol = 4) + colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") + rownames(coefficients) <- unlist(features) + } ans <- list(deviance.resid = deviance.resid, coefficients = coefficients, dispersion = dispersion, null.deviance = null.deviance, deviance = deviance, df.null = df.null, df.residual = df.residual, diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 2a97a51cfa20..467e00cf7919 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -169,6 +169,15 @@ test_that("spark.glm summary", { df <- suppressWarnings(createDataFrame(data)) regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0)) expect_equal(regStats$aic, 14.00976, tolerance = 1e-4) # 14.00976 is from summary() result + + # Test spark.glm works on collinear data + A <- matrix(c(1, 2, 3, 4, 2, 4, 6, 8), 4, 2) + b <- c(1, 2, 3, 4) + data <- as.data.frame(cbind(A, b)) + df <- createDataFrame(data) + stats <- summary(spark.glm(df, b ~ . - 1)) + coefs <- unlist(stats$coefficients) + expect_true(all(abs(c(0.5, 0.25) - coefs) < 1e-4)) }) test_that("spark.glm save/load", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index add4d49110d1..8bcc9fe5d1b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -144,30 +144,38 @@ private[r] object GeneralizedLinearRegressionWrapper features } - val rCoefficientStandardErrors = if (glm.getFitIntercept) { - Array(summary.coefficientStandardErrors.last) ++ - summary.coefficientStandardErrors.dropRight(1) + val rCoefficients: Array[Double] = if (summary.isNormalSolver) { + val rCoefficientStandardErrors = if (glm.getFitIntercept) { + Array(summary.coefficientStandardErrors.last) ++ + summary.coefficientStandardErrors.dropRight(1) + } else { + summary.coefficientStandardErrors + } + + val rTValues = if (glm.getFitIntercept) { + Array(summary.tValues.last) ++ summary.tValues.dropRight(1) + } else { + summary.tValues + } + + val rPValues = if (glm.getFitIntercept) { + Array(summary.pValues.last) ++ summary.pValues.dropRight(1) + } else { + summary.pValues + } + + if (glm.getFitIntercept) { + Array(glm.intercept) ++ glm.coefficients.toArray ++ + rCoefficientStandardErrors ++ rTValues ++ rPValues + } else { + glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues + } } else { - summary.coefficientStandardErrors - } - - val rTValues = if (glm.getFitIntercept) { - Array(summary.tValues.last) ++ summary.tValues.dropRight(1) - } else { - summary.tValues - } - - val rPValues = if (glm.getFitIntercept) { - Array(summary.pValues.last) ++ summary.pValues.dropRight(1) - } else { - summary.pValues - } - - val rCoefficients: Array[Double] = if (glm.getFitIntercept) { - Array(glm.intercept) ++ glm.coefficients.toArray ++ - rCoefficientStandardErrors ++ rTValues ++ rPValues - } else { - glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues + if (glm.getFitIntercept) { + Array(glm.intercept) ++ glm.coefficients.toArray + } else { + glm.coefficients.toArray + } } val rDispersion: Double = summary.dispersion diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 3f9de1fe74c9..f33dd0fd294b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1063,45 +1063,75 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] ( import GeneralizedLinearRegression._ + /** + * Whether the underlying [[WeightedLeastSquares]] using the "normal" solver. + */ + private[ml] val isNormalSolver: Boolean = { + diagInvAtWA.length != 1 || diagInvAtWA(0) != 0 + } + /** * Standard error of estimated coefficients and intercept. + * This value is only available when the underlying [[WeightedLeastSquares]] + * using the "normal" solver. * * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, * then the last element returned corresponds to the intercept. */ @Since("2.0.0") lazy val coefficientStandardErrors: Array[Double] = { - diagInvAtWA.map(_ * dispersion).map(math.sqrt) + if (isNormalSolver) { + diagInvAtWA.map(_ * dispersion).map(math.sqrt) + } else { + throw new UnsupportedOperationException( + "No Std. Error of coefficients available for this GeneralizedLinearRegressionModel") + } } /** * T-statistic of estimated coefficients and intercept. + * This value is only available when the underlying [[WeightedLeastSquares]] + * using the "normal" solver. * * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, * then the last element returned corresponds to the intercept. */ @Since("2.0.0") lazy val tValues: Array[Double] = { - val estimate = if (model.getFitIntercept) { - Array.concat(model.coefficients.toArray, Array(model.intercept)) + if (isNormalSolver) { + val estimate = if (model.getFitIntercept) { + Array.concat(model.coefficients.toArray, Array(model.intercept)) + } else { + model.coefficients.toArray + } + estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 } } else { - model.coefficients.toArray + throw new UnsupportedOperationException( + "No t-statistic available for this GeneralizedLinearRegressionModel") } - estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 } } /** * Two-sided p-value of estimated coefficients and intercept. + * This value is only available when the underlying [[WeightedLeastSquares]] + * using the "normal" solver. * * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, * then the last element returned corresponds to the intercept. */ @Since("2.0.0") lazy val pValues: Array[Double] = { - if (model.getFamily == Binomial.name || model.getFamily == Poisson.name) { - tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) } + if (isNormalSolver) { + if (model.getFamily == Binomial.name || model.getFamily == Poisson.name) { + tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) } + } else { + tValues.map { x => + 2.0 * (1.0 - dist.StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x))) + } + } } else { - tValues.map { x => 2.0 * (1.0 - dist.StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x))) } + throw new UnsupportedOperationException( + "No p-value available for this GeneralizedLinearRegressionModel") } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 9b0fa67630d2..4fab2160339c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1048,6 +1048,27 @@ class GeneralizedLinearRegressionSuite assert(summary.solver === "irls") } + test("glm handle collinear features") { + val collinearInstances = Seq( + Instance(1.0, 1.0, Vectors.dense(1.0, 2.0)), + Instance(2.0, 1.0, Vectors.dense(2.0, 4.0)), + Instance(3.0, 1.0, Vectors.dense(3.0, 6.0)), + Instance(4.0, 1.0, Vectors.dense(4.0, 8.0)) + ).toDF() + val trainer = new GeneralizedLinearRegression() + val model = trainer.fit(collinearInstances) + // to make it clear that underlying WLS did not solve analytically + intercept[UnsupportedOperationException] { + model.summary.coefficientStandardErrors + } + intercept[UnsupportedOperationException] { + model.summary.pValues + } + intercept[UnsupportedOperationException] { + model.summary.tValues + } + } + test("read/write") { def checkModelData( model: GeneralizedLinearRegressionModel, From 2559fb4b40c9f42f7b3ed2b77de14461f68b6fa5 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 22 Nov 2016 22:25:27 -0800 Subject: [PATCH 293/381] [SPARK-18179][SQL] Throws analysis exception with a proper message for unsupported argument types in reflect/java_method function ## What changes were proposed in this pull request? This PR proposes throwing an `AnalysisException` with a proper message rather than `NoSuchElementException` with the message ` key not found: TimestampType` when unsupported types are given to `reflect` and `java_method` functions. ```scala spark.range(1).selectExpr("reflect('java.lang.String', 'valueOf', cast('1990-01-01' as timestamp))") ``` produces **Before** ``` java.util.NoSuchElementException: key not found: TimestampType at scala.collection.MapLike$class.default(MapLike.scala:228) at scala.collection.AbstractMap.default(Map.scala:59) at scala.collection.MapLike$class.apply(MapLike.scala:141) at scala.collection.AbstractMap.apply(Map.scala:59) at org.apache.spark.sql.catalyst.expressions.CallMethodViaReflection$$anonfun$findMethod$1$$anonfun$apply$1.apply(CallMethodViaReflection.scala:159) ... ``` **After** ``` cannot resolve 'reflect('java.lang.String', 'valueOf', CAST('1990-01-01' AS TIMESTAMP))' due to data type mismatch: arguments from the third require boolean, byte, short, integer, long, float, double or string expressions; line 1 pos 0; 'Project [unresolvedalias(reflect(java.lang.String, valueOf, cast(1990-01-01 as timestamp)), Some())] +- Range (0, 1, step=1, splits=Some(2)) ... ``` Added message is, ``` arguments from the third require boolean, byte, short, integer, long, float, double or string expressions ``` ## How was this patch tested? Tests added in `CallMethodViaReflection`. Author: hyukjinkwon Closes #15694 from HyukjinKwon/SPARK-18179. --- .../catalyst/expressions/CallMethodViaReflection.scala | 4 ++++ .../expressions/CallMethodViaReflectionSuite.scala | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 40f1b148f928..4859e0c53761 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -65,6 +65,10 @@ case class CallMethodViaReflection(children: Seq[Expression]) TypeCheckFailure("first two arguments should be string literals") } else if (!classExists) { TypeCheckFailure(s"class $className not found") + } else if (children.slice(2, children.length) + .exists(e => !CallMethodViaReflection.typeMapping.contains(e.dataType))) { + TypeCheckFailure("arguments from the third require boolean, byte, short, " + + "integer, long, float, double or string expressions") } else if (method == null) { TypeCheckFailure(s"cannot find a static method that matches the argument types in $className") } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala index 43367c7e14c3..88d4d460751b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.types.{IntegerType, StringType} @@ -85,6 +87,13 @@ class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelp assert(createExpr(staticClassName, "method1").checkInputDataTypes().isSuccess) } + test("unsupported type checking") { + val ret = createExpr(staticClassName, "method1", new Timestamp(1)).checkInputDataTypes() + assert(ret.isFailure) + val errorMsg = ret.asInstanceOf[TypeCheckFailure].message + assert(errorMsg.contains("arguments from the third require boolean, byte, short")) + } + test("invoking methods using acceptable types") { checkEvaluation(createExpr(staticClassName, "method1"), "m1") checkEvaluation(createExpr(staticClassName, "method2", 2), "m2") From 7e0cd1d9b168286386f15e9b55988733476ae2bb Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 23 Nov 2016 11:25:47 +0000 Subject: [PATCH 294/381] [SPARK-18073][DOCS][WIP] Migrate wiki to spark.apache.org web site ## What changes were proposed in this pull request? Updates links to the wiki to links to the new location of content on spark.apache.org. ## How was this patch tested? Doc builds Author: Sean Owen Closes #15967 from srowen/SPARK-18073.1. --- .github/PULL_REQUEST_TEMPLATE | 2 +- CONTRIBUTING.md | 4 ++-- R/README.md | 2 +- R/pkg/DESCRIPTION | 2 +- README.md | 11 ++++++----- dev/checkstyle.xml | 2 +- docs/_layouts/global.html | 4 ++-- docs/building-spark.md | 4 ++-- docs/contributing-to-spark.md | 2 +- docs/index.md | 4 ++-- docs/sparkr.md | 2 +- docs/streaming-programming-guide.md | 2 +- .../spark/sql/execution/datasources/DataSource.scala | 5 ++--- 13 files changed, 23 insertions(+), 23 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE b/.github/PULL_REQUEST_TEMPLATE index 0e41cf182645..5af45d6fa798 100644 --- a/.github/PULL_REQUEST_TEMPLATE +++ b/.github/PULL_REQUEST_TEMPLATE @@ -7,4 +7,4 @@ (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) -Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. +Please review http://spark.apache.org/contributing.html before opening a pull request. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1a8206abe383..8fdd5aa9e7df 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,12 +1,12 @@ ## Contributing to Spark *Before opening a pull request*, review the -[Contributing to Spark wiki](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark). +[Contributing to Spark guide](http://spark.apache.org/contributing.html). It lists steps that are required before creating a PR. In particular, consider: - Is the change important and ready enough to ask the community to spend time reviewing? - Have you searched for existing, related JIRAs and pull requests? -- Is this a new feature that can stand alone as a [third party project](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects) ? +- Is this a new feature that can stand alone as a [third party project](http://spark.apache.org/third-party-projects.html) ? - Is the change being proposed clearly explained and motivated? When you contribute code, you affirm that the contribution is your original work and that you diff --git a/R/README.md b/R/README.md index 47f9a86dfde1..4c40c5963db7 100644 --- a/R/README.md +++ b/R/README.md @@ -51,7 +51,7 @@ sparkR.session() #### Making changes to SparkR -The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. +The [instructions](http://spark.apache.org/contributing.html) for making contributions to Spark also apply to SparkR. If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. Once you have made your changes, please include unit tests for them and run existing unit tests using the `R/run-tests.sh` script as described below. diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index fe41a9e7dabb..981ae1246476 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -11,7 +11,7 @@ Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), email = "felixcheung@apache.org"), person(family = "The Apache Software Foundation", role = c("aut", "cph"))) URL: http://www.apache.org/ http://spark.apache.org/ -BugReports: https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-ContributingBugReports +BugReports: http://spark.apache.org/contributing.html Depends: R (>= 3.0), methods diff --git a/README.md b/README.md index dd7d0e22495b..853f7f5ded3c 100644 --- a/README.md +++ b/README.md @@ -29,8 +29,9 @@ To build Spark and its example programs, run: You can build Spark using more than one thread by using the -T option with Maven, see ["Parallel builds in Maven 3"](https://cwiki.apache.org/confluence/display/MAVEN/Parallel+builds+in+Maven+3). More detailed documentation is available from the project site, at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). -For developing Spark using an IDE, see [Eclipse](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-Eclipse) -and [IntelliJ](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IntelliJ). + +For general development tips, including info on developing Spark using an IDE, see +[http://spark.apache.org/developer-tools.html](the Useful Developer Tools page). ## Interactive Scala Shell @@ -80,7 +81,7 @@ can be run using: ./dev/run-tests Please see the guidance on how to -[run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools). +[run tests for a module, or individual tests](http://spark.apache.org/developer-tools.html#individual-tests). ## A Note About Hadoop Versions @@ -100,5 +101,5 @@ in the online documentation for an overview on how to configure Spark. ## Contributing -Please review the [Contribution to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) -wiki for information on how to get started contributing to the project. +Please review the [Contribution to Spark guide](http://spark.apache.org/contributing.html) +for information on how to get started contributing to the project. diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml index 92c5251c8503..fd73ca73ee7e 100644 --- a/dev/checkstyle.xml +++ b/dev/checkstyle.xml @@ -28,7 +28,7 @@ with Spark-specific changes from: - https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide + http://spark.apache.org/contributing.html#code-style-guide Checkstyle is very configurable. Be sure to read the documentation at http://checkstyle.sf.net (or in your downloaded distribution). diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index ad5b5c9adfac..c00d0db63cd1 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -113,8 +113,8 @@
  • Hardware Provisioning
  • Building Spark
  • -
  • Contributing to Spark
  • -
  • Third Party Projects
  • +
  • Contributing to Spark
  • +
  • Third Party Projects
  • diff --git a/docs/building-spark.md b/docs/building-spark.md index 88da0cc9c3bb..65c2895b29b1 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -197,7 +197,7 @@ can be set to control the SBT build. For example: To avoid the overhead of launching sbt each time you need to re-compile, you can launch sbt in interactive mode by running `build/sbt`, and then run all build commands at the command prompt. For more recommendations on reducing build time, refer to the -[wiki page](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-ReducingBuildTimes). +[Useful Developer Tools page](http://spark.apache.org/developer-tools.html). ## Encrypted Filesystems @@ -215,7 +215,7 @@ to the `sharedSettings` val. See also [this PR](https://github.com/apache/spark/ ## IntelliJ IDEA or Eclipse For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the -[wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IDESetup). +[Useful Developer Tools page](http://spark.apache.org/developer-tools.html). # Running Tests diff --git a/docs/contributing-to-spark.md b/docs/contributing-to-spark.md index ef1b3ad6da57..9252545e4a12 100644 --- a/docs/contributing-to-spark.md +++ b/docs/contributing-to-spark.md @@ -5,4 +5,4 @@ title: Contributing to Spark The Spark team welcomes all forms of contributions, including bug reports, documentation or patches. For the newest information on how to contribute to the project, please read the -[wiki page on contributing to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark). +[Contributing to Spark guide](http://spark.apache.org/contributing.html). diff --git a/docs/index.md b/docs/index.md index 39de11de854a..c5d34cb5c4e7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -125,8 +125,8 @@ options for deployment: * Integration with other storage systems: * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system -* [Contributing to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) -* [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects): related third party Spark projects +* [Contributing to Spark](http://spark.apache.org/contributing.html) +* [Third Party Projects](http://spark.apache.org/third-party-projects.html): related third party Spark projects **External Resources:** diff --git a/docs/sparkr.md b/docs/sparkr.md index f30bd4026fed..d26949226b11 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -126,7 +126,7 @@ head(df) SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. The general method for creating SparkDataFrames from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active SparkSession will be used automatically. -SparkR supports reading JSON, CSV and Parquet files natively, and through packages available from sources like [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects), you can find data source connectors for popular file formats like Avro. These packages can either be added by +SparkR supports reading JSON, CSV and Parquet files natively, and through packages available from sources like [Third Party Projects](http://spark.apache.org/third-party-projects.html), you can find data source connectors for popular file formats like Avro. These packages can either be added by specifying `--packages` with `spark-submit` or `sparkR` commands, or if initializing SparkSession with `sparkPackages` parameter when in an interactive R shell or from RStudio.
    diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 18fc1cd93482..1fcd198685a5 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2382,7 +2382,7 @@ additional effort may be necessary to achieve exactly-once semantics. There are - [Kafka Integration Guide](streaming-kafka-integration.html) - [Kinesis Integration Guide](streaming-kinesis-integration.html) - [Custom Receiver Guide](streaming-custom-receivers.html) -* Third-party DStream data sources can be found in [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects) +* Third-party DStream data sources can be found in [Third Party Projects](http://spark.apache.org/third-party-projects.html) * API documentation - Scala docs * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index cfee7be1e3f0..84fde0bbf926 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -505,12 +505,11 @@ object DataSource { provider1 == "com.databricks.spark.avro") { throw new AnalysisException( s"Failed to find data source: ${provider1.toLowerCase}. Please find an Avro " + - "package at " + - "https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects") + "package at http://spark.apache.org/third-party-projects.html") } else { throw new ClassNotFoundException( s"Failed to find data source: $provider1. Please find packages at " + - "https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects", + "http://spark.apache.org/third-party-projects.html", error) } } From 85235ed6c600270e3fa434738bd50dce3564440a Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 23 Nov 2016 20:14:08 +0800 Subject: [PATCH 295/381] [SPARK-18545][SQL] Verify number of hive client RPCs in PartitionedTablePerfStatsSuite ## What changes were proposed in this pull request? This would help catch accidental O(n) calls to the hive client as in https://issues.apache.org/jira/browse/SPARK-18507 ## How was this patch tested? Checked that the test fails before https://issues.apache.org/jira/browse/SPARK-18507 was patched. cc cloud-fan Author: Eric Liang Closes #15985 from ericl/spark-18545. --- .../spark/metrics/source/StaticSources.scala | 7 +++ .../sql/hive/client/HiveClientImpl.scala | 1 + .../hive/PartitionedTablePerfStatsSuite.scala | 58 ++++++++++++++++++- 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala index 3f7cfd9d2c11..b433cd0a89ac 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala @@ -85,6 +85,11 @@ object HiveCatalogMetrics extends Source { */ val METRIC_FILE_CACHE_HITS = metricRegistry.counter(MetricRegistry.name("fileCacheHits")) + /** + * Tracks the total number of Hive client calls (e.g. to lookup a table). + */ + val METRIC_HIVE_CLIENT_CALLS = metricRegistry.counter(MetricRegistry.name("hiveClientCalls")) + /** * Resets the values of all metrics to zero. This is useful in tests. */ @@ -92,10 +97,12 @@ object HiveCatalogMetrics extends Source { METRIC_PARTITIONS_FETCHED.dec(METRIC_PARTITIONS_FETCHED.getCount()) METRIC_FILES_DISCOVERED.dec(METRIC_FILES_DISCOVERED.getCount()) METRIC_FILE_CACHE_HITS.dec(METRIC_FILE_CACHE_HITS.getCount()) + METRIC_HIVE_CLIENT_CALLS.dec(METRIC_HIVE_CLIENT_CALLS.getCount()) } // clients can use these to avoid classloader issues with the codahale classes def incrementFetchedPartitions(n: Int): Unit = METRIC_PARTITIONS_FETCHED.inc(n) def incrementFilesDiscovered(n: Int): Unit = METRIC_FILES_DISCOVERED.inc(n) def incrementFileCacheHits(n: Int): Unit = METRIC_FILE_CACHE_HITS.inc(n) + def incrementHiveClientCalls(n: Int): Unit = METRIC_HIVE_CLIENT_CALLS.inc(n) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index daae8523c636..68dcfd86731b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -281,6 +281,7 @@ private[hive] class HiveClientImpl( shim.setCurrentSessionState(state) val ret = try f finally { Thread.currentThread().setContextClassLoader(original) + HiveCatalogMetrics.incrementHiveClientCalls(1) } ret } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index b41bc862e9bc..9838b9a4eba3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -57,7 +57,11 @@ class PartitionedTablePerfStatsSuite } private def setupPartitionedHiveTable(tableName: String, dir: File): Unit = { - spark.range(5).selectExpr("id as fieldOne", "id as partCol1", "id as partCol2").write + setupPartitionedHiveTable(tableName, dir, 5) + } + + private def setupPartitionedHiveTable(tableName: String, dir: File, scale: Int): Unit = { + spark.range(scale).selectExpr("id as fieldOne", "id as partCol1", "id as partCol2").write .partitionBy("partCol1", "partCol2") .mode("overwrite") .parquet(dir.getAbsolutePath) @@ -71,7 +75,11 @@ class PartitionedTablePerfStatsSuite } private def setupPartitionedDatasourceTable(tableName: String, dir: File): Unit = { - spark.range(5).selectExpr("id as fieldOne", "id as partCol1", "id as partCol2").write + setupPartitionedDatasourceTable(tableName, dir, 5) + } + + private def setupPartitionedDatasourceTable(tableName: String, dir: File, scale: Int): Unit = { + spark.range(scale).selectExpr("id as fieldOne", "id as partCol1", "id as partCol2").write .partitionBy("partCol1", "partCol2") .mode("overwrite") .parquet(dir.getAbsolutePath) @@ -242,6 +250,52 @@ class PartitionedTablePerfStatsSuite } } + test("hive table: num hive client calls does not scale with partition count") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + withTable("test") { + withTempDir { dir => + setupPartitionedHiveTable("test", dir, scale = 100) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol1 = 1").count() == 1) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() > 0) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() < 10) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test").count() == 100) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() < 10) + + HiveCatalogMetrics.reset() + assert(spark.sql("show partitions test").count() == 100) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() < 10) + } + } + } + } + + test("datasource table: num hive client calls does not scale with partition count") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "true") { + withTable("test") { + withTempDir { dir => + setupPartitionedDatasourceTable("test", dir, scale = 100) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test where partCol1 = 1").count() == 1) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() > 0) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() < 10) + + HiveCatalogMetrics.reset() + assert(spark.sql("select * from test").count() == 100) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() < 10) + + HiveCatalogMetrics.reset() + assert(spark.sql("show partitions test").count() == 100) + assert(HiveCatalogMetrics.METRIC_HIVE_CLIENT_CALLS.getCount() < 10) + } + } + } + } + test("hive table: files read and cached when filesource partition management is off") { withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { withTable("test") { From 84284e8c82542d80dad94e458a0c0210bf803db3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 23 Nov 2016 04:15:19 -0800 Subject: [PATCH 296/381] [SPARK-18053][SQL] compare unsafe and safe complex-type values correctly ## What changes were proposed in this pull request? In Spark SQL, some expression may output safe format values, e.g. `CreateArray`, `CreateStruct`, `Cast`, etc. When we compare 2 values, we should be able to compare safe and unsafe formats. The `GreaterThan`, `LessThan`, etc. in Spark SQL already handles it, but the `EqualTo` doesn't. This PR fixes it. ## How was this patch tested? new unit test and regression test Author: Wenchen Fan Closes #15929 from cloud-fan/type-aware. --- .../sql/catalyst/expressions/UnsafeRow.java | 6 +--- .../expressions/codegen/CodeGenerator.scala | 20 ++++++++++-- .../sql/catalyst/expressions/predicates.scala | 32 +++---------------- .../catalyst/expressions/PredicateSuite.scala | 29 +++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 7 ++++ 5 files changed, 59 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index c3f0abac244c..d205547698c5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -578,12 +578,8 @@ public boolean equals(Object other) { return (sizeInBytes == o.sizeInBytes) && ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, sizeInBytes); - } else if (!(other instanceof InternalRow)) { - return false; - } else { - throw new IllegalArgumentException( - "Cannot compare UnsafeRow to " + other.getClass().getName()); } + return false; } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9c3c6d3b2a7f..09007b7c89fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -481,8 +481,13 @@ class CodegenContext { case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" + case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" + case array: ArrayType => genComp(array, c1, c2) + " == 0" + case struct: StructType => genComp(struct, c1, c2) + " == 0" case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) - case other => s"$c1.equals($c2)" + case _ => + throw new IllegalArgumentException( + "cannot generate equality code for un-comparable type: " + dataType.simpleString) } /** @@ -512,6 +517,11 @@ class CodegenContext { val funcCode: String = s""" public int $compareFunc(ArrayData a, ArrayData b) { + // when comparing unsafe arrays, try equals first as it compares the binary directly + // which is very fast. + if (a instanceof UnsafeArrayData && b instanceof UnsafeArrayData && a.equals(b)) { + return 0; + } int lengthA = a.numElements(); int lengthB = b.numElements(); int $minLength = (lengthA > lengthB) ? lengthB : lengthA; @@ -551,6 +561,11 @@ class CodegenContext { val funcCode: String = s""" public int $compareFunc(InternalRow a, InternalRow b) { + // when comparing unsafe rows, try equals first as it compares the binary directly + // which is very fast. + if (a instanceof UnsafeRow && b instanceof UnsafeRow && a.equals(b)) { + return 0; + } InternalRow i = null; $comparisons return 0; @@ -561,7 +576,8 @@ class CodegenContext { case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => - throw new IllegalArgumentException("cannot generate compare code for un-comparable type") + throw new IllegalArgumentException( + "cannot generate compare code for un-comparable type: " + dataType.simpleString) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2ad452b6a90c..3fcbb05372d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -388,6 +388,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0") } } + + protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) } @@ -429,17 +431,7 @@ case class EqualTo(left: Expression, right: Expression) override def symbol: String = "=" - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - if (left.dataType == FloatType) { - Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 - } else if (left.dataType == DoubleType) { - Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 - } else if (left.dataType != BinaryType) { - input1 == input2 - } else { - java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) - } - } + protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) @@ -482,15 +474,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } else if (input1 == null || input2 == null) { false } else { - if (left.dataType == FloatType) { - Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 - } else if (left.dataType == DoubleType) { - Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 - } else if (left.dataType != BinaryType) { - input1 == input2 - } else { - java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) - } + ordering.equiv(input1, input2) } } @@ -513,8 +497,6 @@ case class LessThan(left: Expression, right: Expression) override def symbol: String = "<" - private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } @@ -527,8 +509,6 @@ case class LessThanOrEqual(left: Expression, right: Expression) override def symbol: String = "<=" - private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } @@ -541,8 +521,6 @@ case class GreaterThan(left: Expression, right: Expression) override def symbol: String = ">" - private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } @@ -555,7 +533,5 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) override def symbol: String = ">=" - private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 2a445b8cdb09..f9f6799e6e72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -21,6 +21,8 @@ import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -293,4 +295,31 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(EqualNullSafe(nullInt, normalInt), false) checkEvaluation(EqualNullSafe(nullInt, nullInt), true) } + + test("EqualTo on complex type") { + val array = new GenericArrayData(Array(1, 2, 3)) + val struct = create_row("a", 1L, array) + + val arrayType = ArrayType(IntegerType) + val structType = new StructType() + .add("1", StringType) + .add("2", LongType) + .add("3", ArrayType(IntegerType)) + + val projection = UnsafeProjection.create( + new StructType().add("array", arrayType).add("struct", structType)) + + val unsafeRow = projection(InternalRow(array, struct)) + + val unsafeArray = unsafeRow.getArray(0) + val unsafeStruct = unsafeRow.getStruct(1, 3) + + checkEvaluation(EqualTo( + Literal.create(array, arrayType), + Literal.create(unsafeArray, arrayType)), true) + + checkEvaluation(EqualTo( + Literal.create(struct, structType), + Literal.create(unsafeStruct, structType)), true) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a715176d55d9..d2ec3cfc0522 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2469,4 +2469,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-18053: ARRAY equality is broken") { + withTable("array_tbl") { + spark.range(10).select(array($"id").as("arr")).write.saveAsTable("array_tbl") + assert(sql("SELECT * FROM array_tbl where arr = ARRAY(1L)").count == 1) + } + } } From 9785ed40d7fe4e1fcd440e55706519c6e5f8d6b1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Nov 2016 04:22:26 -0800 Subject: [PATCH 297/381] [SPARK-18557] Downgrade confusing memory leak warning message ## What changes were proposed in this pull request? TaskMemoryManager has a memory leak detector that gets called at task completion callback and checks whether any memory has not been released. If they are not released by the time the callback is invoked, TaskMemoryManager releases them. The current error message says something like the following: ``` WARN [Executor task launch worker-0] org.apache.spark.memory.TaskMemoryManager - leak 16.3 MB memory from org.apache.spark.unsafe.map.BytesToBytesMap33fb6a15 In practice, there are multiple reasons why these can be triggered in the normal code path (e.g. limit, or task failures), and the fact that these messages are log means the "leak" is fixed by TaskMemoryManager. ``` To not confuse users, this patch downgrade the message from warning to debug level, and avoids using the word "leak" since it is not actually a leak. ## How was this patch tested? N/A - this is a simple logging improvement. Author: Reynold Xin Closes #15989 from rxin/SPARK-18557. --- .../main/java/org/apache/spark/memory/TaskMemoryManager.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 1a700aa37554..c40974b54cb4 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -378,14 +378,14 @@ public long cleanUpAllAllocatedMemory() { for (MemoryConsumer c: consumers) { if (c != null && c.getUsed() > 0) { // In case of failed task, it's normal to see leaked memory - logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c); + logger.debug("unreleased " + Utils.bytesToString(c.getUsed()) + " memory from " + c); } } consumers.clear(); for (MemoryBlock page : pageTable) { if (page != null) { - logger.warn("leak a page: " + page + " in task " + taskAttemptId); + logger.debug("unreleased page: " + page + " in task " + taskAttemptId); memoryManager.tungstenMemoryAllocator().free(page); } } From 70ad07a9d20586ae182c4e60ed97bdddbcbceff3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Nov 2016 20:48:41 +0800 Subject: [PATCH 298/381] [SPARK-18522][SQL] Explicit contract for column stats serialization ## What changes were proposed in this pull request? The current implementation of column stats uses the base64 encoding of the internal UnsafeRow format to persist statistics (in table properties in Hive metastore). This is an internal format that is not stable across different versions of Spark and should NOT be used for persistence. In addition, it would be better if statistics stored in the catalog is human readable. This pull request introduces the following changes: 1. Created a single ColumnStat class to for all data types. All data types track the same set of statistics. 2. Updated the implementation for stats collection to get rid of the dependency on internal data structures (e.g. InternalRow, or storing DateType as an int32). For example, previously dates were stored as a single integer, but are now stored as java.sql.Date. When we implement the next steps of CBO, we can add code to convert those back into internal types again. 3. Documented clearly what JVM data types are being used to store what data. 4. Defined a simple Map[String, String] interface for serializing and deserializing column stats into/from the catalog. 5. Rearranged the method/function structure so it is more clear what the supported data types are, and also moved how stats are generated into ColumnStat class so they are easy to find. ## How was this patch tested? Removed most of the original test cases created for column statistics, and added three very simple ones to cover all the cases. The three test cases validate: 1. Roundtrip serialization works. 2. Behavior when analyzing non-existent column or unsupported data type column. 3. Result for stats collection for all valid data types. Also moved parser related tests into a parser test suite and added an explicit serialization test for the Hive external catalog. Author: Reynold Xin Closes #15959 from rxin/SPARK-18522. --- .../catalyst/plans/logical/Statistics.scala | 212 ++++++++--- .../command/AnalyzeColumnCommand.scala | 105 +----- .../spark/sql/StatisticsCollectionSuite.scala | 218 ++++++++++++ .../spark/sql/StatisticsColumnSuite.scala | 334 ------------------ .../apache/spark/sql/StatisticsSuite.scala | 92 ----- .../org/apache/spark/sql/StatisticsTest.scala | 130 ------- .../sql/execution/SparkSqlParserSuite.scala | 26 +- .../spark/sql/hive/HiveExternalCatalog.scala | 93 +++-- .../spark/sql/hive/StatisticsSuite.scala | 299 ++++++---------- 9 files changed, 591 insertions(+), 918 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index f3e2147b8f97..79865609cb64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.commons.codec.binary.Base64 +import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.types._ + /** * Estimates of various statistics. The default estimation logic simply lazily multiplies the * corresponding statistic produced by the children. To override this behavior, override @@ -58,60 +61,175 @@ case class Statistics( } } + /** - * Statistics for a column. + * Statistics collected for a column. + * + * 1. Supported data types are defined in `ColumnStat.supportsType`. + * 2. The JVM data type stored in min/max is the external data type (used in Row) for the + * corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for + * TimestampType we store java.sql.Timestamp. + * 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs. + * 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms + * (sketches) might have been used, and the data collected can also be stale. + * + * @param distinctCount number of distinct values + * @param min minimum value + * @param max maximum value + * @param nullCount number of nulls + * @param avgLen average length of the values. For fixed-length types, this should be a constant. + * @param maxLen maximum length of the values. For fixed-length types, this should be a constant. */ -case class ColumnStat(statRow: InternalRow) { +case class ColumnStat( + distinctCount: BigInt, + min: Option[Any], + max: Option[Any], + nullCount: BigInt, + avgLen: Long, + maxLen: Long) { - def forNumeric[T <: AtomicType](dataType: T): NumericColumnStat[T] = { - NumericColumnStat(statRow, dataType) - } - def forString: StringColumnStat = StringColumnStat(statRow) - def forBinary: BinaryColumnStat = BinaryColumnStat(statRow) - def forBoolean: BooleanColumnStat = BooleanColumnStat(statRow) + // We currently don't store min/max for binary/string type. This can change in the future and + // then we need to remove this require. + require(min.isEmpty || (!min.get.isInstanceOf[Array[Byte]] && !min.get.isInstanceOf[String])) + require(max.isEmpty || (!max.get.isInstanceOf[Array[Byte]] && !max.get.isInstanceOf[String])) - override def toString: String = { - // use Base64 for encoding - Base64.encodeBase64String(statRow.asInstanceOf[UnsafeRow].getBytes) + /** + * Returns a map from string to string that can be used to serialize the column stats. + * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string + * representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]]. + * + * As part of the protocol, the returned map always contains a key called "version". + * In the case min/max values are null (None), they won't appear in the map. + */ + def toMap: Map[String, String] = { + val map = new scala.collection.mutable.HashMap[String, String] + map.put(ColumnStat.KEY_VERSION, "1") + map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString) + map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString) + map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString) + map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) + min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) } + max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) } + map.toMap } } -object ColumnStat { - def apply(numFields: Int, str: String): ColumnStat = { - // use Base64 for decoding - val bytes = Base64.decodeBase64(str) - val unsafeRow = new UnsafeRow(numFields) - unsafeRow.pointTo(bytes, bytes.length) - ColumnStat(unsafeRow) + +object ColumnStat extends Logging { + + // List of string keys used to serialize ColumnStat + val KEY_VERSION = "version" + private val KEY_DISTINCT_COUNT = "distinctCount" + private val KEY_MIN_VALUE = "min" + private val KEY_MAX_VALUE = "max" + private val KEY_NULL_COUNT = "nullCount" + private val KEY_AVG_LEN = "avgLen" + private val KEY_MAX_LEN = "maxLen" + + /** Returns true iff the we support gathering column statistics on column of the given type. */ + def supportsType(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case BooleanType => true + case DateType => true + case TimestampType => true + case BinaryType | StringType => true + case _ => false } -} -case class NumericColumnStat[T <: AtomicType](statRow: InternalRow, dataType: T) { - // The indices here must be consistent with `ColumnStatStruct.numericColumnStat`. - val numNulls: Long = statRow.getLong(0) - val max: T#InternalType = statRow.get(1, dataType).asInstanceOf[T#InternalType] - val min: T#InternalType = statRow.get(2, dataType).asInstanceOf[T#InternalType] - val ndv: Long = statRow.getLong(3) -} + /** + * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats + * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. + */ + def fromMap(table: String, field: StructField, map: Map[String, String]) + : Option[ColumnStat] = { + val str2val: (String => Any) = field.dataType match { + case _: IntegralType => _.toLong + case _: DecimalType => new java.math.BigDecimal(_) + case DoubleType | FloatType => _.toDouble + case BooleanType => _.toBoolean + case DateType => java.sql.Date.valueOf + case TimestampType => java.sql.Timestamp.valueOf + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType | StringType => _ => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column ${field.name} of data type: ${field.dataType}.") + } -case class StringColumnStat(statRow: InternalRow) { - // The indices here must be consistent with `ColumnStatStruct.stringColumnStat`. - val numNulls: Long = statRow.getLong(0) - val avgColLen: Double = statRow.getDouble(1) - val maxColLen: Long = statRow.getInt(2) - val ndv: Long = statRow.getLong(3) -} + try { + Some(ColumnStat( + distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong), + // Note that flatMap(Option.apply) turns Option(null) into None. + min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply), + max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply), + nullCount = BigInt(map(KEY_NULL_COUNT).toLong), + avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, + maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong + )) + } catch { + case NonFatal(e) => + logWarning(s"Failed to parse column statistics for column ${field.name} in table $table", e) + None + } + } -case class BinaryColumnStat(statRow: InternalRow) { - // The indices here must be consistent with `ColumnStatStruct.binaryColumnStat`. - val numNulls: Long = statRow.getLong(0) - val avgColLen: Double = statRow.getDouble(1) - val maxColLen: Long = statRow.getInt(2) -} + /** + * Constructs an expression to compute column statistics for a given column. + * + * The expression should create a single struct column with the following schema: + * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long + * + * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and + * as a result should stay in sync with it. + */ + def statExprs(col: Attribute, relativeSD: Double): CreateNamedStruct = { + def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => + expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } + }) + val one = Literal(1, LongType) + + // the approximate ndv (num distinct value) should never be larger than the number of rows + val numNonNulls = if (col.nullable) Count(col) else Count(one) + val ndv = Least(Seq(HyperLogLogPlusPlus(col, relativeSD), numNonNulls)) + val numNulls = Subtract(Count(one), numNonNulls) + + def fixedLenTypeStruct(castType: DataType) = { + // For fixed width types, avg size should be the same as max size. + val avgSize = Literal(col.dataType.defaultSize, LongType) + struct(ndv, Cast(Min(col), castType), Cast(Max(col), castType), numNulls, avgSize, avgSize) + } + + col.dataType match { + case _: IntegralType => fixedLenTypeStruct(LongType) + case _: DecimalType => fixedLenTypeStruct(col.dataType) + case DoubleType | FloatType => fixedLenTypeStruct(DoubleType) + case BooleanType => fixedLenTypeStruct(col.dataType) + case DateType => fixedLenTypeStruct(col.dataType) + case TimestampType => fixedLenTypeStruct(col.dataType) + case BinaryType | StringType => + // For string and binary type, we don't store min/max. + val nullLit = Literal(null, col.dataType) + struct( + ndv, nullLit, nullLit, numNulls, + Ceil(Average(Length(col))), Cast(Max(Length(col)), LongType)) + case _ => + throw new AnalysisException("Analyzing column statistics is not supported for column " + + s"${col.name} of data type: ${col.dataType}.") + } + } + + /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */ + def rowToColumnStat(row: Row): ColumnStat = { + ColumnStat( + distinctCount = BigInt(row.getLong(0)), + min = Option(row.get(1)), // for string/binary min/max, get should return null + max = Option(row.get(2)), + nullCount = BigInt(row.getLong(3)), + avgLen = row.getLong(4), + maxLen = row.getLong(5) + ) + } -case class BooleanColumnStat(statRow: InternalRow) { - // The indices here must be consistent with `ColumnStatStruct.booleanColumnStat`. - val numNulls: Long = statRow.getLong(0) - val numTrues: Long = statRow.getLong(1) - val numFalses: Long = statRow.getLong(2) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 7fc57d09e924..9dffe3614a87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -24,9 +24,8 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.types._ /** @@ -62,7 +61,7 @@ case class AnalyzeColumnCommand( // Compute stats for each column val (rowCount, newColStats) = - AnalyzeColumnCommand.computeColStats(sparkSession, relation, columnNames) + AnalyzeColumnCommand.computeColumnStats(sparkSession, tableIdent.table, relation, columnNames) // We also update table-level stats in order to keep them consistent with column-level stats. val statistics = Statistics( @@ -88,8 +87,9 @@ object AnalyzeColumnCommand extends Logging { * * This is visible for testing. */ - def computeColStats( + def computeColumnStats( sparkSession: SparkSession, + tableName: String, relation: LogicalPlan, columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { @@ -97,102 +97,33 @@ object AnalyzeColumnCommand extends Logging { val resolver = sparkSession.sessionState.conf.resolver val attributesToAnalyze = AttributeSet(columnNames.map { col => val exprOption = relation.output.find(attr => resolver(attr.name, col)) - exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) + exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist.")) }).toSeq + // Make sure the column types are supported for stats gathering. + attributesToAnalyze.foreach { attr => + if (!ColumnStat.supportsType(attr.dataType)) { + throw new AnalysisException( + s"Column ${attr.name} in table $tableName is of type ${attr.dataType}, " + + "and Spark does not support statistics collection on this column type.") + } + } + // Collect statistics per column. // The first element in the result will be the overall row count, the following elements // will be structs containing all column stats. // The layout of each struct follows the layout of the ColumnStats. val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(AnalyzeColumnCommand.createColumnStatStruct(_, ndvMaxErr)) + attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr)) + val namedExpressions = expressions.map(e => Alias(e, e.toString)()) - val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) - .queryExecution.toRdd.collect().head + val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head() - // unwrap the result - // TODO: Get rid of numFields by using the public Dataset API. val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - val numFields = AnalyzeColumnCommand.numStatFields(expr.dataType) - (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields))) + (expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1))) }.toMap (rowCount, columnStats) } - - private val zero = Literal(0, LongType) - private val one = Literal(1, LongType) - - private def numNulls(e: Expression): Expression = { - if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero - } - private def max(e: Expression): Expression = Max(e) - private def min(e: Expression): Expression = Min(e) - private def ndv(e: Expression, relativeSD: Double): Expression = { - // the approximate ndv should never be larger than the number of rows - Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one))) - } - private def avgLength(e: Expression): Expression = Average(Length(e)) - private def maxLength(e: Expression): Expression = Max(Length(e)) - private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) - private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - - /** - * Creates a struct that groups the sequence of expressions together. This is used to create - * one top level struct per column. - */ - private def createStruct(exprs: Seq[Expression]): CreateNamedStruct = { - CreateStruct(exprs.map { expr: Expression => - expr.transformUp { - case af: AggregateFunction => af.toAggregateExpression() - } - }) - } - - private def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { - Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD)) - } - - private def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { - Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD)) - } - - private def binaryColumnStat(e: Expression): Seq[Expression] = { - Seq(numNulls(e), avgLength(e), maxLength(e)) - } - - private def booleanColumnStat(e: Expression): Seq[Expression] = { - Seq(numNulls(e), numTrues(e), numFalses(e)) - } - - // TODO(rxin): Get rid of this function. - def numStatFields(dataType: DataType): Int = { - dataType match { - case BinaryType | BooleanType => 3 - case _ => 4 - } - } - - /** - * Creates a struct expression that contains the statistics to collect for a column. - * - * @param attr column to collect statistics - * @param relativeSD relative error for approximate number of distinct values. - */ - def createColumnStatStruct(attr: Attribute, relativeSD: Double): CreateNamedStruct = { - attr.dataType match { - case _: NumericType | TimestampType | DateType => - createStruct(numericColumnStat(attr, relativeSD)) - case StringType => - createStruct(stringColumnStat(attr, relativeSD)) - case BinaryType => - createStruct(binaryColumnStat(attr)) - case BooleanType => - createStruct(booleanColumnStat(attr)) - case otherType => - throw new AnalysisException("Analyzing columns is not supported for column " + - s"${attr.name} of data type: ${attr.dataType}.") - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala new file mode 100644 index 000000000000..1fcccd061079 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.{lang => jl} +import java.sql.{Date, Timestamp} + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.test.SQLTestData.ArrayData +import org.apache.spark.sql.types._ + + +/** + * End-to-end suite testing statistics collection and use on both entire table and columns. + */ +class StatisticsCollectionSuite extends StatisticsCollectionTestBase with SharedSQLContext { + import testImplicits._ + + private def checkTableStats(tableName: String, expectedRowCount: Option[Int]) + : Option[Statistics] = { + val df = spark.table(tableName) + val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => + assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) + rel.catalogTable.get.stats + } + assert(stats.size == 1) + stats.head + } + + test("estimates the size of a limit 0 on outer join") { + withTempView("test") { + Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") + .createOrReplaceTempView("test") + val df1 = spark.table("test") + val df2 = spark.table("test").limit(0) + val df = df1.join(df2, Seq("k"), "left") + + val sizes = df.queryExecution.analyzed.collect { case g: Join => + g.statistics.sizeInBytes + } + + assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") + assert(sizes.head === BigInt(96), + s"expected exact size 96 for table 'test', got: ${sizes.head}") + } + } + + test("analyze column command - unsupported types and invalid columns") { + val tableName = "column_stats_test1" + withTable(tableName) { + Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName) + + // Test unsupported data types + val err1 = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data") + } + assert(err1.message.contains("does not support statistics collection")) + + // Test invalid columns + val err2 = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS some_random_column") + } + assert(err2.message.contains("does not exist")) + } + } + + test("test table-level statistics for data source table") { + val tableName = "tbl" + withTable(tableName) { + sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet") + Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName) + + // noscan won't count the number of rows + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + checkTableStats(tableName, expectedRowCount = None) + + // without noscan, we count the number of rows + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") + checkTableStats(tableName, expectedRowCount = Some(2)) + } + } + + test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { + val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) + val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) + assert(df.queryExecution.analyzed.statistics.sizeInBytes > + spark.sessionState.conf.autoBroadcastJoinThreshold) + assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes > + spark.sessionState.conf.autoBroadcastJoinThreshold) + } + + test("estimates the size of limit") { + withTempView("test") { + Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") + .createOrReplaceTempView("test") + Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) => + val df = sql(s"""SELECT * FROM test limit $limit""") + + val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => + g.statistics.sizeInBytes + } + assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizesGlobalLimit.head === BigInt(expected), + s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") + + val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => + l.statistics.sizeInBytes + } + assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizesLocalLimit.head === BigInt(expected), + s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}") + } + } + } + +} + + +/** + * The base for test cases that we want to include in both the hive module (for verifying behavior + * when using the Hive external catalog) as well as in the sql/core module. + */ +abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils { + import testImplicits._ + + private val dec1 = new java.math.BigDecimal("1.000000000000000000") + private val dec2 = new java.math.BigDecimal("8.000000000000000000") + private val d1 = Date.valueOf("2016-05-08") + private val d2 = Date.valueOf("2016-05-09") + private val t1 = Timestamp.valueOf("2016-05-08 00:00:01") + private val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + + /** + * Define a very simple 3 row table used for testing column serialization. + * Note: last column is seq[int] which doesn't support stats collection. + */ + protected val data = Seq[ + (jl.Boolean, jl.Byte, jl.Short, jl.Integer, jl.Long, + jl.Double, jl.Float, java.math.BigDecimal, + String, Array[Byte], Date, Timestamp, + Seq[Int])]( + (false, 1.toByte, 1.toShort, 1, 1L, 1.0, 1.0f, dec1, "s1", "b1".getBytes, d1, t1, null), + (true, 2.toByte, 3.toShort, 4, 5L, 6.0, 7.0f, dec2, "ss9", "bb0".getBytes, d2, t2, null), + (null, null, null, null, null, null, null, null, null, null, null, null, null) + ) + + /** A mapping from column to the stats collected. */ + protected val stats = mutable.LinkedHashMap( + "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), + "cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1), + "cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2), + "cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4), + "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), + "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), + "cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4), + "cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16), + "cstring" -> ColumnStat(2, None, None, 1, 3, 3), + "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), + "cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4), + "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8) + ) + + test("column stats round trip serialization") { + // Make sure we serialize and then deserialize and we will get the result data + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + stats.zip(df.schema).foreach { case ((k, v), field) => + withClue(s"column $k with type ${field.dataType}") { + val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap) + assert(roundtrip == Some(v)) + } + } + } + + test("analyze column command - result verification") { + val tableName = "column_stats_test2" + // (data.head.productArity - 1) because the last column does not support stats collection. + assert(stats.size == data.head.productArity - 1) + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Collect statistics + sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) + + // Validate statistics + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + assert(table.stats.isDefined) + assert(table.stats.get.colStats.size == stats.size) + + stats.foreach { case (k, v) => + withClue(s"column $k") { + assert(table.stats.get.colStats(k) == v) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala deleted file mode 100644 index e866ac2cb3b3..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ /dev/null @@ -1,334 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Date, Timestamp} - -import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.command.AnalyzeColumnCommand -import org.apache.spark.sql.test.SQLTestData.ArrayData -import org.apache.spark.sql.types._ - -class StatisticsColumnSuite extends StatisticsTest { - import testImplicits._ - - test("parse analyze column commands") { - val tableName = "tbl" - - // we need to specify column names - intercept[ParseException] { - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS") - } - - val analyzeSql = s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key, value" - val parsed = spark.sessionState.sqlParser.parsePlan(analyzeSql) - val expected = AnalyzeColumnCommand(TableIdentifier(tableName), Seq("key", "value")) - comparePlans(parsed, expected) - } - - test("analyzing columns of non-atomic types is not supported") { - val tableName = "tbl" - withTable(tableName) { - Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName) - val err = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data") - } - assert(err.message.contains("Analyzing columns is not supported")) - } - } - - test("check correctness of columns") { - val table = "tbl" - val colName1 = "abc" - val colName2 = "x.yz" - withTable(table) { - sql(s"CREATE TABLE $table ($colName1 int, `$colName2` string) USING PARQUET") - - val invalidColError = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key") - } - assert(invalidColError.message == "Invalid column name: key.") - - withSQLConf("spark.sql.caseSensitive" -> "true") { - val invalidErr = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ${colName1.toUpperCase}") - } - assert(invalidErr.message == s"Invalid column name: ${colName1.toUpperCase}.") - } - - withSQLConf("spark.sql.caseSensitive" -> "false") { - val columnsToAnalyze = Seq(colName2.toUpperCase, colName1, colName2) - val tableIdent = TableIdentifier(table, Some("default")) - val relation = spark.sessionState.catalog.lookupRelation(tableIdent) - val (_, columnStats) = - AnalyzeColumnCommand.computeColStats(spark, relation, columnsToAnalyze) - assert(columnStats.contains(colName1)) - assert(columnStats.contains(colName2)) - // check deduplication - assert(columnStats.size == 2) - assert(!columnStats.contains(colName2.toUpperCase)) - } - } - } - - private def getNonNullValues[T](values: Seq[Option[T]]): Seq[T] = { - values.filter(_.isDefined).map(_.get) - } - - test("column-level statistics for integral type columns") { - val values = (0 to 5).map { i => - if (i % 2 == 0) None else Some(i) - } - val data = values.map { i => - (i.map(_.toByte), i.map(_.toShort), i.map(_.toInt), i.map(_.toLong)) - } - - val df = data.toDF("c1", "c2", "c3", "c4") - val nonNullValues = getNonNullValues[Int](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.max, - nonNullValues.min, - nonNullValues.distinct.length.toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for fractional type columns") { - val values: Seq[Option[Decimal]] = (0 to 5).map { i => - if (i == 0) None else Some(Decimal(i + i * 0.01)) - } - val data = values.map { i => - (i.map(_.toFloat), i.map(_.toDouble), i) - } - - val df = data.toDF("c1", "c2", "c3") - val nonNullValues = getNonNullValues[Decimal](values) - val numNulls = values.count(_.isEmpty).toLong - val ndv = nonNullValues.distinct.length.toLong - val expectedColStatsSeq = df.schema.map { f => - val colStat = f.dataType match { - case floatType: FloatType => - ColumnStat(InternalRow(numNulls, nonNullValues.max.toFloat, nonNullValues.min.toFloat, - ndv)) - case doubleType: DoubleType => - ColumnStat(InternalRow(numNulls, nonNullValues.max.toDouble, nonNullValues.min.toDouble, - ndv)) - case decimalType: DecimalType => - ColumnStat(InternalRow(numNulls, nonNullValues.max, nonNullValues.min, ndv)) - } - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for string column") { - val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")) - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[String](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, - nonNullValues.map(_.length).max.toInt, - nonNullValues.distinct.length.toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for binary column") { - val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")).map(_.map(_.getBytes)) - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[Array[Byte]](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, - nonNullValues.map(_.length).max.toInt)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for boolean column") { - val values = Seq(None, Some(true), Some(false), Some(true)) - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[Boolean](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.count(_.equals(true)).toLong, - nonNullValues.count(_.equals(false)).toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for date column") { - val values = Seq(None, Some("1970-01-01"), Some("1970-02-02")).map(_.map(Date.valueOf)) - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[Date](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - // Internally, DateType is represented as the number of days from 1970-01-01. - nonNullValues.map(DateTimeUtils.fromJavaDate).max, - nonNullValues.map(DateTimeUtils.fromJavaDate).min, - nonNullValues.distinct.length.toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for timestamp column") { - val values = Seq(None, Some("1970-01-01 00:00:00"), Some("1970-01-01 00:00:05")).map { i => - i.map(Timestamp.valueOf) - } - val df = values.toDF("c1") - val nonNullValues = getNonNullValues[Timestamp](values) - val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat(InternalRow( - values.count(_.isEmpty).toLong, - // Internally, TimestampType is represented as the number of days from 1970-01-01 - nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max, - nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min, - nonNullValues.distinct.length.toLong)) - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for null columns") { - val values = Seq(None, None) - val data = values.map { i => - (i.map(_.toString), i.map(_.toString.toInt)) - } - val df = data.toDF("c1", "c2") - val expectedColStatsSeq = df.schema.map { f => - (f, ColumnStat(InternalRow(values.count(_.isEmpty).toLong, null, null, 0L))) - } - checkColStats(df, expectedColStatsSeq) - } - - test("column-level statistics for columns with different types") { - val intSeq = Seq(1, 2) - val doubleSeq = Seq(1.01d, 2.02d) - val stringSeq = Seq("a", "bb") - val binarySeq = Seq("a", "bb").map(_.getBytes) - val booleanSeq = Seq(true, false) - val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf) - val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05").map(Timestamp.valueOf) - val longSeq = Seq(5L, 4L) - - val data = intSeq.indices.map { i => - (intSeq(i), doubleSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i), dateSeq(i), - timestampSeq(i), longSeq(i)) - } - val df = data.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8") - val expectedColStatsSeq = df.schema.map { f => - val colStat = f.dataType match { - case IntegerType => - ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) - case DoubleType => - ColumnStat(InternalRow(0L, doubleSeq.max, doubleSeq.min, - doubleSeq.distinct.length.toLong)) - case StringType => - ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, - stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) - case BinaryType => - ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, - binarySeq.map(_.length).max.toInt)) - case BooleanType => - ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, - booleanSeq.count(_.equals(false)).toLong)) - case DateType => - ColumnStat(InternalRow(0L, dateSeq.map(DateTimeUtils.fromJavaDate).max, - dateSeq.map(DateTimeUtils.fromJavaDate).min, dateSeq.distinct.length.toLong)) - case TimestampType => - ColumnStat(InternalRow(0L, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max, - timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min, - timestampSeq.distinct.length.toLong)) - case LongType => - ColumnStat(InternalRow(0L, longSeq.max, longSeq.min, longSeq.distinct.length.toLong)) - } - (f, colStat) - } - checkColStats(df, expectedColStatsSeq) - } - - test("update table-level stats while collecting column-level stats") { - val table = "tbl" - withTable(table) { - sql(s"CREATE TABLE $table (c1 int) USING PARQUET") - sql(s"INSERT INTO $table SELECT 1") - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") - checkTableStats(tableName = table, expectedRowCount = Some(1)) - - // update table-level stats between analyze table and analyze column commands - sql(s"INSERT INTO $table SELECT 1") - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") - val fetchedStats = checkTableStats(tableName = table, expectedRowCount = Some(2)) - - val colStat = fetchedStats.get.colStats("c1") - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = colStat, - expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), - rsd = spark.sessionState.conf.ndvMaxError) - } - } - - test("analyze column stats independently") { - val table = "tbl" - withTable(table) { - sql(s"CREATE TABLE $table (c1 int, c2 long) USING PARQUET") - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") - val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0)) - assert(fetchedStats1.get.colStats.size == 1) - val expected1 = ColumnStat(InternalRow(0L, null, null, 0L)) - val rsd = spark.sessionState.conf.ndvMaxError - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = fetchedStats1.get.colStats("c1"), - expectedColStat = expected1, - rsd = rsd) - - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") - val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0)) - // column c1 is kept in the stats - assert(fetchedStats2.get.colStats.size == 2) - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = fetchedStats2.get.colStats("c1"), - expectedColStat = expected1, - rsd = rsd) - val expected2 = ColumnStat(InternalRow(0L, null, null, 0L)) - StatisticsTest.checkColStat( - dataType = LongType, - colStat = fetchedStats2.get.colStats("c2"), - expectedColStat = expected2, - rsd = rsd) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala deleted file mode 100644 index 8cf42e9248c2..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit} -import org.apache.spark.sql.types._ - -class StatisticsSuite extends StatisticsTest { - import testImplicits._ - - test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { - val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) - val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) - assert(df.queryExecution.analyzed.statistics.sizeInBytes > - spark.sessionState.conf.autoBroadcastJoinThreshold) - assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes > - spark.sessionState.conf.autoBroadcastJoinThreshold) - } - - test("estimates the size of limit") { - withTempView("test") { - Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") - .createOrReplaceTempView("test") - Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) => - val df = sql(s"""SELECT * FROM test limit $limit""") - - val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => - g.statistics.sizeInBytes - } - assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesGlobalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") - - val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => - l.statistics.sizeInBytes - } - assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesLocalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}") - } - } - } - - test("estimates the size of a limit 0 on outer join") { - withTempView("test") { - Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") - .createOrReplaceTempView("test") - val df1 = spark.table("test") - val df2 = spark.table("test").limit(0) - val df = df1.join(df2, Seq("k"), "left") - - val sizes = df.queryExecution.analyzed.collect { case g: Join => - g.statistics.sizeInBytes - } - - assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") - assert(sizes.head === BigInt(96), - s"expected exact size 96 for table 'test', got: ${sizes.head}") - } - } - - test("test table-level statistics for data source table created in InMemoryCatalog") { - val tableName = "tbl" - withTable(tableName) { - sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet") - Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName) - - // noscan won't count the number of rows - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") - checkTableStats(tableName, expectedRowCount = None) - - // without noscan, we count the number of rows - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") - checkTableStats(tableName, expectedRowCount = Some(2)) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala deleted file mode 100644 index 915ee0d31bca..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} -import org.apache.spark.sql.execution.command.AnalyzeColumnCommand -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ - - -trait StatisticsTest extends QueryTest with SharedSQLContext { - - def checkColStats( - df: DataFrame, - expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = { - val table = "tbl" - withTable(table) { - df.write.format("json").saveAsTable(table) - val columns = expectedColStatsSeq.map(_._1) - val tableIdent = TableIdentifier(table, Some("default")) - val relation = spark.sessionState.catalog.lookupRelation(tableIdent) - val (_, columnStats) = - AnalyzeColumnCommand.computeColStats(spark, relation, columns.map(_.name)) - expectedColStatsSeq.foreach { case (field, expectedColStat) => - assert(columnStats.contains(field.name)) - val colStat = columnStats(field.name) - StatisticsTest.checkColStat( - dataType = field.dataType, - colStat = colStat, - expectedColStat = expectedColStat, - rsd = spark.sessionState.conf.ndvMaxError) - - // check if we get the same colStat after encoding and decoding - val encodedCS = colStat.toString - val numFields = AnalyzeColumnCommand.numStatFields(field.dataType) - val decodedCS = ColumnStat(numFields, encodedCS) - StatisticsTest.checkColStat( - dataType = field.dataType, - colStat = decodedCS, - expectedColStat = expectedColStat, - rsd = spark.sessionState.conf.ndvMaxError) - } - } - } - - def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = { - val df = spark.table(tableName) - val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) - rel.catalogTable.get.stats - } - assert(stats.size == 1) - stats.head - } -} - -object StatisticsTest { - def checkColStat( - dataType: DataType, - colStat: ColumnStat, - expectedColStat: ColumnStat, - rsd: Double): Unit = { - dataType match { - case StringType => - val cs = colStat.forString - val expectedCS = expectedColStat.forString - assert(cs.numNulls == expectedCS.numNulls) - assert(cs.avgColLen == expectedCS.avgColLen) - assert(cs.maxColLen == expectedCS.maxColLen) - checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) - case BinaryType => - val cs = colStat.forBinary - val expectedCS = expectedColStat.forBinary - assert(cs.numNulls == expectedCS.numNulls) - assert(cs.avgColLen == expectedCS.avgColLen) - assert(cs.maxColLen == expectedCS.maxColLen) - case BooleanType => - val cs = colStat.forBoolean - val expectedCS = expectedColStat.forBoolean - assert(cs.numNulls == expectedCS.numNulls) - assert(cs.numTrues == expectedCS.numTrues) - assert(cs.numFalses == expectedCS.numFalses) - case atomicType: AtomicType => - checkNumericColStats( - dataType = atomicType, colStat = colStat, expectedColStat = expectedColStat, rsd = rsd) - } - } - - private def checkNumericColStats( - dataType: AtomicType, - colStat: ColumnStat, - expectedColStat: ColumnStat, - rsd: Double): Unit = { - val cs = colStat.forNumeric(dataType) - val expectedCS = expectedColStat.forNumeric(dataType) - assert(cs.numNulls == expectedCS.numNulls) - assert(cs.max == expectedCS.max) - assert(cs.min == expectedCS.min) - checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) - } - - private def checkNdv(ndv: Long, expectedNdv: Long, rsd: Double): Unit = { - // ndv is an approximate value, so we make sure we have the value, and it should be - // within 3*SD's of the given rsd. - if (expectedNdv == 0) { - assert(ndv == 0) - } else if (expectedNdv > 0) { - assert(ndv > 0) - val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) - assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.") - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 797fe9ffa8be..b070138be05d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DescribeFunctionCommand, - DescribeTableCommand, ShowFunctionsCommand} -import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing} +import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} @@ -221,12 +220,22 @@ class SparkSqlParserSuite extends PlanTest { intercept("explain describe tables x", "Unsupported SQL statement") } - test("SPARK-18106 analyze table") { + test("analyze table statistics") { assertEqual("analyze table t compute statistics", AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) assertEqual("analyze table t compute statistics noscan", AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) - assertEqual("analyze table t partition (a) compute statistics noscan", + assertEqual("analyze table t partition (a) compute statistics nOscAn", + AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + + // Partitions specified - we currently parse them but don't do anything with it + assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS", + AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) + assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan", + AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS", + AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) + assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS noscan", AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) intercept("analyze table t compute statistics xxxx", @@ -234,4 +243,11 @@ class SparkSqlParserSuite extends PlanTest { intercept("analyze table t partition (a) compute statistics xxxx", "Expected `NOSCAN` instead of `xxxx`") } + + test("analyze table column statistics") { + intercept("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS", "") + + assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value", + AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value"))) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ff0923f04893..fd9dc3206387 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, DDLUtils} +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ @@ -514,7 +514,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } stats.colStats.foreach { case (colName, colStat) => - statsProperties += (STATISTICS_COL_STATS_PREFIX + colName) -> colStat.toString + colStat.toMap.foreach { case (k, v) => + statsProperties += (columnStatKeyPropName(colName, k) -> v) + } } tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties) } else { @@ -605,48 +607,65 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * It reads table schema, provider, partition column names and bucket specification from table * properties, and filter out these special entries from table properties. */ - private def restoreTableMetadata(table: CatalogTable): CatalogTable = { + private def restoreTableMetadata(inputTable: CatalogTable): CatalogTable = { if (conf.get(DEBUG_MODE)) { - return table + return inputTable } - val tableWithSchema = if (table.tableType == VIEW) { - table - } else { - getProviderFromTableProperties(table) match { + var table = inputTable + + if (table.tableType != VIEW) { + table.properties.get(DATASOURCE_PROVIDER) match { // No provider in table properties, which means this table is created by Spark prior to 2.1, // or is created at Hive side. case None => - table.copy(provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true) + table = table.copy( + provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true) // This is a Hive serde table created by Spark 2.1 or higher versions. - case Some(DDLUtils.HIVE_PROVIDER) => restoreHiveSerdeTable(table) + case Some(DDLUtils.HIVE_PROVIDER) => + table = restoreHiveSerdeTable(table) // This is a regular data source table. - case Some(provider) => restoreDataSourceTable(table, provider) + case Some(provider) => + table = restoreDataSourceTable(table, provider) } } // construct Spark's statistics from information in Hive metastore - val statsProps = tableWithSchema.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) - val tableWithStats = if (statsProps.nonEmpty) { - val colStatsProps = statsProps.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)) - .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } - val colStats: Map[String, ColumnStat] = tableWithSchema.schema.collect { - case f if colStatsProps.contains(f.name) => - val numFields = AnalyzeColumnCommand.numStatFields(f.dataType) - (f.name, ColumnStat(numFields, colStatsProps(f.name))) - }.toMap - tableWithSchema.copy( + val statsProps = table.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + + if (statsProps.nonEmpty) { + val colStats = new scala.collection.mutable.HashMap[String, ColumnStat] + + // For each column, recover its column stats. Note that this is currently a O(n^2) operation, + // but given the number of columns it usually not enormous, this is probably OK as a start. + // If we want to map this a linear operation, we'd need a stronger contract between the + // naming convention used for serialization. + table.schema.foreach { field => + if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) { + // If "version" field is defined, then the column stat is defined. + val keyPrefix = columnStatKeyPropName(field.name, "") + val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => + (k.drop(keyPrefix.length), v) + } + + ColumnStat.fromMap(table.identifier.table, field, colStatMap).foreach { + colStat => colStats += field.name -> colStat + } + } + } + + table = table.copy( stats = Some(Statistics( - sizeInBytes = BigInt(tableWithSchema.properties(STATISTICS_TOTAL_SIZE)), - rowCount = tableWithSchema.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), - colStats = colStats))) - } else { - tableWithSchema + sizeInBytes = BigInt(table.properties(STATISTICS_TOTAL_SIZE)), + rowCount = table.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), + colStats = colStats.toMap))) } - tableWithStats.copy(properties = getOriginalTableProperties(table)) + // Get the original table properties as defined by the user. + table.copy( + properties = table.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) }) } private def restoreHiveSerdeTable(table: CatalogTable): CatalogTable = { @@ -1020,17 +1039,17 @@ object HiveExternalCatalog { val TABLE_PARTITION_PROVIDER_CATALOG = "catalog" val TABLE_PARTITION_PROVIDER_FILESYSTEM = "filesystem" - - def getProviderFromTableProperties(metadata: CatalogTable): Option[String] = { - metadata.properties.get(DATASOURCE_PROVIDER) - } - - def getOriginalTableProperties(metadata: CatalogTable): Map[String, String] = { - metadata.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) } + /** + * Returns the fully qualified name used in table properties for a particular column stat. + * For example, for column "mycol", and "min" stat, this should return + * "spark.sql.statistics.colStats.mycol.min". + */ + private def columnStatKeyPropName(columnName: String, statKey: String): String = { + STATISTICS_COL_STATS_PREFIX + columnName + "." + statKey } // A persisted data source table always store its schema in the catalog. - def getSchemaFromTableProperties(metadata: CatalogTable): StructType = { + private def getSchemaFromTableProperties(metadata: CatalogTable): StructType = { val errorMessage = "Could not read schema from the hive metastore because it is corrupted." val props = metadata.properties val schema = props.get(DATASOURCE_SCHEMA) @@ -1078,11 +1097,11 @@ object HiveExternalCatalog { ) } - def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = { + private def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = { getColumnNamesByType(metadata.properties, "part", "partitioning columns") } - def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = { + private def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = { metadata.properties.get(DATASOURCE_SCHEMA_NUMBUCKETS).map { numBuckets => BucketSpec( numBuckets.toInt, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 4f5ebc3d838b..5ae202fdc98d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -22,56 +22,16 @@ import java.io.{File, PrintWriter} import scala.reflect.ClassTag import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} -import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { - - test("parse analyze commands") { - def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { - val parsed = spark.sessionState.sqlParser.parsePlan(analyzeCommand) - val operators = parsed.collect { - case a: AnalyzeTableCommand => a - case o => o - } - - assert(operators.size === 1) - if (operators(0).getClass() != c) { - fail( - s"""$analyzeCommand expected command: $c, but got ${operators(0)} - |parsed command: - |$parsed - """.stripMargin) - } - } - - assertAnalyzeCommand( - "ANALYZE TABLE Table1 COMPUTE STATISTICS", - classOf[AnalyzeTableCommand]) - assertAnalyzeCommand( - "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS", - classOf[AnalyzeTableCommand]) - assertAnalyzeCommand( - "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan", - classOf[AnalyzeTableCommand]) - assertAnalyzeCommand( - "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS", - classOf[AnalyzeTableCommand]) - assertAnalyzeCommand( - "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS noscan", - classOf[AnalyzeTableCommand]) - - assertAnalyzeCommand( - "ANALYZE TABLE Table1 COMPUTE STATISTICS nOscAn", - classOf[AnalyzeTableCommand]) - } +class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { test("MetastoreRelations fallback to HDFS for size estimation") { val enableFallBackToHdfsForStats = spark.sessionState.conf.fallBackToHdfsForStatsEnabled @@ -310,6 +270,110 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } + test("verify serialized column stats after analyzing columns") { + import testImplicits._ + + val tableName = "column_stats_test2" + // (data.head.productArity - 1) because the last column does not support stats collection. + assert(stats.size == data.head.productArity - 1) + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Collect statistics + sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) + + // Validate statistics + val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val table = hiveClient.getTable("default", tableName) + + val props = table.properties.filterKeys(_.startsWith("spark.sql.statistics.colStats")) + assert(props == Map( + "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", + "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", + "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", + "spark.sql.statistics.colStats.cbinary.version" -> "1", + "spark.sql.statistics.colStats.cbool.avgLen" -> "1", + "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbool.max" -> "true", + "spark.sql.statistics.colStats.cbool.maxLen" -> "1", + "spark.sql.statistics.colStats.cbool.min" -> "false", + "spark.sql.statistics.colStats.cbool.nullCount" -> "1", + "spark.sql.statistics.colStats.cbool.version" -> "1", + "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", + "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbyte.max" -> "2", + "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", + "spark.sql.statistics.colStats.cbyte.min" -> "1", + "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", + "spark.sql.statistics.colStats.cbyte.version" -> "1", + "spark.sql.statistics.colStats.cdate.avgLen" -> "4", + "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", + "spark.sql.statistics.colStats.cdate.maxLen" -> "4", + "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", + "spark.sql.statistics.colStats.cdate.nullCount" -> "1", + "spark.sql.statistics.colStats.cdate.version" -> "1", + "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", + "spark.sql.statistics.colStats.cdecimal.version" -> "1", + "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", + "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdouble.max" -> "6.0", + "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", + "spark.sql.statistics.colStats.cdouble.min" -> "1.0", + "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", + "spark.sql.statistics.colStats.cdouble.version" -> "1", + "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", + "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", + "spark.sql.statistics.colStats.cfloat.max" -> "7.0", + "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", + "spark.sql.statistics.colStats.cfloat.min" -> "1.0", + "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", + "spark.sql.statistics.colStats.cfloat.version" -> "1", + "spark.sql.statistics.colStats.cint.avgLen" -> "4", + "spark.sql.statistics.colStats.cint.distinctCount" -> "2", + "spark.sql.statistics.colStats.cint.max" -> "4", + "spark.sql.statistics.colStats.cint.maxLen" -> "4", + "spark.sql.statistics.colStats.cint.min" -> "1", + "spark.sql.statistics.colStats.cint.nullCount" -> "1", + "spark.sql.statistics.colStats.cint.version" -> "1", + "spark.sql.statistics.colStats.clong.avgLen" -> "8", + "spark.sql.statistics.colStats.clong.distinctCount" -> "2", + "spark.sql.statistics.colStats.clong.max" -> "5", + "spark.sql.statistics.colStats.clong.maxLen" -> "8", + "spark.sql.statistics.colStats.clong.min" -> "1", + "spark.sql.statistics.colStats.clong.nullCount" -> "1", + "spark.sql.statistics.colStats.clong.version" -> "1", + "spark.sql.statistics.colStats.cshort.avgLen" -> "2", + "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", + "spark.sql.statistics.colStats.cshort.max" -> "3", + "spark.sql.statistics.colStats.cshort.maxLen" -> "2", + "spark.sql.statistics.colStats.cshort.min" -> "1", + "spark.sql.statistics.colStats.cshort.nullCount" -> "1", + "spark.sql.statistics.colStats.cshort.version" -> "1", + "spark.sql.statistics.colStats.cstring.avgLen" -> "3", + "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", + "spark.sql.statistics.colStats.cstring.maxLen" -> "3", + "spark.sql.statistics.colStats.cstring.nullCount" -> "1", + "spark.sql.statistics.colStats.cstring.version" -> "1", + "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", + "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", + "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", + "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", + "spark.sql.statistics.colStats.ctimestamp.version" -> "1" + )) + } + } + private def testUpdatingTableStats(tableDescription: String, createTableCmd: String): Unit = { test("test table-level statistics for " + tableDescription) { val parquetTable = "parquetTable" @@ -319,7 +383,8 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils TableIdentifier(parquetTable)) assert(DDLUtils.isDatasourceTable(catalogTable)) - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") + // Add a filter to avoid creating too many partitions + sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10") checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) @@ -328,7 +393,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils val fetchedStats1 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") + sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src WHERE key < 10") sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") val fetchedStats2 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) @@ -340,7 +405,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils parquetTable, isDataSourceTable = true, hasSizeInBytes = true, - expectedRowCounts = Some(1000)) + expectedRowCounts = Some(20)) assert(fetchedStats3.get.sizeInBytes == fetchedStats2.get.sizeInBytes) } } @@ -369,6 +434,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } + /** Used to test refreshing cached metadata once table stats are updated. */ private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean): (Statistics, Statistics) = { val tableName = "tbl" var statsBeforeUpdate: Statistics = null @@ -411,145 +477,6 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils assert(statsAfterUpdate.rowCount == Some(2)) } - test("test refreshing column stats of cached data source table by `ANALYZE TABLE` statement") { - val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = true) - - assert(statsBeforeUpdate.sizeInBytes > 0) - assert(statsBeforeUpdate.rowCount == Some(1)) - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = statsBeforeUpdate.colStats("key"), - expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), - rsd = spark.sessionState.conf.ndvMaxError) - - assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes) - assert(statsAfterUpdate.rowCount == Some(2)) - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = statsAfterUpdate.colStats("key"), - expectedColStat = ColumnStat(InternalRow(0L, 2, 1, 2L)), - rsd = spark.sessionState.conf.ndvMaxError) - } - - private lazy val (testDataFrame, expectedColStatsSeq) = { - import testImplicits._ - - val intSeq = Seq(1, 2) - val stringSeq = Seq("a", "bb") - val binarySeq = Seq("a", "bb").map(_.getBytes) - val booleanSeq = Seq(true, false) - val data = intSeq.indices.map { i => - (intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i)) - } - val df: DataFrame = data.toDF("c1", "c2", "c3", "c4") - val expectedColStatsSeq: Seq[(StructField, ColumnStat)] = df.schema.map { f => - val colStat = f.dataType match { - case IntegerType => - ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) - case StringType => - ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, - stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong)) - case BinaryType => - ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, - binarySeq.map(_.length).max.toInt)) - case BooleanType => - ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, - booleanSeq.count(_.equals(false)).toLong)) - } - (f, colStat) - } - (df, expectedColStatsSeq) - } - - private def checkColStats( - tableName: String, - isDataSourceTable: Boolean, - expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = { - val readback = spark.table(tableName) - val stats = readback.queryExecution.analyzed.collect { - case rel: MetastoreRelation => - assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") - rel.catalogTable.stats.get - case rel: LogicalRelation => - assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") - rel.catalogTable.get.stats.get - } - assert(stats.length == 1) - val columnStats = stats.head.colStats - assert(columnStats.size == expectedColStatsSeq.length) - expectedColStatsSeq.foreach { case (field, expectedColStat) => - StatisticsTest.checkColStat( - dataType = field.dataType, - colStat = columnStats(field.name), - expectedColStat = expectedColStat, - rsd = spark.sessionState.conf.ndvMaxError) - } - } - - test("generate and load column-level stats for data source table") { - val dsTable = "dsTable" - withTable(dsTable) { - testDataFrame.write.format("parquet").saveAsTable(dsTable) - sql(s"ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4") - checkColStats(dsTable, isDataSourceTable = true, expectedColStatsSeq) - } - } - - test("generate and load column-level stats for hive serde table") { - val hTable = "hTable" - val tmp = "tmp" - withTable(hTable, tmp) { - testDataFrame.write.format("parquet").saveAsTable(tmp) - sql(s"CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) STORED AS TEXTFILE") - sql(s"INSERT INTO $hTable SELECT * FROM $tmp") - sql(s"ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4") - checkColStats(hTable, isDataSourceTable = false, expectedColStatsSeq) - } - } - - // When caseSensitive is on, for columns with only case difference, they are different columns - // and we should generate column stats for all of them. - private def checkCaseSensitiveColStats(columnName: String): Unit = { - val tableName = "tbl" - withTable(tableName) { - val column1 = columnName.toLowerCase - val column2 = columnName.toUpperCase - withSQLConf("spark.sql.caseSensitive" -> "true") { - sql(s"CREATE TABLE $tableName (`$column1` int, `$column2` double) USING PARQUET") - sql(s"INSERT INTO $tableName SELECT 1, 3.0") - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS `$column1`, `$column2`") - val readback = spark.table(tableName) - val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => - val columnStats = rel.catalogTable.get.stats.get.colStats - assert(columnStats.size == 2) - StatisticsTest.checkColStat( - dataType = IntegerType, - colStat = columnStats(column1), - expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), - rsd = spark.sessionState.conf.ndvMaxError) - StatisticsTest.checkColStat( - dataType = DoubleType, - colStat = columnStats(column2), - expectedColStat = ColumnStat(InternalRow(0L, 3.0d, 3.0d, 1L)), - rsd = spark.sessionState.conf.ndvMaxError) - rel - } - assert(relations.size == 1) - } - } - } - - test("check column statistics for case sensitive column names") { - checkCaseSensitiveColStats(columnName = "c1") - } - - test("check column statistics for case sensitive non-ascii column names") { - // scalastyle:off - // non ascii characters are not allowed in the source code, so we disable the scalastyle. - checkCaseSensitiveColStats(columnName = "列c") - // scalastyle:on - } - test("estimates the size of a test MetastoreRelation") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => From f129ebcd302168b628f47705f4a7d6b7e7b057b0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 23 Nov 2016 12:54:18 -0500 Subject: [PATCH 299/381] [SPARK-18050][SQL] do not create default database if it already exists ## What changes were proposed in this pull request? When we try to create the default database, we ask hive to do nothing if it already exists. However, Hive will log an error message instead of doing nothing, and the error message is quite annoying and confusing. In this PR, we only create default database if it doesn't exist. ## How was this patch tested? N/A Author: Wenchen Fan Closes #15993 from cloud-fan/default-db. --- .../scala/org/apache/spark/sql/internal/SharedState.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 6232c18b1cea..8de95fe64e66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -92,8 +92,12 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { { val defaultDbDefinition = CatalogDatabase( SessionCatalog.DEFAULT_DATABASE, "default database", warehousePath, Map()) - // Initialize default database if it doesn't already exist - externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) + // Initialize default database if it doesn't exist + if (!externalCatalog.databaseExists(SessionCatalog.DEFAULT_DATABASE)) { + // There may be another Spark application creating default database at the same time, here we + // set `ignoreIfExists = true` to avoid `DatabaseAlreadyExists` exception. + externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) + } } /** From 0d1bf2b6c8ac4d4141d7cef0552c22e586843c57 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 23 Nov 2016 11:48:59 -0800 Subject: [PATCH 300/381] [SPARK-18510] Fix data corruption from inferred partition column dataTypes ## What changes were proposed in this pull request? ### The Issue If I specify my schema when doing ```scala spark.read .schema(someSchemaWherePartitionColumnsAreStrings) ``` but if the partition inference can infer it as IntegerType or I assume LongType or DoubleType (basically fixed size types), then once UnsafeRows are generated, your data will be corrupted. ### Proposed solution The partition handling code path is kind of a mess. In my fix I'm probably adding to the mess, but at least trying to standardize the code path. The real issue is that a user that uses the `spark.read` code path can never clearly specify what the partition columns are. If you try to specify the fields in `schema`, we practically ignore what the user provides, and fall back to our inferred data types. What happens in the end is data corruption. My solution tries to fix this by always trying to infer partition columns the first time you specify the table. Once we find what the partition columns are, we try to find them in the user specified schema and use the dataType provided there, or fall back to the smallest common data type. We will ALWAYS append partition columns to the user's schema, even if they didn't ask for it. We will only use the data type they provided if they specified it. While this is confusing, this has been the behavior since Spark 1.6, and I didn't want to change this behavior in the QA period of Spark 2.1. We may revisit this decision later. A side effect of this PR is that we won't need https://github.com/apache/spark/pull/15942 if this PR goes in. ## How was this patch tested? Regression tests Author: Burak Yavuz Closes #15951 from brkyvz/partition-corruption. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- .../execution/datasources/DataSource.scala | 159 ++++++++++++------ .../sql/execution/command/DDLSuite.scala | 2 +- .../sql/streaming/FileStreamSourceSuite.scala | 2 +- .../test/DataStreamReaderWriterSuite.scala | 45 ++++- .../sql/test/DataFrameReaderWriterSuite.scala | 38 ++++- 6 files changed, 190 insertions(+), 58 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ee48baa59c7a..c669c2e2e26e 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2684,7 +2684,7 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. expect_error(read.df(source = "json"), - paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", + paste("Error in loadDF : analysis error - Unable to infer schema for JSON.", "It must be specified manually")) expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 84fde0bbf926..dbc3e712332f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -61,8 +61,12 @@ import org.apache.spark.util.Utils * qualified. This option only works when reading from a [[FileFormat]]. * @param userSpecifiedSchema An optional specification of the schema of the data. When present * we skip attempting to infer the schema. - * @param partitionColumns A list of column names that the relation is partitioned by. When this - * list is empty, the relation is unpartitioned. + * @param partitionColumns A list of column names that the relation is partitioned by. This list is + * generally empty during the read path, unless this DataSource is managed + * by Hive. In these cases, during `resolveRelation`, we will call + * `getOrInferFileFormatSchema` for file based DataSources to infer the + * partitioning. In other cases, if this list is empty, then this table + * is unpartitioned. * @param bucketSpec An optional specification for bucketing (hash-partitioning) of the data. * @param catalogTable Optional catalog table reference that can be used to push down operations * over the datasource to the catalog service. @@ -84,30 +88,106 @@ case class DataSource( private val caseInsensitiveOptions = new CaseInsensitiveMap(options) /** - * Infer the schema of the given FileFormat, returns a pair of schema and partition column names. + * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer + * it. In the read path, only managed tables by Hive provide the partition columns properly when + * initializing this class. All other file based data sources will try to infer the partitioning, + * and then cast the inferred types to user specified dataTypes if the partition columns exist + * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510. + * This method will try to skip file scanning whether `userSpecifiedSchema` and + * `partitionColumns` are provided. Here are some code paths that use this method: + * 1. `spark.read` (no schema): Most amount of work. Infer both schema and partitioning columns + * 2. `spark.read.schema(userSpecifiedSchema)`: Parse partitioning columns, cast them to the + * dataTypes provided in `userSpecifiedSchema` if they exist or fallback to inferred + * dataType if they don't. + * 3. `spark.readStream.schema(userSpecifiedSchema)`: For streaming use cases, users have to + * provide the schema. Here, we also perform partition inference like 2, and try to use + * dataTypes in `userSpecifiedSchema`. All subsequent triggers for this stream will re-use + * this information, therefore calls to this method should be very cheap, i.e. there won't + * be any further inference in any triggers. + * 4. `df.saveAsTable(tableThatExisted)`: In this case, we call this method to resolve the + * existing table's partitioning scheme. This is achieved by not providing + * `userSpecifiedSchema`. For this case, we add the boolean `justPartitioning` for an early + * exit, if we don't care about the schema of the original table. + * + * @param format the file format object for this DataSource + * @param justPartitioning Whether to exit early and provide just the schema partitioning. + * @return A pair of the data schema (excluding partition columns) and the schema of the partition + * columns. If `justPartitioning` is `true`, then the dataSchema will be provided as + * `null`. */ - private def inferFileFormatSchema(format: FileFormat): (StructType, Seq[String]) = { - userSpecifiedSchema.map(_ -> partitionColumns).orElse { - val allPaths = caseInsensitiveOptions.get("path") + private def getOrInferFileFormatSchema( + format: FileFormat, + justPartitioning: Boolean = false): (StructType, StructType) = { + // the operations below are expensive therefore try not to do them if we don't need to + lazy val tempFileCatalog = { + val allPaths = caseInsensitiveOptions.get("path") ++ paths + val hadoopConf = sparkSession.sessionState.newHadoopConf() val globbedPaths = allPaths.toSeq.flatMap { path => val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - val fileCatalog = new InMemoryFileIndex(sparkSession, globbedPaths, options, None) - val partitionSchema = fileCatalog.partitionSpec().partitionColumns - val inferred = format.inferSchema( + new InMemoryFileIndex(sparkSession, globbedPaths, options, None) + } + val partitionSchema = if (partitionColumns.isEmpty && catalogTable.isEmpty) { + // Try to infer partitioning, because no DataSource in the read path provides the partitioning + // columns properly unless it is a Hive DataSource + val resolved = tempFileCatalog.partitionSchema.map { partitionField => + val equality = sparkSession.sessionState.conf.resolver + // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred + userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( + partitionField) + } + StructType(resolved) + } else { + // in streaming mode, we have already inferred and registered partition columns, we will + // never have to materialize the lazy val below + lazy val inferredPartitions = tempFileCatalog.partitionSchema + // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred + // partitioning + if (userSpecifiedSchema.isEmpty) { + inferredPartitions + } else { + val partitionFields = partitionColumns.map { partitionColumn => + userSpecifiedSchema.flatMap(_.find(_.name == partitionColumn)).orElse { + val inferredOpt = inferredPartitions.find(_.name == partitionColumn) + if (inferredOpt.isDefined) { + logDebug( + s"""Type of partition column: $partitionColumn not found in specified schema + |for $format. + |User Specified Schema + |===================== + |${userSpecifiedSchema.orNull} + | + |Falling back to inferred dataType if it exists. + """.stripMargin) + } + inferredPartitions.find(_.name == partitionColumn) + }.getOrElse { + throw new AnalysisException(s"Failed to resolve the schema for $format for " + + s"the partition column: $partitionColumn. It must be specified manually.") + } + } + StructType(partitionFields) + } + } + if (justPartitioning) { + return (null, partitionSchema) + } + val dataSchema = userSpecifiedSchema.map { schema => + val equality = sparkSession.sessionState.conf.resolver + StructType(schema.filterNot(f => partitionSchema.exists(p => equality(p.name, f.name)))) + }.orElse { + format.inferSchema( sparkSession, caseInsensitiveOptions, - fileCatalog.allFiles()) - - inferred.map { inferredSchema => - StructType(inferredSchema ++ partitionSchema) -> partitionSchema.map(_.name) - } + tempFileCatalog.allFiles()) }.getOrElse { - throw new AnalysisException("Unable to infer schema. It must be specified manually.") + throw new AnalysisException( + s"Unable to infer schema for $format. It must be specified manually.") } + (dataSchema, partitionSchema) } /** Returns the name and schema of the source that can be used to continually read data. */ @@ -144,8 +224,8 @@ case class DataSource( "you may be able to create a static DataFrame on that directory with " + "'spark.read.load(directory)' and infer schema from it.") } - val (schema, partCols) = inferFileFormatSchema(format) - SourceInfo(s"FileSource[$path]", schema, partCols) + val (schema, partCols) = getOrInferFileFormatSchema(format) + SourceInfo(s"FileSource[$path]", StructType(schema ++ partCols), partCols.fieldNames) case _ => throw new UnsupportedOperationException( @@ -272,7 +352,7 @@ case class DataSource( HadoopFsRelation( fileCatalog, - partitionSchema = fileCatalog.partitionSpec().partitionColumns, + partitionSchema = fileCatalog.partitionSchema, dataSchema = dataSchema, bucketSpec = None, format, @@ -281,9 +361,10 @@ case class DataSource( // This is a non-streaming file based datasource. case (format: FileFormat, _) => val allPaths = caseInsensitiveOptions.get("path") ++ paths + val hadoopConf = sparkSession.sessionState.newHadoopConf() val globbedPaths = allPaths.flatMap { path => val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) val globPath = SparkHadoopUtil.get.globPathIfNecessary(qualified) @@ -291,23 +372,14 @@ case class DataSource( throw new AnalysisException(s"Path does not exist: $qualified") } // Sufficient to check head of the globPath seq for non-glob scenario + // Don't need to check once again if files exist in streaming mode if (checkFilesExist && !fs.exists(globPath.head)) { throw new AnalysisException(s"Path does not exist: ${globPath.head}") } globPath }.toArray - // If they gave a schema, then we try and figure out the types of the partition columns - // from that schema. - val partitionSchema = userSpecifiedSchema.map { schema => - StructType( - partitionColumns.map { c => - // TODO: Case sensitivity. - schema - .find(_.name.toLowerCase() == c.toLowerCase()) - .getOrElse(throw new AnalysisException(s"Invalid partition column '$c'")) - }) - } + val (dataSchema, inferredPartitionSchema) = getOrInferFileFormatSchema(format) val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { @@ -316,27 +388,12 @@ case class DataSource( catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(0L)) } else { - new InMemoryFileIndex( - sparkSession, globbedPaths, options, partitionSchema) - } - - val dataSchema = userSpecifiedSchema.map { schema => - val equality = sparkSession.sessionState.conf.resolver - StructType(schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) - }.orElse { - format.inferSchema( - sparkSession, - caseInsensitiveOptions, - fileCatalog.asInstanceOf[InMemoryFileIndex].allFiles()) - }.getOrElse { - throw new AnalysisException( - s"Unable to infer schema for $format at ${allPaths.take(2).mkString(",")}. " + - "It must be specified manually") + new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(inferredPartitionSchema)) } HadoopFsRelation( fileCatalog, - partitionSchema = fileCatalog.partitionSchema, + partitionSchema = inferredPartitionSchema, dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, @@ -384,11 +441,7 @@ case class DataSource( // up. If we fail to load the table for whatever reason, ignore the check. if (mode == SaveMode.Append) { val existingPartitionColumns = Try { - resolveRelation() - .asInstanceOf[HadoopFsRelation] - .partitionSchema - .fieldNames - .toSeq + getOrInferFileFormatSchema(format, justPartitioning = true)._2.fieldNames.toList }.getOrElse(Seq.empty[String]) // TODO: Case sensitivity. val sameColumns = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 02d9d1568490..10843e9ba575 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -274,7 +274,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { pathToPartitionedTable, userSpecifiedSchema = Option("num int, str string"), userSpecifiedPartitionCols = partitionCols, - expectedSchema = new StructType().add("num", IntegerType).add("str", StringType), + expectedSchema = new StructType().add("str", StringType).add("num", IntegerType), expectedPartitionCols = partitionCols.map(Seq(_)).getOrElse(Seq.empty[String])) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index a099153d2e58..bad6642ea405 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -282,7 +282,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { createFileStreamSourceAndGetSchema( format = Some("json"), path = Some(src.getCanonicalPath), schema = None) } - assert("Unable to infer schema. It must be specified manually.;" === e.getMessage) + assert("Unable to infer schema for JSON. It must be specified manually.;" === e.getMessage) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 5630464f4080..0eb95a02432f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, StreamingQuery, StreamTest} -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils object LastOptions { @@ -532,4 +532,47 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { assert(e.getMessage.contains("does not support recovering")) assert(e.getMessage.contains("checkpoint location")) } + + test("SPARK-18510: use user specified types for partition columns in file sources") { + import org.apache.spark.sql.functions.udf + import testImplicits._ + withTempDir { src => + val createArray = udf { (length: Long) => + for (i <- 1 to length.toInt) yield i.toString + } + spark.range(4).select(createArray('id + 1) as 'ex, 'id, 'id % 4 as 'part).coalesce(1).write + .partitionBy("part", "id") + .mode("overwrite") + .parquet(src.toString) + // Specify a random ordering of the schema, partition column in the middle, etc. + // Also let's say that the partition columns are Strings instead of Longs. + // partition columns should go to the end + val schema = new StructType() + .add("id", StringType) + .add("ex", ArrayType(StringType)) + + val sdf = spark.readStream + .schema(schema) + .format("parquet") + .load(src.toString) + + assert(sdf.schema.toList === List( + StructField("ex", ArrayType(StringType)), + StructField("part", IntegerType), // inferred partitionColumn dataType + StructField("id", StringType))) // used user provided partitionColumn dataType + + val sq = sdf.writeStream + .queryName("corruption_test") + .format("memory") + .start() + sq.processAllAvailable() + checkAnswer( + spark.table("corruption_test"), + // notice how `part` is ordered before `id` + Row(Array("1"), 0, "0") :: Row(Array("1", "2"), 1, "1") :: + Row(Array("1", "2", "3"), 2, "2") :: Row(Array("1", "2", "3", "4"), 3, "3") :: Nil + ) + sq.stop() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index a7fda0109856..e0887e0f1c7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -573,4 +573,40 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } } } + + test("SPARK-18510: use user specified types for partition columns in file sources") { + import org.apache.spark.sql.functions.udf + import testImplicits._ + withTempDir { src => + val createArray = udf { (length: Long) => + for (i <- 1 to length.toInt) yield i.toString + } + spark.range(4).select(createArray('id + 1) as 'ex, 'id, 'id % 4 as 'part).coalesce(1).write + .partitionBy("part", "id") + .mode("overwrite") + .parquet(src.toString) + // Specify a random ordering of the schema, partition column in the middle, etc. + // Also let's say that the partition columns are Strings instead of Longs. + // partition columns should go to the end + val schema = new StructType() + .add("id", StringType) + .add("ex", ArrayType(StringType)) + val df = spark.read + .schema(schema) + .format("parquet") + .load(src.toString) + + assert(df.schema.toList === List( + StructField("ex", ArrayType(StringType)), + StructField("part", IntegerType), // inferred partitionColumn dataType + StructField("id", StringType))) // used user provided partitionColumn dataType + + checkAnswer( + df, + // notice how `part` is ordered before `id` + Row(Array("1"), 0, "0") :: Row(Array("1", "2"), 1, "1") :: + Row(Array("1", "2", "3"), 2, "2") :: Row(Array("1", "2", "3", "4"), 3, "3") :: Nil + ) + } + } } From 223fa218e1f637f0d62332785a3bee225b65b990 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 23 Nov 2016 16:15:35 -0800 Subject: [PATCH 301/381] [SPARK-18510][SQL] Follow up to address comments in #15951 ## What changes were proposed in this pull request? This PR addressed the rest comments in #15951. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #15997 from zsxwing/SPARK-18510-follow-up. --- .../execution/datasources/DataSource.scala | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index dbc3e712332f..ccfc759c8fa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -118,8 +118,10 @@ case class DataSource( private def getOrInferFileFormatSchema( format: FileFormat, justPartitioning: Boolean = false): (StructType, StructType) = { - // the operations below are expensive therefore try not to do them if we don't need to - lazy val tempFileCatalog = { + // the operations below are expensive therefore try not to do them if we don't need to, e.g., + // in streaming mode, we have already inferred and registered partition columns, we will + // never have to materialize the lazy val below + lazy val tempFileIndex = { val allPaths = caseInsensitiveOptions.get("path") ++ paths val hadoopConf = sparkSession.sessionState.newHadoopConf() val globbedPaths = allPaths.toSeq.flatMap { path => @@ -133,7 +135,7 @@ case class DataSource( val partitionSchema = if (partitionColumns.isEmpty && catalogTable.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource - val resolved = tempFileCatalog.partitionSchema.map { partitionField => + val resolved = tempFileIndex.partitionSchema.map { partitionField => val equality = sparkSession.sessionState.conf.resolver // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( @@ -141,17 +143,17 @@ case class DataSource( } StructType(resolved) } else { - // in streaming mode, we have already inferred and registered partition columns, we will - // never have to materialize the lazy val below - lazy val inferredPartitions = tempFileCatalog.partitionSchema // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred // partitioning if (userSpecifiedSchema.isEmpty) { + val inferredPartitions = tempFileIndex.partitionSchema inferredPartitions } else { val partitionFields = partitionColumns.map { partitionColumn => - userSpecifiedSchema.flatMap(_.find(_.name == partitionColumn)).orElse { - val inferredOpt = inferredPartitions.find(_.name == partitionColumn) + val equality = sparkSession.sessionState.conf.resolver + userSpecifiedSchema.flatMap(_.find(c => equality(c.name, partitionColumn))).orElse { + val inferredPartitions = tempFileIndex.partitionSchema + val inferredOpt = inferredPartitions.find(p => equality(p.name, partitionColumn)) if (inferredOpt.isDefined) { logDebug( s"""Type of partition column: $partitionColumn not found in specified schema @@ -163,7 +165,7 @@ case class DataSource( |Falling back to inferred dataType if it exists. """.stripMargin) } - inferredPartitions.find(_.name == partitionColumn) + inferredOpt }.getOrElse { throw new AnalysisException(s"Failed to resolve the schema for $format for " + s"the partition column: $partitionColumn. It must be specified manually.") @@ -182,7 +184,7 @@ case class DataSource( format.inferSchema( sparkSession, caseInsensitiveOptions, - tempFileCatalog.allFiles()) + tempFileIndex.allFiles()) }.getOrElse { throw new AnalysisException( s"Unable to infer schema for $format. It must be specified manually.") @@ -224,8 +226,11 @@ case class DataSource( "you may be able to create a static DataFrame on that directory with " + "'spark.read.load(directory)' and infer schema from it.") } - val (schema, partCols) = getOrInferFileFormatSchema(format) - SourceInfo(s"FileSource[$path]", StructType(schema ++ partCols), partCols.fieldNames) + val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format) + SourceInfo( + s"FileSource[$path]", + StructType(dataSchema ++ partitionSchema), + partitionSchema.fieldNames) case _ => throw new UnsupportedOperationException( @@ -379,7 +384,7 @@ case class DataSource( globPath }.toArray - val (dataSchema, inferredPartitionSchema) = getOrInferFileFormatSchema(format) + val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format) val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { @@ -388,12 +393,12 @@ case class DataSource( catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(0L)) } else { - new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(inferredPartitionSchema)) + new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(partitionSchema)) } HadoopFsRelation( fileCatalog, - partitionSchema = inferredPartitionSchema, + partitionSchema = partitionSchema, dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, From 2dfabec38c24174e7f747c27c7144f7738483ec1 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 24 Nov 2016 05:46:05 -0800 Subject: [PATCH 302/381] [SPARK-18520][ML] Add missing setXXXCol methods for BisectingKMeansModel and GaussianMixtureModel ## What changes were proposed in this pull request? add `setFeaturesCol` and `setPredictionCol` for BiKModel and GMModel add `setProbabilityCol` for GMModel ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #15957 from zhengruifeng/bikm_set. --- .../apache/spark/ml/clustering/BisectingKMeans.scala | 8 ++++++++ .../apache/spark/ml/clustering/GaussianMixture.scala | 12 ++++++++++++ 2 files changed, 20 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index e6ca3aedffd9..cf11ba37abb5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -98,6 +98,14 @@ class BisectingKMeansModel private[ml] ( copied.setSummary(trainingSummary).setParent(this.parent) } + /** @group setParam */ + @Since("2.1.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.1.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 92d0b7d085f1..19998ca44b11 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -87,6 +87,18 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") val gaussians: Array[MultivariateGaussian]) extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable { + /** @group setParam */ + @Since("2.1.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.1.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.1.0") + def setProbabilityCol(value: String): this.type = set(probabilityCol, value) + @Since("2.0.0") override def copy(extra: ParamMap): GaussianMixtureModel = { val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra) From a367d5ff005884322fb8bb43a1cfa4d4bf54b31a Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Thu, 24 Nov 2016 12:07:55 -0800 Subject: [PATCH 303/381] [SPARK-18578][SQL] Full outer join in correlated subquery returns incorrect results ## What changes were proposed in this pull request? - Raise Analysis exception when correlated predicates exist in the descendant operators of either operand of a Full outer join in a subquery as well as in a FOJ operator itself - Raise Analysis exception when correlated predicates exists in a Window operator (a side effect inadvertently introduced by SPARK-17348) ## How was this patch tested? Run sql/test catalyst/test and new test cases, added to SubquerySuite, showing the reported incorrect results. Author: Nattavut Sutyanyong Closes #16005 from nsyca/FOJ-incorrect.1. --- .../sql/catalyst/analysis/Analyzer.scala | 10 +++++ .../org/apache/spark/sql/SubquerySuite.scala | 45 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0155741ddbc1..1db44496e67c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1076,6 +1076,10 @@ class Analyzer( // Simplify the predicates before pulling them out. val transformed = BooleanSimplification(sub) transformUp { + // WARNING: + // Only Filter can host correlated expressions at this time + // Anyone adding a new "case" below needs to add the call to + // "failOnOuterReference" to disallow correlated expressions in it. case f @ Filter(cond, child) => // Find all predicates with an outer reference. val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) @@ -1116,12 +1120,18 @@ class Analyzer( a } case w : Window => + failOnOuterReference(w) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, w) w case j @ Join(left, _, RightOuter, _) => failOnOuterReference(j) failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN") j + // SPARK-18578: Do not allow any correlated predicate + // in a Full (Outer) Join operator and its descendants + case j @ Join(_, _, FullOuter, _) => + failOnOuterReferenceInSubTree(j, "a FULL OUTER JOIN") + j case j @ Join(_, right, jt, _) if !jt.isInstanceOf[InnerLike] => failOnOuterReference(j) failOnOuterReferenceInSubTree(right, "a LEFT (OUTER) JOIN") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index f1dd1c620e66..73a53944964f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -744,4 +744,49 @@ class SubquerySuite extends QueryTest with SharedSQLContext { } } } + // This restriction applies to + // the permutation of { LOJ, ROJ, FOJ } x { EXISTS, IN, scalar subquery } + // where correlated predicates appears in right operand of LOJ, + // or in left operand of ROJ, or in either operand of FOJ. + // The test cases below cover the representatives of the patterns + test("Correlated subqueries in outer joins") { + withTempView("t1", "t2", "t3") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + Seq(1).toDF("c1").createOrReplaceTempView("t3") + + // Left outer join (LOJ) in IN subquery context + intercept[AnalysisException] { + sql( + """ + | select t1.c1 + | from t1 + | where 1 IN (select 1 + | from t3 left outer join + | (select c1 from t2 where t1.c1 = 2) t2 + | on t2.c1 = t3.c1)""".stripMargin).collect() + } + // Right outer join (ROJ) in EXISTS subquery context + intercept[AnalysisException] { + sql( + """ + | select t1.c1 + | from t1 + | where exists (select 1 + | from (select c1 from t2 where t1.c1 = 2) t2 + | right outer join t3 + | on t2.c1 = t3.c1)""".stripMargin).collect() + } + // SPARK-18578: Full outer join (FOJ) in scalar subquery context + intercept[AnalysisException] { + sql( + """ + | select (select max(1) + | from (select c1 from t2 where t1.c1 = 2 and t1.c1=t2.c1) t2 + | full join t3 + | on t2.c1=t3.c1) + | from t1""".stripMargin).collect() + } + } + } } From f58a8aa20106ea36386db79a8a66f529a8da75c9 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Fri, 25 Nov 2016 09:10:17 +0000 Subject: [PATCH 304/381] [SPARK-18575][WEB] Keep same style: adjust the position of driver log links ## What changes were proposed in this pull request? NOT BUG, just adjust the position of driver log link to keep the same style with other executors log link. ![image](https://cloud.githubusercontent.com/assets/7402327/20590092/f8bddbb8-b25b-11e6-9aaf-3b5b3073df10.png) ## How was this patch tested? no Author: uncleGen Closes #16001 from uncleGen/SPARK-18575. --- .../spark/scheduler/cluster/YarnClusterSchedulerBackend.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index ced597bed36d..4f3d5ebf403e 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -55,8 +55,8 @@ private[spark] class YarnClusterSchedulerBackend( val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" logDebug(s"Base URL for logs: $baseUrl") driverLogs = Some(Map( - "stderr" -> s"$baseUrl/stderr?start=-4096", - "stdout" -> s"$baseUrl/stdout?start=-4096")) + "stdout" -> s"$baseUrl/stdout?start=-4096", + "stderr" -> s"$baseUrl/stderr?start=-4096")) } catch { case e: Exception => logInfo("Error while building AM log links, so AM" + From f42db0c0c1434bfcccaa70d0db55e16c4396af04 Mon Sep 17 00:00:00 2001 From: "n.fraison" Date: Fri, 25 Nov 2016 09:45:51 +0000 Subject: [PATCH 305/381] [SPARK-18119][SPARK-CORE] Namenode safemode check is only performed on one namenode which can stuck the startup of SparkHistory server ## What changes were proposed in this pull request? Instead of using the setSafeMode method that check the first namenode used the one which permitts to check only for active NNs ## How was this patch tested? manual tests Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. This commit is contributed by Criteo SA under the Apache v2 licence. Author: n.fraison Closes #15648 from ashangit/SPARK-18119. --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index ca38a4763942..8ef69b142cd1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -663,9 +663,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) false } - // For testing. private[history] def isFsInSafeMode(dfs: DistributedFileSystem): Boolean = { - dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET) + /* true to check only for Active NNs status */ + dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET, true) } /** From 51b1c1551d3a7147403b9e821fcc7c8f57b4824c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 25 Nov 2016 11:27:07 +0000 Subject: [PATCH 306/381] [SPARK-3359][BUILD][DOCS] More changes to resolve javadoc 8 errors that will help unidoc/genjavadoc compatibility ## What changes were proposed in this pull request? This PR only tries to fix things that looks pretty straightforward and were fixed in other previous PRs before. This PR roughly fixes several things as below: - Fix unrecognisable class and method links in javadoc by changing it from `[[..]]` to `` `...` `` ``` [error] .../spark/sql/core/target/java/org/apache/spark/sql/streaming/DataStreamReader.java:226: error: reference not found [error] * Loads text files and returns a {link DataFrame} whose schema starts with a string column named ``` - Fix an exception annotation and remove code backticks in `throws` annotation Currently, sbt unidoc with Java 8 complains as below: ``` [error] .../java/org/apache/spark/sql/streaming/StreamingQuery.java:72: error: unexpected text [error] * throws StreamingQueryException, if this query has terminated with an exception. ``` `throws` should specify the correct class name from `StreamingQueryException,` to `StreamingQueryException` without backticks. (see [JDK-8007644](https://bugs.openjdk.java.net/browse/JDK-8007644)). - Fix `[[http..]]` to ``. ```diff - * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle - * blog page]]. + * + * Oracle blog page. ``` `[[http...]]` link markdown in scaladoc is unrecognisable in javadoc. - It seems class can't have `return` annotation. So, two cases of this were removed. ``` [error] .../java/org/apache/spark/mllib/regression/IsotonicRegression.java:27: error: invalid use of return [error] * return New instance of IsotonicRegression. ``` - Fix < to `<` and > to `>` according to HTML rules. - Fix `

    ` complaint - Exclude unrecognisable in javadoc, `constructor`, `todo` and `groupname`. ## How was this patch tested? Manually tested by `jekyll build` with Java 7 and 8 ``` java version "1.7.0_80" Java(TM) SE Runtime Environment (build 1.7.0_80-b15) Java HotSpot(TM) 64-Bit Server VM (build 24.80-b11, mixed mode) ``` ``` java version "1.8.0_45" Java(TM) SE Runtime Environment (build 1.8.0_45-b14) Java HotSpot(TM) 64-Bit Server VM (build 25.45-b02, mixed mode) ``` Note: this does not yet make sbt unidoc suceed with Java 8 yet but it reduces the number of errors with Java 8. Author: hyukjinkwon Closes #15999 from HyukjinKwon/SPARK-3359-errors. --- .../scala/org/apache/spark/SSLOptions.scala | 4 +- .../apache/spark/api/java/JavaPairRDD.scala | 6 +- .../org/apache/spark/api/java/JavaRDD.scala | 10 +-- .../spark/api/java/JavaSparkContext.scala | 14 ++-- .../apache/spark/io/CompressionCodec.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 18 ++--- .../spark/security/CryptoStreamUtils.scala | 4 +- .../spark/serializer/KryoSerializer.scala | 3 +- .../storage/BlockReplicationPolicy.scala | 7 +- .../scala/org/apache/spark/ui/UIUtils.scala | 4 +- .../org/apache/spark/util/AccumulatorV2.scala | 2 +- .../org/apache/spark/util/RpcUtils.scala | 2 +- .../org/apache/spark/util/StatCounter.scala | 4 +- .../org/apache/spark/util/ThreadUtils.scala | 6 +- .../scala/org/apache/spark/util/Utils.scala | 10 +-- .../spark/util/io/ChunkedByteBuffer.scala | 2 +- .../scala/org/apache/spark/graphx/Graph.scala | 4 +- .../org/apache/spark/graphx/GraphLoader.scala | 2 +- .../spark/graphx/impl/EdgeRDDImpl.scala | 2 +- .../apache/spark/graphx/lib/PageRank.scala | 4 +- .../apache/spark/graphx/lib/SVDPlusPlus.scala | 3 +- .../spark/graphx/lib/TriangleCount.scala | 2 +- .../distribution/MultivariateGaussian.scala | 3 +- .../scala/org/apache/spark/ml/Predictor.scala | 2 +- .../spark/ml/attribute/AttributeGroup.scala | 2 +- .../spark/ml/attribute/attributes.scala | 4 +- .../classification/LogisticRegression.scala | 74 +++++++++---------- .../MultilayerPerceptronClassifier.scala | 1 - .../spark/ml/classification/NaiveBayes.scala | 8 +- .../RandomForestClassifier.scala | 6 +- .../spark/ml/clustering/BisectingKMeans.scala | 14 ++-- .../ml/clustering/ClusteringSummary.scala | 2 +- .../spark/ml/clustering/GaussianMixture.scala | 6 +- .../apache/spark/ml/clustering/KMeans.scala | 8 +- .../org/apache/spark/ml/clustering/LDA.scala | 42 +++++------ .../org/apache/spark/ml/feature/DCT.scala | 3 +- .../org/apache/spark/ml/feature/MinHash.scala | 5 +- .../spark/ml/feature/MinMaxScaler.scala | 4 +- .../ml/feature/PolynomialExpansion.scala | 14 ++-- .../spark/ml/feature/RandomProjection.scala | 4 +- .../spark/ml/feature/StandardScaler.scala | 4 +- .../spark/ml/feature/StopWordsRemover.scala | 5 +- .../org/apache/spark/ml/feature/package.scala | 3 +- .../IterativelyReweightedLeastSquares.scala | 7 +- .../spark/ml/param/shared/sharedParams.scala | 12 +-- .../ml/regression/AFTSurvivalRegression.scala | 27 +++---- .../ml/regression/DecisionTreeRegressor.scala | 4 +- .../spark/ml/regression/GBTRegressor.scala | 4 +- .../GeneralizedLinearRegression.scala | 12 +-- .../ml/regression/LinearRegression.scala | 38 +++++----- .../ml/regression/RandomForestRegressor.scala | 5 +- .../ml/source/libsvm/LibSVMDataSource.scala | 13 ++-- .../ml/tree/impl/GradientBoostedTrees.scala | 10 +-- .../spark/ml/tree/impl/RandomForest.scala | 2 +- .../org/apache/spark/ml/tree/treeParams.scala | 6 +- .../spark/ml/tuning/CrossValidator.scala | 4 +- .../org/apache/spark/ml/util/ReadWrite.scala | 10 +-- .../mllib/classification/NaiveBayes.scala | 28 +++---- .../mllib/clustering/BisectingKMeans.scala | 21 +++--- .../clustering/BisectingKMeansModel.scala | 4 +- .../mllib/clustering/GaussianMixture.scala | 6 +- .../clustering/GaussianMixtureModel.scala | 2 +- .../apache/spark/mllib/clustering/LDA.scala | 24 +++--- .../spark/mllib/clustering/LDAModel.scala | 2 +- .../spark/mllib/clustering/LDAOptimizer.scala | 2 +- .../clustering/PowerIterationClustering.scala | 13 ++-- .../mllib/clustering/StreamingKMeans.scala | 4 +- .../mllib/evaluation/RegressionMetrics.scala | 10 ++- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 12 +-- .../apache/spark/mllib/fpm/PrefixSpan.scala | 7 +- .../linalg/distributed/BlockMatrix.scala | 20 ++--- .../linalg/distributed/CoordinateMatrix.scala | 4 +- .../linalg/distributed/IndexedRowMatrix.scala | 4 +- .../mllib/linalg/distributed/RowMatrix.scala | 2 +- .../spark/mllib/optimization/Gradient.scala | 24 +++--- .../mllib/optimization/GradientDescent.scala | 4 +- .../spark/mllib/optimization/LBFGS.scala | 7 +- .../spark/mllib/optimization/NNLS.scala | 2 +- .../spark/mllib/optimization/Updater.scala | 6 +- .../org/apache/spark/mllib/package.scala | 4 +- .../apache/spark/mllib/rdd/RDDFunctions.scala | 2 +- .../spark/mllib/recommendation/ALS.scala | 7 +- .../MatrixFactorizationModel.scala | 6 +- .../mllib/regression/IsotonicRegression.scala | 9 +-- .../stat/MultivariateOnlineSummarizer.scala | 7 +- .../apache/spark/mllib/stat/Statistics.scala | 11 +-- .../distribution/MultivariateGaussian.scala | 3 +- .../mllib/tree/GradientBoostedTrees.scala | 2 +- .../spark/mllib/tree/RandomForest.scala | 8 +- .../apache/spark/mllib/tree/model/Split.scala | 2 +- .../org/apache/spark/mllib/util/MLUtils.scala | 10 +-- .../spark/mllib/util/modelSaveLoad.scala | 2 +- pom.xml | 12 +++ project/SparkBuild.scala | 5 +- .../main/scala/org/apache/spark/sql/Row.scala | 2 +- .../aggregate/CentralMomentAgg.scala | 4 +- .../apache/spark/sql/types/BinaryType.scala | 2 +- .../apache/spark/sql/types/BooleanType.scala | 2 +- .../org/apache/spark/sql/types/ByteType.scala | 2 +- .../sql/types/CalendarIntervalType.scala | 2 +- .../org/apache/spark/sql/types/DateType.scala | 2 +- .../apache/spark/sql/types/DecimalType.scala | 4 +- .../apache/spark/sql/types/DoubleType.scala | 2 +- .../apache/spark/sql/types/FloatType.scala | 2 +- .../apache/spark/sql/types/IntegerType.scala | 2 +- .../org/apache/spark/sql/types/LongType.scala | 2 +- .../org/apache/spark/sql/types/MapType.scala | 2 +- .../org/apache/spark/sql/types/NullType.scala | 2 +- .../apache/spark/sql/types/ShortType.scala | 2 +- .../apache/spark/sql/types/StringType.scala | 2 +- .../spark/sql/types/TimestampType.scala | 2 +- .../apache/spark/sql/DataFrameReader.scala | 17 +++-- .../spark/sql/DataFrameStatFunctions.scala | 16 ++-- .../apache/spark/sql/DataFrameWriter.scala | 4 +- .../org/apache/spark/sql/SQLContext.scala | 62 ++++++++-------- .../sql/execution/stat/FrequentItems.scala | 3 +- .../sql/execution/stat/StatFunctions.scala | 4 +- .../spark/sql/expressions/Aggregator.scala | 8 +- .../sql/expressions/UserDefinedFunction.scala | 2 +- .../apache/spark/sql/expressions/Window.scala | 16 ++-- .../spark/sql/expressions/WindowSpec.scala | 16 ++-- .../sql/expressions/scalalang/typed.scala | 2 +- .../apache/spark/sql/expressions/udaf.scala | 24 +++--- .../apache/spark/sql/jdbc/JdbcDialects.scala | 6 +- .../sql/streaming/DataStreamReader.scala | 20 ++--- .../sql/streaming/DataStreamWriter.scala | 8 +- .../spark/sql/streaming/StreamingQuery.scala | 10 ++- .../sql/streaming/StreamingQueryManager.scala | 8 +- .../sql/util/QueryExecutionListener.scala | 2 +- .../hive/execution/InsertIntoHiveTable.scala | 4 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 4 +- .../spark/sql/hive/orc/OrcFileOperator.scala | 2 +- 132 files changed, 558 insertions(+), 499 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index be19179b00a4..5f14102c3c36 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -150,8 +150,8 @@ private[spark] object SSLOptions extends Logging { * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers * * For a list of protocols and ciphers supported by particular Java versions, you may go to - * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle - * blog page]]. + * + * Oracle blog page. * * You can optionally specify the default configuration. If you do, for each setting which is * missing in SparkConf, the corresponding setting is used from the default configuration. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index bff5a29bb60f..d7e3a1b1be48 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -405,7 +405,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * partitioning of the resulting key-value pair RDD by passing a Partitioner. * * @note If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JIterable[V]] = @@ -416,7 +416,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * resulting RDD with into `numPartitions` partitions. * * @note If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(numPartitions: Int): JavaPairRDD[K, JIterable[V]] = @@ -546,7 +546,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * resulting RDD with the existing partitioner/parallelism level. * * @note If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(): JavaPairRDD[K, JIterable[V]] = diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index ccd94f876e0b..a20d264be5af 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -103,10 +103,10 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be >= 0 * * @note This is NOT guaranteed to provide exactly the fraction of the count - * of the given [[RDD]]. + * of the given `RDD`. */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) @@ -117,11 +117,11 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be >= 0 * @param seed seed for the random number generator * * @note This is NOT guaranteed to provide exactly the fraction of the count - * of the given [[RDD]]. + * of the given `RDD`. */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) @@ -167,7 +167,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be <= us. */ def subtract(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.subtract(other)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 38d347aeab8c..9481156bc93a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -238,7 +238,9 @@ class JavaSparkContext(val sc: SparkContext) * }}} * * Do - * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * {{{ + * JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path") + * }}} * * then `rdd` contains * {{{ @@ -270,7 +272,9 @@ class JavaSparkContext(val sc: SparkContext) * }}} * * Do - * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * {{{ + * JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path") + * }}}, * * then `rdd` contains * {{{ @@ -749,7 +753,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get a local property set in this thread, or null if it is missing. See - * [[org.apache.spark.api.java.JavaSparkContext.setLocalProperty]]. + * `org.apache.spark.api.java.JavaSparkContext.setLocalProperty`. */ def getLocalProperty(key: String): String = sc.getLocalProperty(key) @@ -769,7 +773,7 @@ class JavaSparkContext(val sc: SparkContext) * Application programmers can use this method to group all those jobs together and give a * group description. Once set, the Spark web UI will associate such jobs with this group. * - * The application can also use [[org.apache.spark.api.java.JavaSparkContext.cancelJobGroup]] + * The application can also use `org.apache.spark.api.java.JavaSparkContext.cancelJobGroup` * to cancel all running jobs in this group. For example, * {{{ * // In the main thread: @@ -802,7 +806,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Cancel active jobs for the specified group. See - * [[org.apache.spark.api.java.JavaSparkContext.setJobGroup]] for more information. + * `org.apache.spark.api.java.JavaSparkContext.setJobGroup` for more information. */ def cancelJobGroup(groupId: String): Unit = sc.cancelJobGroup(groupId) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 6ba79e506a64..2e991ce394c4 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -172,7 +172,7 @@ private final object SnappyCompressionCodec { } /** - * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close + * Wrapper over `SnappyOutputStream` which guards against write-after-close and double-close * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version * of snappy-java that contains the fix for https://github.com/xerial/snappy-java/issues/107. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index bff2b8f1d06c..8e673447581c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -70,8 +70,8 @@ import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, Poi * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for * reading data from a new storage system) by overriding these functions. Please refer to the - * [[http://people.csail.mit.edu/matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details - * on RDD internals. + * Spark paper + * for more details on RDD internals. */ abstract class RDD[T: ClassTag]( @transient private var _sc: SparkContext, @@ -469,7 +469,7 @@ abstract class RDD[T: ClassTag]( * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be >= 0 * @param seed seed for the random number generator * * @note This is NOT guaranteed to provide exactly the fraction of the count @@ -675,8 +675,8 @@ abstract class RDD[T: ClassTag]( * may even differ each time the resulting RDD is evaluated. * * @note This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = withScope { groupBy[K](f, defaultPartitioner(this)) @@ -688,8 +688,8 @@ abstract class RDD[T: ClassTag]( * may even differ each time the resulting RDD is evaluated. * * @note This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K]( f: T => K, @@ -703,8 +703,8 @@ abstract class RDD[T: ClassTag]( * may even differ each time the resulting RDD is evaluated. * * @note This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K](f: T => K, p: Partitioner)(implicit kt: ClassTag[K], ord: Ordering[K] = null) : RDD[(K, Iterable[T])] = withScope { diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index 8f15f50bee81..f41fc38be208 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -46,7 +46,7 @@ private[spark] object CryptoStreamUtils extends Logging { val COMMONS_CRYPTO_CONF_PREFIX = "commons.crypto." /** - * Helper method to wrap [[OutputStream]] with [[CryptoOutputStream]] for encryption. + * Helper method to wrap `OutputStream` with `CryptoOutputStream` for encryption. */ def createCryptoOutputStream( os: OutputStream, @@ -62,7 +62,7 @@ private[spark] object CryptoStreamUtils extends Logging { } /** - * Helper method to wrap [[InputStream]] with [[CryptoInputStream]] for decryption. + * Helper method to wrap `InputStream` with `CryptoInputStream` for decryption. */ def createCryptoInputStream( is: InputStream, diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 19e020c968a9..7eb2da1c2748 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -43,7 +43,8 @@ import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, S import org.apache.spark.util.collection.CompactBuffer /** - * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. + * A Spark serializer that uses the + * Kryo serialization library. * * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index bf087af16a5b..bb8a684b4c7a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -89,17 +89,18 @@ class RandomBlockReplicationPolicy prioritizedPeers } + // scalastyle:off line.size.limit /** * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage - * [[http://math.stackexchange.com/questions/178690/ - * whats-the-proof-of-correctness-for-robert-floyds-algorithm-for-selecting-a-sin]] + * minimizing space usage. Please see + * here. * * @param n total number of indices * @param m number of samples needed * @param r random number generator * @return list of m random unique indices */ + // scalastyle:on line.size.limit private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) => val t = r.nextInt(i) + 1 diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 57f6f2f0a9be..dbeb970c81df 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -422,8 +422,8 @@ private[spark] object UIUtils extends Logging { * the whole string will rendered as a simple escaped text. * * Note: In terms of security, only anchor tags with root relative links are supported. So any - * attempts to embed links outside Spark UI, or other tags like