From 4c1a8682e658a4085660d74ffef28ed98adad25d Mon Sep 17 00:00:00 2001 From: Andrea zito Date: Wed, 25 Oct 2017 10:10:24 -0700 Subject: [PATCH 01/11] [SPARK-21991][LAUNCHER] Fix race condition in LauncherServer#acceptConnections ## What changes were proposed in this pull request? This patch changes the order in which _acceptConnections_ starts the client thread and schedules the client timeout action ensuring that the latter has been scheduled before the former get a chance to cancel it. ## How was this patch tested? Due to the non-deterministic nature of the patch I wasn't able to add a new test for this issue. Author: Andrea zito Closes #19217 from nivox/SPARK-21991. (cherry picked from commit 6ea8a56ca26a7e02e6574f5f763bb91059119a80) Signed-off-by: Marcelo Vanzin --- .../apache/spark/launcher/LauncherServer.java | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 865d4926da6a9..454bc7a7f924d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -232,20 +232,20 @@ public void run() { }; ServerConnection clientConnection = new ServerConnection(client, timeout); Thread clientThread = factory.newThread(clientConnection); - synchronized (timeout) { - clientThread.start(); - synchronized (clients) { - clients.add(clientConnection); - } - long timeoutMs = getConnectionTimeout(); - // 0 is used for testing to avoid issues with clock resolution / thread scheduling, - // and force an immediate timeout. - if (timeoutMs > 0) { - timeoutTimer.schedule(timeout, getConnectionTimeout()); - } else { - timeout.run(); - } + synchronized (clients) { + clients.add(clientConnection); + } + + long timeoutMs = getConnectionTimeout(); + // 0 is used for testing to avoid issues with clock resolution / thread scheduling, + // and force an immediate timeout. + if (timeoutMs > 0) { + timeoutTimer.schedule(timeout, timeoutMs); + } else { + timeout.run(); } + + clientThread.start(); } } catch (IOException ioe) { if (running) { From 9ed64048a740fbcd15d2b830b1edbb728f87c423 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 25 Oct 2017 22:15:44 +0100 Subject: [PATCH 02/11] [SPARK-22227][CORE] DiskBlockManager.getAllBlocks now tolerates temp files Prior to this commit getAllBlocks implicitly assumed that the directories managed by the DiskBlockManager contain only the files corresponding to valid block IDs. In reality, this assumption was violated during shuffle, which produces temporary files in the same directory as the resulting blocks. As a result, calls to getAllBlocks during shuffle were unreliable. The fix could be made more efficient, but this is probably good enough. `DiskBlockManagerSuite` Author: Sergei Lebedev Closes #19458 from superbobry/block-id-option. (cherry picked from commit b377ef133cdc38d49b460b2cc6ece0b5892804cc) Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/storage/BlockId.scala | 16 +++++++++++++--- .../apache/spark/storage/DiskBlockManager.scala | 11 ++++++++++- .../org/apache/spark/storage/BlockIdSuite.scala | 7 +------ .../spark/storage/DiskBlockManagerSuite.scala | 7 +++++++ 4 files changed, 31 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 524f6970992a5..8c1e657ecc8e0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.util.UUID +import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi /** @@ -100,6 +101,10 @@ private[spark] case class TestBlockId(id: String) extends BlockId { override def name: String = "test_" + id } +@DeveloperApi +class UnrecognizedBlockId(name: String) + extends SparkException(s"Failed to parse $name into a block ID") + @DeveloperApi object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r @@ -109,10 +114,11 @@ object BlockId { val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r + val TEMP_LOCAL = "temp_local_([-A-Fa-f0-9]+)".r + val TEMP_SHUFFLE = "temp_shuffle_([-A-Fa-f0-9]+)".r val TEST = "test_(.*)".r - /** Converts a BlockId "name" String back into a BlockId. */ - def apply(id: String): BlockId = id match { + def apply(name: String): BlockId = name match { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => @@ -127,9 +133,13 @@ object BlockId { TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => StreamBlockId(streamId.toInt, uniqueId.toLong) + case TEMP_LOCAL(uuid) => + TempLocalBlockId(UUID.fromString(uuid)) + case TEMP_SHUFFLE(uuid) => + TempShuffleBlockId(UUID.fromString(uuid)) case TEST(value) => TestBlockId(value) case _ => - throw new IllegalStateException("Unrecognized BlockId: " + id) + throw new UnrecognizedBlockId(name) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3d43e3c367aac..a69bcc9259995 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -100,7 +100,16 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea /** List all the blocks currently stored on disk by the disk manager. */ def getAllBlocks(): Seq[BlockId] = { - getAllFiles().map(f => BlockId(f.getName)) + getAllFiles().flatMap { f => + try { + Some(BlockId(f.getName)) + } catch { + case _: UnrecognizedBlockId => + // Skip files which do not correspond to blocks, for example temporary + // files created by [[SortShuffleWriter]]. + None + } + } } /** Produces a unique block id and File suitable for storing local intermediate results. */ diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index 89ed031b6fcd1..6bc63245f2a81 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -33,13 +33,8 @@ class BlockIdSuite extends SparkFunSuite { } test("test-bad-deserialization") { - try { - // Try to deserialize an invalid block id. + intercept[UnrecognizedBlockId] { BlockId("myblock") - fail() - } catch { - case e: IllegalStateException => // OK - case _: Throwable => fail() } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 7859b0bba2b48..0c4f3c48ef802 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.{File, FileWriter} +import java.util.UUID import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} @@ -79,6 +80,12 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } + test("SPARK-22227: non-block files are skipped") { + val file = diskBlockManager.getFile("unmanaged_file") + writeToFile(file, 10) + assert(diskBlockManager.getAllBlocks().isEmpty) + } + def writeToFile(file: File, numBytes: Int) { val writer = new FileWriter(file, true) for (i <- 0 until numBytes) writer.write(i) From 35725f735019132377d81b7cf13a6a4fb92aecfe Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 25 Oct 2017 14:31:36 -0700 Subject: [PATCH 03/11] [SPARK-22332][ML][TEST] Fix NaiveBayes unit test occasionly fail (cause by test dataset not deterministic) ## What changes were proposed in this pull request? Fix NaiveBayes unit test occasionly fail: Set seed for `BrzMultinomial.sample`, make `generateNaiveBayesInput` output deterministic dataset. (If we do not set seed, the generated dataset will be random, and the model will be possible to exceed the tolerance in the test, which trigger this failure) ## How was this patch tested? Manually run tests multiple times and check each time output models contains the same values. Author: WeichenXu Closes #19558 from WeichenXu123/fix_nb_test_seed. (cherry picked from commit 841f1d776f420424c20d99cf7110d06c73f9ca20) Signed-off-by: Joseph K. Bradley --- .../org/apache/spark/ml/classification/NaiveBayesSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 3a2be236f1257..7c9cf76763d01 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import scala.util.Random import breeze.linalg.{DenseVector => BDV, Vector => BV} -import breeze.stats.distributions.{Multinomial => BrzMultinomial} +import breeze.stats.distributions.{Multinomial => BrzMultinomial, RandBasis => BrzRandBasis} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial} @@ -329,6 +329,7 @@ object NaiveBayesSuite { val _pi = pi.map(math.exp) val _theta = theta.map(row => row.map(math.exp)) + implicit val rngForBrzMultinomial = BrzRandBasis.withSeed(seed) for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) val xi = modelType match { From d2dc175a153733dd664cd77ad56304ec40b95cff Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Wed, 25 Oct 2017 14:41:02 -0700 Subject: [PATCH 04/11] [SPARK-21991][LAUNCHER][FOLLOWUP] Fix java lint ## What changes were proposed in this pull request? Fix java lint ## How was this patch tested? Run `./dev/lint-java` Author: Andrew Ash Closes #19574 from ash211/aash/fix-java-lint. (cherry picked from commit 5433be44caecaeef45ed1fdae10b223c698a9d14) Signed-off-by: Marcelo Vanzin --- .../main/java/org/apache/spark/launcher/LauncherServer.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 454bc7a7f924d..4353e3f263c51 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -235,7 +235,7 @@ public void run() { synchronized (clients) { clients.add(clientConnection); } - + long timeoutMs = getConnectionTimeout(); // 0 is used for testing to avoid issues with clock resolution / thread scheduling, // and force an immediate timeout. @@ -244,7 +244,7 @@ public void run() { } else { timeout.run(); } - + clientThread.start(); } } catch (IOException ioe) { From 24fe7ccbacd913c19fa40199fd5511aaf55c6bfa Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 26 Oct 2017 20:54:36 +0900 Subject: [PATCH 05/11] [SPARK-17902][R] Revive stringsAsFactors option for collect() in SparkR ## What changes were proposed in this pull request? This PR proposes to revive `stringsAsFactors` option in collect API, which was mistakenly removed in https://github.com/apache/spark/commit/71a138cd0e0a14e8426f97877e3b52a562bbd02c. Simply, it casts `charactor` to `factor` if it meets the condition, `stringsAsFactors && is.character(vec)` in primitive type conversion. ## How was this patch tested? Unit test in `R/pkg/tests/fulltests/test_sparkSQL.R`. Author: hyukjinkwon Closes #19551 from HyukjinKwon/SPARK-17902. (cherry picked from commit a83d8d5adcb4e0061e43105767242ba9770dda96) Signed-off-by: hyukjinkwon --- R/pkg/R/DataFrame.R | 3 +++ R/pkg/tests/fulltests/test_sparkSQL.R | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 3859fa8631b38..c0a954df2b10c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1174,6 +1174,9 @@ setMethod("collect", vec <- do.call(c, col) stopifnot(class(vec) != "list") class(vec) <- PRIMITIVE_TYPES[[colType]] + if (is.character(vec) && stringsAsFactors) { + vec <- as.factor(vec) + } df[[colIndex]] <- vec } else { df[[colIndex]] <- col diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 12d8feff2ad68..50c60fe331078 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -483,6 +483,12 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) +test_that("SPARK-17902: collect() with stringsAsFactors enabled", { + df <- suppressWarnings(collect(createDataFrame(iris), stringsAsFactors = TRUE)) + expect_equal(class(iris$Species), class(df$Species)) + expect_equal(iris$Species, df$Species) +}) + test_that("SPARK-17811: can create DataFrame containing NA as date and time", { df <- data.frame( id = 1:2, From a607ddc52e933151327f9b097a453eff38fcf748 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 26 Oct 2017 21:41:45 +0100 Subject: [PATCH 06/11] [SPARK-22328][CORE] ClosureCleaner should not miss referenced superclass fields When the given closure uses some fields defined in super class, `ClosureCleaner` can't figure them and don't set it properly. Those fields will be in null values. Added test. Author: Liang-Chi Hsieh Closes #19556 from viirya/SPARK-22328. (cherry picked from commit 4f8dc6b01ea787243a38678ea8199fbb0814cffc) Signed-off-by: Wenchen Fan --- .../apache/spark/util/ClosureCleaner.scala | 73 ++++++++++++++++--- .../spark/util/ClosureCleanerSuite.scala | 72 ++++++++++++++++++ 2 files changed, 133 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 489688cb0880f..2d5d3f863daa4 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -91,6 +91,54 @@ private[spark] object ClosureCleaner extends Logging { (seen - obj.getClass).toList } + /** Initializes the accessed fields for outer classes and their super classes. */ + private def initAccessedFields( + accessedFields: Map[Class[_], Set[String]], + outerClasses: Seq[Class[_]]): Unit = { + for (cls <- outerClasses) { + var currentClass = cls + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + accessedFields(currentClass) = Set.empty[String] + currentClass = currentClass.getSuperclass() + } + } + } + + /** Sets accessed fields for given class in clone object based on given object. */ + private def setAccessedFields( + outerClass: Class[_], + clone: AnyRef, + obj: AnyRef, + accessedFields: Map[Class[_], Set[String]]): Unit = { + for (fieldName <- accessedFields(outerClass)) { + val field = outerClass.getDeclaredField(fieldName) + field.setAccessible(true) + val value = field.get(obj) + field.set(clone, value) + } + } + + /** Clones a given object and sets accessed fields in cloned object. */ + private def cloneAndSetFields( + parent: AnyRef, + obj: AnyRef, + outerClass: Class[_], + accessedFields: Map[Class[_], Set[String]]): AnyRef = { + val clone = instantiateClass(outerClass, parent) + + var currentClass = outerClass + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + setAccessedFields(currentClass, clone, obj, accessedFields) + currentClass = currentClass.getSuperclass() + } + + clone + } + /** * Clean the given closure in place. * @@ -200,9 +248,8 @@ private[spark] object ClosureCleaner extends Logging { logDebug(s" + populating accessed fields because this is the starting closure") // Initialize accessed fields with the outer classes first // This step is needed to associate the fields to the correct classes later - for (cls <- outerClasses) { - accessedFields(cls) = Set[String]() - } + initAccessedFields(accessedFields, outerClasses) + // Populate accessed fields by visiting all fields and methods accessed by this and // all of its inner closures. If transitive cleaning is enabled, this may recursively // visits methods that belong to other classes in search of transitively referenced fields. @@ -248,13 +295,8 @@ private[spark] object ClosureCleaner extends Logging { // required fields from the original object. We need the parent here because the Java // language specification requires the first constructor parameter of any closure to be // its enclosing object. - val clone = instantiateClass(cls, parent) - for (fieldName <- accessedFields(cls)) { - val field = cls.getDeclaredField(fieldName) - field.setAccessible(true) - val value = field.get(obj) - field.set(clone, value) - } + val clone = cloneAndSetFields(parent, obj, cls, accessedFields) + // If transitive cleaning is enabled, we recursively clean any enclosing closure using // the already populated accessed fields map of the starting closure if (cleanTransitively && isClosure(clone.getClass)) { @@ -393,8 +435,15 @@ private[util] class FieldAccessFinder( if (!visitedMethods.contains(m)) { // Keep track of visited methods to avoid potential infinite cycles visitedMethods += m - ClosureCleaner.getClassReader(cl).accept( - new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + + var currentClass = cl + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + ClosureCleaner.getClassReader(currentClass).accept( + new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + currentClass = currentClass.getSuperclass() + } } } } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 4920b7ee8bfb4..9a19baee9569e 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -119,6 +119,63 @@ class ClosureCleanerSuite extends SparkFunSuite { test("createNullValue") { new TestCreateNullValue().run() } + + test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") { + val concreteObject = new TestAbstractClass { + val n2 = 222 + val s2 = "bbb" + val d2 = 2.0d + + def run(): Seq[(Int, Int, String, String, Double, Double)] = { + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1) + body(rdd) + } + } + + def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] = rdd.map { _ => + (n1, n2, s1, s2, d1, d2) + }.collect() + } + assert(concreteObject.run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) + } + + test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") { + val concreteObject = new TestAbstractClass2 { + val n2 = 222 + val s2 = "bbb" + val d2 = 2.0d + def getData: Int => (Int, Int, String, String, Double, Double) = _ => (n1, n2, s1, s2, d1, d2) + } + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1).map(concreteObject.getData) + assert(rdd.collect() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) + } + } + + test("SPARK-22328: multiple outer classes have the same parent class") { + val concreteObject = new TestAbstractClass2 { + + val innerObject = new TestAbstractClass2 { + override val n1 = 222 + override val s1 = "bbb" + } + + val innerObject2 = new TestAbstractClass2 { + override val n1 = 444 + val n3 = 333 + val s3 = "ccc" + val d3 = 3.0d + + def getData: Int => (Int, Int, String, String, Double, Double, Int, String) = + _ => (n1, n3, s1, s3, d1, d3, innerObject.n1, innerObject.s1) + } + } + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1).map(concreteObject.innerObject2.getData) + assert(rdd.collect() === Seq((444, 333, "aaa", "ccc", 1.0d, 3.0d, 222, "bbb"))) + } + } } // A non-serializable class we create in closures to make sure that we aren't @@ -377,3 +434,18 @@ class TestCreateNullValue { nestedClosure() } } + +abstract class TestAbstractClass extends Serializable { + val n1 = 111 + val s1 = "aaa" + protected val d1 = 1.0d + + def run(): Seq[(Int, Int, String, String, Double, Double)] + def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] +} + +abstract class TestAbstractClass2 extends Serializable { + val n1 = 111 + val s1 = "aaa" + protected val d1 = 1.0d +} From 2839280adc930593c64a74892fec79dcc666d468 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 26 Oct 2017 17:51:16 -0700 Subject: [PATCH 07/11] [SPARK-22355][SQL] Dataset.collect is not threadsafe It's possible that users create a `Dataset`, and call `collect` of this `Dataset` in many threads at the same time. Currently `Dataset#collect` just call `encoder.fromRow` to convert spark rows to objects of type T, and this encoder is per-dataset. This means `Dataset#collect` is not thread-safe, because the encoder uses a projection to output the object to a re-usable row. This PR fixes this problem, by creating a new projection when calling `Dataset#collect`, so that we have the re-usable row for each method call, instead of each Dataset. N/A Author: Wenchen Fan Closes #19577 from cloud-fan/encoder. (cherry picked from commit 5c3a1f3fad695317c2fff1243cdb9b3ceb25c317) Signed-off-by: gatorsmile --- .../scala/org/apache/spark/sql/Dataset.scala | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a775fb8ed4ed3..1acbad960f1bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.ParseException @@ -195,15 +196,10 @@ class Dataset[T] private[sql]( */ private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder) - /** - * Encoder is used mostly as a container of serde expressions in Dataset. We build logical - * plans by these serde expressions and execute it within the query framework. However, for - * performance reasons we may want to use encoder as a function to deserialize internal rows to - * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its - * `fromRow` method later. - */ - private val boundEnc = - exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) + // The deserializer expression which can be used to build a projection and turn rows to objects + // of type T, after collecting rows to the driver side. + private val deserializer = + exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer private implicit def classTag = exprEnc.clsTag @@ -2418,7 +2414,15 @@ class Dataset[T] private[sql]( */ def toLocalIterator(): java.util.Iterator[T] = { withAction("toLocalIterator", queryExecution) { plan => - plan.executeToIterator().map(boundEnc.fromRow).asJava + // This projection writes output to a `InternalRow`, which means applying this projection is + // not thread-safe. Here we create the projection inside this method to make `Dataset` + // thread-safe. + val objProj = GenerateSafeProjection.generate(deserializer :: Nil) + plan.executeToIterator().map { row => + // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type + // parameter of its `get` method, so it's safe to use null here. + objProj(row).get(0, null).asInstanceOf[T] + }.asJava } } @@ -2851,7 +2855,14 @@ class Dataset[T] private[sql]( * Collect all elements from a spark plan. */ private def collectFromPlan(plan: SparkPlan): Array[T] = { - plan.executeCollect().map(boundEnc.fromRow) + // This projection writes output to a `InternalRow`, which means applying this projection is not + // thread-safe. Here we create the projection inside this method to make `Dataset` thread-safe. + val objProj = GenerateSafeProjection.generate(deserializer :: Nil) + plan.executeCollect().map { row => + // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type + // parameter of its `get` method, so it's safe to use null here. + objProj(row).get(0, null).asInstanceOf[T] + } } private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { From cb54f297ae52690e6162b2bab9a3940d38ff82f2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 26 Oct 2017 17:39:53 -0700 Subject: [PATCH 08/11] [SPARK-22356][SQL] data source table should support overlapped columns between data and partition schema This is a regression introduced by #14207. After Spark 2.1, we store the inferred schema when creating the table, to avoid inferring schema again at read path. However, there is one special case: overlapped columns between data and partition. For this case, it breaks the assumption of table schema that there is on ovelap between data and partition schema, and partition columns should be at the end. The result is, for Spark 2.1, the table scan has incorrect schema that puts partition columns at the end. For Spark 2.2, we add a check in CatalogTable to validate table schema, which fails at this case. To fix this issue, a simple and safe approach is to fallback to old behavior when overlapeed columns detected, i.e. store empty schema in metastore. new regression test Author: Wenchen Fan Closes #19579 from cloud-fan/bug2. --- .../command/createDataSourceTables.scala | 35 +++++++++++++---- .../datasources/HadoopFsRelation.scala | 25 ++++++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 16 ++++++++ .../HiveExternalCatalogVersionsSuite.scala | 38 ++++++++++++++----- 4 files changed, 89 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 2d890118ae0a5..d05af89df38db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.types.StructType /** * A command used to create a data source table. @@ -87,14 +88,32 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo } } - val newTable = table.copy( - schema = dataSource.schema, - partitionColumnNames = partitionColumnNames, - // If metastore partition management for file source tables is enabled, we start off with - // partition provider hive, but no partitions in the metastore. The user has to call - // `msck repair table` to populate the table partitions. - tracksPartitionsInCatalog = partitionColumnNames.nonEmpty && - sessionState.conf.manageFilesourcePartitions) + val newTable = dataSource match { + // Since Spark 2.1, we store the inferred schema of data source in metastore, to avoid + // inferring the schema again at read path. However if the data source has overlapped columns + // between data and partition schema, we can't store it in metastore as it breaks the + // assumption of table schema. Here we fallback to the behavior of Spark prior to 2.1, store + // empty schema in metastore and infer it at runtime. Note that this also means the new + // scalable partitioning handling feature(introduced at Spark 2.1) is disabled in this case. + case r: HadoopFsRelation if r.overlappedPartCols.nonEmpty => + logWarning("It is not recommended to create a table with overlapped data and partition " + + "columns, as Spark cannot store a valid table schema and has to infer it at runtime, " + + "which hurts performance. Please check your data files and remove the partition " + + "columns in it.") + table.copy(schema = new StructType(), partitionColumnNames = Nil) + + case _ => + table.copy( + schema = dataSource.schema, + partitionColumnNames = partitionColumnNames, + // If metastore partition management for file source tables is enabled, we start off with + // partition provider hive, but no partitions in the metastore. The user has to call + // `msck repair table` to populate the table partitions. + tracksPartitionsInCatalog = partitionColumnNames.nonEmpty && + sessionState.conf.manageFilesourcePartitions) + + } + // We will return Nil or throw exception at the beginning if the table already exists, so when // we reach here, the table should not exist and we should set `ignoreIfExists` to false. sessionState.catalog.createTable(newTable, ignoreIfExists = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 9a08524476baa..89d8a85a9cbd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import scala.collection.mutable import org.apache.spark.sql.{SparkSession, SQLContext} @@ -50,15 +52,22 @@ case class HadoopFsRelation( override def sqlContext: SQLContext = sparkSession.sqlContext - val schema: StructType = { - val getColName: (StructField => String) = - if (sparkSession.sessionState.conf.caseSensitiveAnalysis) _.name else _.name.toLowerCase - val overlappedPartCols = mutable.Map.empty[String, StructField] - partitionSchema.foreach { partitionField => - if (dataSchema.exists(getColName(_) == getColName(partitionField))) { - overlappedPartCols += getColName(partitionField) -> partitionField - } + private def getColName(f: StructField): String = { + if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { + f.name + } else { + f.name.toLowerCase(Locale.ROOT) + } + } + + val overlappedPartCols = mutable.Map.empty[String, StructField] + partitionSchema.foreach { partitionField => + if (dataSchema.exists(getColName(_) == getColName(partitionField))) { + overlappedPartCols += getColName(partitionField) -> partitionField } + } + + val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c6a6efda59879..3750551d7f530 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2646,4 +2646,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-22356: overlapped columns between data and partition schema in data source tables") { + withTempPath { path => + Seq((1, 1, 1), (1, 2, 1)).toDF("i", "p", "j") + .write.mode("overwrite").parquet(new File(path, "p=1").getCanonicalPath) + withTable("t") { + sql(s"create table t using parquet options(path='${path.getCanonicalPath}')") + // We should respect the column order in data schema. + assert(spark.table("t").columns === Array("i", "p", "j")) + checkAnswer(spark.table("t"), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + // The DESC TABLE should report same schema as table scan. + assert(sql("desc t").select("col_name") + .as[String].collect().mkString(",").contains("i,p,j")) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 5f8c9d5799662..6859432c406a9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -40,7 +40,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") // For local test, you can set `sparkTestingDir` to a static value like `/tmp/test-spark`, to // avoid downloading Spark of different versions in each run. - private val sparkTestingDir = Utils.createTempDir(namePrefix = "test-spark") + private val sparkTestingDir = new File("/tmp/test-spark") private val unusedJar = TestUtils.createJarWithClasses(Seq.empty) override def afterAll(): Unit = { @@ -77,35 +77,38 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { super.beforeAll() val tempPyFile = File.createTempFile("test", ".py") + // scalastyle:off line.size.limit Files.write(tempPyFile.toPath, s""" |from pyspark.sql import SparkSession + |import os | |spark = SparkSession.builder.enableHiveSupport().getOrCreate() |version_index = spark.conf.get("spark.sql.test.version.index", None) | |spark.sql("create table data_source_tbl_{} using json as select 1 i".format(version_index)) | - |spark.sql("create table hive_compatible_data_source_tbl_" + version_index + \\ - | " using parquet as select 1 i") + |spark.sql("create table hive_compatible_data_source_tbl_{} using parquet as select 1 i".format(version_index)) | |json_file = "${genDataDir("json_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file) - |spark.sql("create table external_data_source_tbl_" + version_index + \\ - | "(i int) using json options (path '{}')".format(json_file)) + |spark.sql("create table external_data_source_tbl_{}(i int) using json options (path '{}')".format(version_index, json_file)) | |parquet_file = "${genDataDir("parquet_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.parquet(parquet_file) - |spark.sql("create table hive_compatible_external_data_source_tbl_" + version_index + \\ - | "(i int) using parquet options (path '{}')".format(parquet_file)) + |spark.sql("create table hive_compatible_external_data_source_tbl_{}(i int) using parquet options (path '{}')".format(version_index, parquet_file)) | |json_file2 = "${genDataDir("json2_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file2) - |spark.sql("create table external_table_without_schema_" + version_index + \\ - | " using json options (path '{}')".format(json_file2)) + |spark.sql("create table external_table_without_schema_{} using json options (path '{}')".format(version_index, json_file2)) + | + |parquet_file2 = "${genDataDir("parquet2_")}" + str(version_index) + |spark.range(1, 3).selectExpr("1 as i", "cast(id as int) as p", "1 as j").write.parquet(os.path.join(parquet_file2, "p=1")) + |spark.sql("create table tbl_with_col_overlap_{} using parquet options(path '{}')".format(version_index, parquet_file2)) | |spark.sql("create view v_{} as select 1 i".format(version_index)) """.stripMargin.getBytes("utf8")) + // scalastyle:on line.size.limit PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => val sparkHome = new File(sparkTestingDir, s"spark-$version") @@ -153,6 +156,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { .enableHiveSupport() .getOrCreate() spark = session + import session.implicits._ testingVersions.indices.foreach { index => Seq( @@ -194,6 +198,22 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // test permanent view checkAnswer(sql(s"select i from v_$index"), Row(1)) + + // SPARK-22356: overlapped columns between data and partition schema in data source tables + val tbl_with_col_overlap = s"tbl_with_col_overlap_$index" + // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0. + if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") { + spark.sql("msck repair table " + tbl_with_col_overlap) + assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p")) + checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + assert(sql("desc " + tbl_with_col_overlap).select("col_name") + .as[String].collect().mkString(",").contains("i,j,p")) + } else { + assert(spark.table(tbl_with_col_overlap).columns === Array("i", "p", "j")) + checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + assert(sql("desc " + tbl_with_col_overlap).select("col_name") + .as[String].collect().mkString(",").contains("i,p,j")) + } } } } From cac6506caa2f4188a82fa6e5a35ee6254beaba17 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 28 Oct 2017 18:24:18 -0700 Subject: [PATCH 09/11] [SPARK-19727][SQL][FOLLOWUP] Fix for round function that modifies original column ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/17075 , to fix the bug in codegen path. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #19576 from cloud-fan/bug. (cherry picked from commit 7fdacbc77bbcf98c2c045a1873e749129769dcc0) Signed-off-by: gatorsmile --- .../sql/catalyst/CatalystTypeConverters.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 3 +- .../expressions/decimalExpressions.scala | 2 +- .../expressions/mathExpressions.scala | 10 ++---- .../org/apache/spark/sql/types/Decimal.scala | 31 ++++++++++--------- .../apache/spark/sql/types/DecimalSuite.scala | 2 +- .../apache/spark/sql/MathFunctionsSuite.scala | 12 +++++++ 7 files changed, 36 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d4ebdb139fe0f..474ec592201d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -310,7 +310,7 @@ object CatalystTypeConverters { case d: JavaBigInteger => Decimal(d) case d: Decimal => d } - decimal.toPrecision(dataType.precision, dataType.scale).orNull + decimal.toPrecision(dataType.precision, dataType.scale) } override def toScala(catalystValue: Decimal): JavaBigDecimal = { if (catalystValue == null) null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 43df19ba009a8..16ab3359db610 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -387,10 +387,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String /** * Create new `Decimal` with precision and scale given in `decimalType` (if any), * returning null if it overflows or creating a new `value` and returning it if successful. - * */ private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal = - value.toPrecision(decimalType.precision, decimalType.scale).orNull + value.toPrecision(decimalType.precision, decimalType.scale) private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index c2211ae5d594b..752dea23e1f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -85,7 +85,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = - input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull + input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 1c61428c57f18..42d668958d2ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1030,7 +1030,7 @@ abstract class RoundBase(child: Expression, scale: Expression, dataType match { case DecimalType.Fixed(_, s) => val decimal = input1.asInstanceOf[Decimal] - decimal.toPrecision(decimal.precision, s, mode).orNull + decimal.toPrecision(decimal.precision, s, mode) case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => @@ -1062,12 +1062,8 @@ abstract class RoundBase(child: Expression, scale: Expression, val evaluationCode = dataType match { case DecimalType.Fixed(_, s) => s""" - if (${ce.value}.changePrecision(${ce.value}.precision(), ${s}, - java.math.BigDecimal.${modeStr})) { - ${ev.value} = ${ce.value}; - } else { - ${ev.isNull} = true; - }""" + ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr()); + ${ev.isNull} = ${ev.value} == null;""" case ByteType => if (_scale < 0) { s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 1f1fb51addfd8..6da4f28b12962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -234,22 +234,17 @@ final class Decimal extends Ordered[Decimal] with Serializable { changePrecision(precision, scale, ROUND_HALF_UP) } - def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match { - case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP) - case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN) - } - /** * Create new `Decimal` with given precision and scale. * - * @return `Some(decimal)` if successful or `None` if overflow would occur + * @return a non-null `Decimal` value if successful or `null` if overflow would occur. */ private[sql] def toPrecision( precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = { + roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = { val copy = clone() - if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None + if (copy.changePrecision(precision, scale, roundMode)) copy else null } /** @@ -257,8 +252,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { * * @return true if successful, false if overflow would occur */ - private[sql] def changePrecision(precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value): Boolean = { + private[sql] def changePrecision( + precision: Int, + scale: Int, + roundMode: BigDecimal.RoundingMode.Value): Boolean = { // fast path for UnsafeProjection if (precision == this.precision && scale == this.scale) { return true @@ -393,14 +390,20 @@ final class Decimal extends Ordered[Decimal] with Serializable { def floor: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision - toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse( - throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) + val res = toPrecision(newPrecision, 0, ROUND_FLOOR) + if (res == null) { + throw new AnalysisException(s"Overflow when setting precision to $newPrecision") + } + res } def ceil: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision - toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse( - throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) + val res = toPrecision(newPrecision, 0, ROUND_CEILING) + if (res == null) { + throw new AnalysisException(s"Overflow when setting precision to $newPrecision") + } + res } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 144f3d688d402..f4cdb7058ab1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -213,7 +213,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(d.changePrecision(10, 0, mode)) assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") - val copy = d.toPrecision(10, 0, mode).orNull + val copy = d.toPrecision(10, 0, mode) assert(copy !== null) assert(d.ne(copy)) assert(d === copy) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index c2d08a06569bf..5be8c581e9ddb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -258,6 +258,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("round/bround with table columns") { + withTable("t") { + Seq(BigDecimal("5.9")).toDF("i").write.saveAsTable("t") + checkAnswer( + sql("select i, round(i) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) + checkAnswer( + sql("select i, bround(i) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) + } + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } From f973587c9d593557db2e50d1d2ebb4d2e052e174 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sun, 29 Oct 2017 18:53:47 -0700 Subject: [PATCH 10/11] [SPARK-22344][SPARKR] Set java.io.tmpdir for SparkR tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR sets the java.io.tmpdir for CRAN checksĀ and also disables the hsperfdata for the JVM when running CRAN checks. Together this prevents files from being left behind in `/tmp` ## How was this patch tested? Tested manually on a clean EC2 machine Author: Shivaram Venkataraman Closes #19589 from shivaram/sparkr-tmpdir-clean. (cherry picked from commit 1fe27612d7bcb8b6478a36bc16ddd4802e4ee2fc) Signed-off-by: Shivaram Venkataraman --- R/pkg/inst/tests/testthat/test_basic.R | 6 ++++-- R/pkg/tests/run-all.R | 9 +++++++++ R/pkg/vignettes/sparkr-vignettes.Rmd | 8 +++++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R index de47162d5325f..823d26f12feee 100644 --- a/R/pkg/inst/tests/testthat/test_basic.R +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -18,7 +18,8 @@ context("basic tests for CRAN") test_that("create DataFrame from list or data.frame", { - sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, + sparkConfig = sparkRTestConfig) i <- 4 df <- createDataFrame(data.frame(dummy = 1:i)) @@ -49,7 +50,8 @@ test_that("create DataFrame from list or data.frame", { }) test_that("spark.glm and predict", { - sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, + sparkConfig = sparkRTestConfig) training <- suppressWarnings(createDataFrame(iris)) # gaussian family diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 0aefd8006caa4..3f432f7c44089 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -36,8 +36,17 @@ invisible(lapply(sparkRWhitelistSQLDirs, sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkRTestMaster <- "local[1]" +sparkRTestConfig <- list() if (identical(Sys.getenv("NOT_CRAN"), "true")) { sparkRTestMaster <- "" +} else { + # Disable hsperfdata on CRAN + old_java_opt <- Sys.getenv("_JAVA_OPTIONS") + Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt)) + tmpDir <- tempdir() + tmpArg <- paste0("-Djava.io.tmpdir=", tmpDir) + sparkRTestConfig <- list(spark.driver.extraJavaOptions = tmpArg, + spark.executor.extraJavaOptions = tmpArg) } test_package("SparkR") diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index c97ba5f9a1351..240dda38fee82 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -36,6 +36,12 @@ opts_hooks$set(eval = function(options) { } options }) +r_tmp_dir <- tempdir() +tmp_arg <- paste("-Djava.io.tmpdir=", r_tmp_dir, sep = "") +sparkSessionConfig <- list(spark.driver.extraJavaOptions = tmp_arg, + spark.executor.extraJavaOptions = tmp_arg) +old_java_opt <- Sys.getenv("_JAVA_OPTIONS") +Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt, sep = " ")) ``` ## Overview @@ -57,7 +63,7 @@ We use default settings in which it runs in local mode. It auto downloads Spark ```{r, include=FALSE} install.spark() -sparkR.session(master = "local[1]") +sparkR.session(master = "local[1]", sparkConfig = sparkSessionConfig, enableHiveSupport = FALSE) ``` ```{r, eval=FALSE} sparkR.session() From 7f8236c93560e9fc3ff16d397c007d1d17d1f4db Mon Sep 17 00:00:00 2001 From: Jen-Ming Chung Date: Mon, 30 Oct 2017 09:09:11 +0100 Subject: [PATCH 11/11] [SPARK-22291][SQL] Conversion error when transforming array types of uuid, inet and cidr to StingType in PostgreSQL ## What changes were proposed in this pull request? This PR fixes the conversion error when transforming array types of `uuid`, `inet` and `cidr` to `StingType` in PostgreSQL. ## How was this patch tested? Added test in `PostgresIntegrationSuite`. Author: Jen-Ming Chung Closes #19604 from jmchung/SPARK-22291-FOLLOWUP. --- .../sql/jdbc/PostgresIntegrationSuite.scala | 30 +++++++++++++++++++ .../datasources/jdbc/JdbcUtils.scala | 5 ++-- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index a1a065a443e67..fa3889fd6b76d 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -55,6 +55,19 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { + "null, null, null, null, null, " + "null, null, null, null, null, null, null)" ).executeUpdate() + + conn.prepareStatement("CREATE TABLE st_with_array (c0 uuid, c1 inet, c2 cidr," + + "c3 json, c4 jsonb, c5 uuid[], c6 inet[], c7 cidr[], c8 json[], c9 jsonb[])") + .executeUpdate() + conn.prepareStatement("INSERT INTO st_with_array VALUES ( " + + "'0a532531-cdf1-45e3-963d-5de90b6a30f1', '172.168.22.1', '192.168.100.128/25', " + + """'{"a": "foo", "b": "bar"}', '{"a": 1, "b": 2}', """ + + "ARRAY['7be8aaf8-650e-4dbb-8186-0a749840ecf2'," + + "'205f9bfc-018c-4452-a605-609c0cfad228']::uuid[], ARRAY['172.16.0.41', " + + "'172.16.0.42']::inet[], ARRAY['192.168.0.0/24', '10.1.0.0/16']::cidr[], " + + """ARRAY['{"a": "foo", "b": "bar"}', '{"a": 1, "b": 2}']::json[], """ + + """ARRAY['{"a": 1, "b": 2, "c": 3}']::jsonb[])""" + ).executeUpdate() } test("Type mapping for various types") { @@ -126,4 +139,21 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(schema(0).dataType == FloatType) assert(schema(1).dataType == ShortType) } + + test("SPARK-22291: Conversion error when transforming array types of " + + "uuid, inet and cidr to StingType in PostgreSQL") { + val df = sqlContext.read.jdbc(jdbcUrl, "st_with_array", new Properties) + val rows = df.collect() + assert(rows(0).getString(0) == "0a532531-cdf1-45e3-963d-5de90b6a30f1") + assert(rows(0).getString(1) == "172.168.22.1") + assert(rows(0).getString(2) == "192.168.100.128/25") + assert(rows(0).getString(3) == "{\"a\": \"foo\", \"b\": \"bar\"}") + assert(rows(0).getString(4) == "{\"a\": 1, \"b\": 2}") + assert(rows(0).getSeq(5) == Seq("7be8aaf8-650e-4dbb-8186-0a749840ecf2", + "205f9bfc-018c-4452-a605-609c0cfad228")) + assert(rows(0).getSeq(6) == Seq("172.16.0.41", "172.16.0.42")) + assert(rows(0).getSeq(7) == Seq("192.168.0.0/24", "10.1.0.0/16")) + assert(rows(0).getSeq(8) == Seq("""{"a": "foo", "b": "bar"}""", """{"a": 1, "b": 2}""")) + assert(rows(0).getSeq(9) == Seq("""{"a": 1, "b": 2, "c": 3}""")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 0183805d56257..ce0610fc09394 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -440,8 +440,9 @@ object JdbcUtils extends Logging { case StringType => (array: Object) => - array.asInstanceOf[Array[java.lang.String]] - .map(UTF8String.fromString) + // some underlying types are not String such as uuid, inet, cidr, etc. + array.asInstanceOf[Array[java.lang.Object]] + .map(obj => if (obj == null) null else UTF8String.fromString(obj.toString)) case DateType => (array: Object) =>