diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 3859fa8631b38..c0a954df2b10c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1174,6 +1174,9 @@ setMethod("collect", vec <- do.call(c, col) stopifnot(class(vec) != "list") class(vec) <- PRIMITIVE_TYPES[[colType]] + if (is.character(vec) && stringsAsFactors) { + vec <- as.factor(vec) + } df[[colIndex]] <- vec } else { df[[colIndex]] <- col diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R index de47162d5325f..823d26f12feee 100644 --- a/R/pkg/inst/tests/testthat/test_basic.R +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -18,7 +18,8 @@ context("basic tests for CRAN") test_that("create DataFrame from list or data.frame", { - sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, + sparkConfig = sparkRTestConfig) i <- 4 df <- createDataFrame(data.frame(dummy = 1:i)) @@ -49,7 +50,8 @@ test_that("create DataFrame from list or data.frame", { }) test_that("spark.glm and predict", { - sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, + sparkConfig = sparkRTestConfig) training <- suppressWarnings(createDataFrame(iris)) # gaussian family diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 12d8feff2ad68..50c60fe331078 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -483,6 +483,12 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) +test_that("SPARK-17902: collect() with stringsAsFactors enabled", { + df <- suppressWarnings(collect(createDataFrame(iris), stringsAsFactors = TRUE)) + expect_equal(class(iris$Species), class(df$Species)) + expect_equal(iris$Species, df$Species) +}) + test_that("SPARK-17811: can create DataFrame containing NA as date and time", { df <- data.frame( id = 1:2, diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 0aefd8006caa4..3f432f7c44089 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -36,8 +36,17 @@ invisible(lapply(sparkRWhitelistSQLDirs, sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkRTestMaster <- "local[1]" +sparkRTestConfig <- list() if (identical(Sys.getenv("NOT_CRAN"), "true")) { sparkRTestMaster <- "" +} else { + # Disable hsperfdata on CRAN + old_java_opt <- Sys.getenv("_JAVA_OPTIONS") + Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt)) + tmpDir <- tempdir() + tmpArg <- paste0("-Djava.io.tmpdir=", tmpDir) + sparkRTestConfig <- list(spark.driver.extraJavaOptions = tmpArg, + spark.executor.extraJavaOptions = tmpArg) } test_package("SparkR") diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index c97ba5f9a1351..240dda38fee82 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -36,6 +36,12 @@ opts_hooks$set(eval = function(options) { } options }) +r_tmp_dir <- tempdir() +tmp_arg <- paste("-Djava.io.tmpdir=", r_tmp_dir, sep = "") +sparkSessionConfig <- list(spark.driver.extraJavaOptions = tmp_arg, + spark.executor.extraJavaOptions = tmp_arg) +old_java_opt <- Sys.getenv("_JAVA_OPTIONS") +Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt, sep = " ")) ``` ## Overview @@ -57,7 +63,7 @@ We use default settings in which it runs in local mode. It auto downloads Spark ```{r, include=FALSE} install.spark() -sparkR.session(master = "local[1]") +sparkR.session(master = "local[1]", sparkConfig = sparkSessionConfig, enableHiveSupport = FALSE) ``` ```{r, eval=FALSE} sparkR.session() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 524f6970992a5..8c1e657ecc8e0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.util.UUID +import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi /** @@ -100,6 +101,10 @@ private[spark] case class TestBlockId(id: String) extends BlockId { override def name: String = "test_" + id } +@DeveloperApi +class UnrecognizedBlockId(name: String) + extends SparkException(s"Failed to parse $name into a block ID") + @DeveloperApi object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r @@ -109,10 +114,11 @@ object BlockId { val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r + val TEMP_LOCAL = "temp_local_([-A-Fa-f0-9]+)".r + val TEMP_SHUFFLE = "temp_shuffle_([-A-Fa-f0-9]+)".r val TEST = "test_(.*)".r - /** Converts a BlockId "name" String back into a BlockId. */ - def apply(id: String): BlockId = id match { + def apply(name: String): BlockId = name match { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => @@ -127,9 +133,13 @@ object BlockId { TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => StreamBlockId(streamId.toInt, uniqueId.toLong) + case TEMP_LOCAL(uuid) => + TempLocalBlockId(UUID.fromString(uuid)) + case TEMP_SHUFFLE(uuid) => + TempShuffleBlockId(UUID.fromString(uuid)) case TEST(value) => TestBlockId(value) case _ => - throw new IllegalStateException("Unrecognized BlockId: " + id) + throw new UnrecognizedBlockId(name) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3d43e3c367aac..a69bcc9259995 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -100,7 +100,16 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea /** List all the blocks currently stored on disk by the disk manager. */ def getAllBlocks(): Seq[BlockId] = { - getAllFiles().map(f => BlockId(f.getName)) + getAllFiles().flatMap { f => + try { + Some(BlockId(f.getName)) + } catch { + case _: UnrecognizedBlockId => + // Skip files which do not correspond to blocks, for example temporary + // files created by [[SortShuffleWriter]]. + None + } + } } /** Produces a unique block id and File suitable for storing local intermediate results. */ diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 489688cb0880f..2d5d3f863daa4 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -91,6 +91,54 @@ private[spark] object ClosureCleaner extends Logging { (seen - obj.getClass).toList } + /** Initializes the accessed fields for outer classes and their super classes. */ + private def initAccessedFields( + accessedFields: Map[Class[_], Set[String]], + outerClasses: Seq[Class[_]]): Unit = { + for (cls <- outerClasses) { + var currentClass = cls + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + accessedFields(currentClass) = Set.empty[String] + currentClass = currentClass.getSuperclass() + } + } + } + + /** Sets accessed fields for given class in clone object based on given object. */ + private def setAccessedFields( + outerClass: Class[_], + clone: AnyRef, + obj: AnyRef, + accessedFields: Map[Class[_], Set[String]]): Unit = { + for (fieldName <- accessedFields(outerClass)) { + val field = outerClass.getDeclaredField(fieldName) + field.setAccessible(true) + val value = field.get(obj) + field.set(clone, value) + } + } + + /** Clones a given object and sets accessed fields in cloned object. */ + private def cloneAndSetFields( + parent: AnyRef, + obj: AnyRef, + outerClass: Class[_], + accessedFields: Map[Class[_], Set[String]]): AnyRef = { + val clone = instantiateClass(outerClass, parent) + + var currentClass = outerClass + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + setAccessedFields(currentClass, clone, obj, accessedFields) + currentClass = currentClass.getSuperclass() + } + + clone + } + /** * Clean the given closure in place. * @@ -200,9 +248,8 @@ private[spark] object ClosureCleaner extends Logging { logDebug(s" + populating accessed fields because this is the starting closure") // Initialize accessed fields with the outer classes first // This step is needed to associate the fields to the correct classes later - for (cls <- outerClasses) { - accessedFields(cls) = Set[String]() - } + initAccessedFields(accessedFields, outerClasses) + // Populate accessed fields by visiting all fields and methods accessed by this and // all of its inner closures. If transitive cleaning is enabled, this may recursively // visits methods that belong to other classes in search of transitively referenced fields. @@ -248,13 +295,8 @@ private[spark] object ClosureCleaner extends Logging { // required fields from the original object. We need the parent here because the Java // language specification requires the first constructor parameter of any closure to be // its enclosing object. - val clone = instantiateClass(cls, parent) - for (fieldName <- accessedFields(cls)) { - val field = cls.getDeclaredField(fieldName) - field.setAccessible(true) - val value = field.get(obj) - field.set(clone, value) - } + val clone = cloneAndSetFields(parent, obj, cls, accessedFields) + // If transitive cleaning is enabled, we recursively clean any enclosing closure using // the already populated accessed fields map of the starting closure if (cleanTransitively && isClosure(clone.getClass)) { @@ -393,8 +435,15 @@ private[util] class FieldAccessFinder( if (!visitedMethods.contains(m)) { // Keep track of visited methods to avoid potential infinite cycles visitedMethods += m - ClosureCleaner.getClassReader(cl).accept( - new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + + var currentClass = cl + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + ClosureCleaner.getClassReader(currentClass).accept( + new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + currentClass = currentClass.getSuperclass() + } } } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index 89ed031b6fcd1..6bc63245f2a81 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -33,13 +33,8 @@ class BlockIdSuite extends SparkFunSuite { } test("test-bad-deserialization") { - try { - // Try to deserialize an invalid block id. + intercept[UnrecognizedBlockId] { BlockId("myblock") - fail() - } catch { - case e: IllegalStateException => // OK - case _: Throwable => fail() } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 7859b0bba2b48..0c4f3c48ef802 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.{File, FileWriter} +import java.util.UUID import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} @@ -79,6 +80,12 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } + test("SPARK-22227: non-block files are skipped") { + val file = diskBlockManager.getFile("unmanaged_file") + writeToFile(file, 10) + assert(diskBlockManager.getAllBlocks().isEmpty) + } + def writeToFile(file: File, numBytes: Int) { val writer = new FileWriter(file, true) for (i <- 0 until numBytes) writer.write(i) diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 4920b7ee8bfb4..9a19baee9569e 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -119,6 +119,63 @@ class ClosureCleanerSuite extends SparkFunSuite { test("createNullValue") { new TestCreateNullValue().run() } + + test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") { + val concreteObject = new TestAbstractClass { + val n2 = 222 + val s2 = "bbb" + val d2 = 2.0d + + def run(): Seq[(Int, Int, String, String, Double, Double)] = { + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1) + body(rdd) + } + } + + def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] = rdd.map { _ => + (n1, n2, s1, s2, d1, d2) + }.collect() + } + assert(concreteObject.run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) + } + + test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") { + val concreteObject = new TestAbstractClass2 { + val n2 = 222 + val s2 = "bbb" + val d2 = 2.0d + def getData: Int => (Int, Int, String, String, Double, Double) = _ => (n1, n2, s1, s2, d1, d2) + } + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1).map(concreteObject.getData) + assert(rdd.collect() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) + } + } + + test("SPARK-22328: multiple outer classes have the same parent class") { + val concreteObject = new TestAbstractClass2 { + + val innerObject = new TestAbstractClass2 { + override val n1 = 222 + override val s1 = "bbb" + } + + val innerObject2 = new TestAbstractClass2 { + override val n1 = 444 + val n3 = 333 + val s3 = "ccc" + val d3 = 3.0d + + def getData: Int => (Int, Int, String, String, Double, Double, Int, String) = + _ => (n1, n3, s1, s3, d1, d3, innerObject.n1, innerObject.s1) + } + } + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1).map(concreteObject.innerObject2.getData) + assert(rdd.collect() === Seq((444, 333, "aaa", "ccc", 1.0d, 3.0d, 222, "bbb"))) + } + } } // A non-serializable class we create in closures to make sure that we aren't @@ -377,3 +434,18 @@ class TestCreateNullValue { nestedClosure() } } + +abstract class TestAbstractClass extends Serializable { + val n1 = 111 + val s1 = "aaa" + protected val d1 = 1.0d + + def run(): Seq[(Int, Int, String, String, Double, Double)] + def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] +} + +abstract class TestAbstractClass2 extends Serializable { + val n1 = 111 + val s1 = "aaa" + protected val d1 = 1.0d +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index a1a065a443e67..fa3889fd6b76d 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -55,6 +55,19 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { + "null, null, null, null, null, " + "null, null, null, null, null, null, null)" ).executeUpdate() + + conn.prepareStatement("CREATE TABLE st_with_array (c0 uuid, c1 inet, c2 cidr," + + "c3 json, c4 jsonb, c5 uuid[], c6 inet[], c7 cidr[], c8 json[], c9 jsonb[])") + .executeUpdate() + conn.prepareStatement("INSERT INTO st_with_array VALUES ( " + + "'0a532531-cdf1-45e3-963d-5de90b6a30f1', '172.168.22.1', '192.168.100.128/25', " + + """'{"a": "foo", "b": "bar"}', '{"a": 1, "b": 2}', """ + + "ARRAY['7be8aaf8-650e-4dbb-8186-0a749840ecf2'," + + "'205f9bfc-018c-4452-a605-609c0cfad228']::uuid[], ARRAY['172.16.0.41', " + + "'172.16.0.42']::inet[], ARRAY['192.168.0.0/24', '10.1.0.0/16']::cidr[], " + + """ARRAY['{"a": "foo", "b": "bar"}', '{"a": 1, "b": 2}']::json[], """ + + """ARRAY['{"a": 1, "b": 2, "c": 3}']::jsonb[])""" + ).executeUpdate() } test("Type mapping for various types") { @@ -126,4 +139,21 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(schema(0).dataType == FloatType) assert(schema(1).dataType == ShortType) } + + test("SPARK-22291: Conversion error when transforming array types of " + + "uuid, inet and cidr to StingType in PostgreSQL") { + val df = sqlContext.read.jdbc(jdbcUrl, "st_with_array", new Properties) + val rows = df.collect() + assert(rows(0).getString(0) == "0a532531-cdf1-45e3-963d-5de90b6a30f1") + assert(rows(0).getString(1) == "172.168.22.1") + assert(rows(0).getString(2) == "192.168.100.128/25") + assert(rows(0).getString(3) == "{\"a\": \"foo\", \"b\": \"bar\"}") + assert(rows(0).getString(4) == "{\"a\": 1, \"b\": 2}") + assert(rows(0).getSeq(5) == Seq("7be8aaf8-650e-4dbb-8186-0a749840ecf2", + "205f9bfc-018c-4452-a605-609c0cfad228")) + assert(rows(0).getSeq(6) == Seq("172.16.0.41", "172.16.0.42")) + assert(rows(0).getSeq(7) == Seq("192.168.0.0/24", "10.1.0.0/16")) + assert(rows(0).getSeq(8) == Seq("""{"a": "foo", "b": "bar"}""", """{"a": 1, "b": 2}""")) + assert(rows(0).getSeq(9) == Seq("""{"a": 1, "b": 2, "c": 3}""")) + } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 865d4926da6a9..4353e3f263c51 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -232,20 +232,20 @@ public void run() { }; ServerConnection clientConnection = new ServerConnection(client, timeout); Thread clientThread = factory.newThread(clientConnection); - synchronized (timeout) { - clientThread.start(); - synchronized (clients) { - clients.add(clientConnection); - } - long timeoutMs = getConnectionTimeout(); - // 0 is used for testing to avoid issues with clock resolution / thread scheduling, - // and force an immediate timeout. - if (timeoutMs > 0) { - timeoutTimer.schedule(timeout, getConnectionTimeout()); - } else { - timeout.run(); - } + synchronized (clients) { + clients.add(clientConnection); } + + long timeoutMs = getConnectionTimeout(); + // 0 is used for testing to avoid issues with clock resolution / thread scheduling, + // and force an immediate timeout. + if (timeoutMs > 0) { + timeoutTimer.schedule(timeout, timeoutMs); + } else { + timeout.run(); + } + + clientThread.start(); } } catch (IOException ioe) { if (running) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 3a2be236f1257..7c9cf76763d01 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import scala.util.Random import breeze.linalg.{DenseVector => BDV, Vector => BV} -import breeze.stats.distributions.{Multinomial => BrzMultinomial} +import breeze.stats.distributions.{Multinomial => BrzMultinomial, RandBasis => BrzRandBasis} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial} @@ -329,6 +329,7 @@ object NaiveBayesSuite { val _pi = pi.map(math.exp) val _theta = theta.map(row => row.map(math.exp)) + implicit val rngForBrzMultinomial = BrzRandBasis.withSeed(seed) for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) val xi = modelType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d4ebdb139fe0f..474ec592201d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -310,7 +310,7 @@ object CatalystTypeConverters { case d: JavaBigInteger => Decimal(d) case d: Decimal => d } - decimal.toPrecision(dataType.precision, dataType.scale).orNull + decimal.toPrecision(dataType.precision, dataType.scale) } override def toScala(catalystValue: Decimal): JavaBigDecimal = { if (catalystValue == null) null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 43df19ba009a8..16ab3359db610 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -387,10 +387,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String /** * Create new `Decimal` with precision and scale given in `decimalType` (if any), * returning null if it overflows or creating a new `value` and returning it if successful. - * */ private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal = - value.toPrecision(decimalType.precision, decimalType.scale).orNull + value.toPrecision(decimalType.precision, decimalType.scale) private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index c2211ae5d594b..752dea23e1f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -85,7 +85,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = - input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull + input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 1c61428c57f18..42d668958d2ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1030,7 +1030,7 @@ abstract class RoundBase(child: Expression, scale: Expression, dataType match { case DecimalType.Fixed(_, s) => val decimal = input1.asInstanceOf[Decimal] - decimal.toPrecision(decimal.precision, s, mode).orNull + decimal.toPrecision(decimal.precision, s, mode) case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => @@ -1062,12 +1062,8 @@ abstract class RoundBase(child: Expression, scale: Expression, val evaluationCode = dataType match { case DecimalType.Fixed(_, s) => s""" - if (${ce.value}.changePrecision(${ce.value}.precision(), ${s}, - java.math.BigDecimal.${modeStr})) { - ${ev.value} = ${ce.value}; - } else { - ${ev.isNull} = true; - }""" + ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr()); + ${ev.isNull} = ${ev.value} == null;""" case ByteType => if (_scale < 0) { s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 1f1fb51addfd8..6da4f28b12962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -234,22 +234,17 @@ final class Decimal extends Ordered[Decimal] with Serializable { changePrecision(precision, scale, ROUND_HALF_UP) } - def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match { - case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP) - case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN) - } - /** * Create new `Decimal` with given precision and scale. * - * @return `Some(decimal)` if successful or `None` if overflow would occur + * @return a non-null `Decimal` value if successful or `null` if overflow would occur. */ private[sql] def toPrecision( precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = { + roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = { val copy = clone() - if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None + if (copy.changePrecision(precision, scale, roundMode)) copy else null } /** @@ -257,8 +252,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { * * @return true if successful, false if overflow would occur */ - private[sql] def changePrecision(precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value): Boolean = { + private[sql] def changePrecision( + precision: Int, + scale: Int, + roundMode: BigDecimal.RoundingMode.Value): Boolean = { // fast path for UnsafeProjection if (precision == this.precision && scale == this.scale) { return true @@ -393,14 +390,20 @@ final class Decimal extends Ordered[Decimal] with Serializable { def floor: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision - toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse( - throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) + val res = toPrecision(newPrecision, 0, ROUND_FLOOR) + if (res == null) { + throw new AnalysisException(s"Overflow when setting precision to $newPrecision") + } + res } def ceil: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision - toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse( - throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) + val res = toPrecision(newPrecision, 0, ROUND_CEILING) + if (res == null) { + throw new AnalysisException(s"Overflow when setting precision to $newPrecision") + } + res } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 144f3d688d402..f4cdb7058ab1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -213,7 +213,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(d.changePrecision(10, 0, mode)) assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") - val copy = d.toPrecision(10, 0, mode).orNull + val copy = d.toPrecision(10, 0, mode) assert(copy !== null) assert(d.ne(copy)) assert(d === copy) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a775fb8ed4ed3..1acbad960f1bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.ParseException @@ -195,15 +196,10 @@ class Dataset[T] private[sql]( */ private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder) - /** - * Encoder is used mostly as a container of serde expressions in Dataset. We build logical - * plans by these serde expressions and execute it within the query framework. However, for - * performance reasons we may want to use encoder as a function to deserialize internal rows to - * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its - * `fromRow` method later. - */ - private val boundEnc = - exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) + // The deserializer expression which can be used to build a projection and turn rows to objects + // of type T, after collecting rows to the driver side. + private val deserializer = + exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer private implicit def classTag = exprEnc.clsTag @@ -2418,7 +2414,15 @@ class Dataset[T] private[sql]( */ def toLocalIterator(): java.util.Iterator[T] = { withAction("toLocalIterator", queryExecution) { plan => - plan.executeToIterator().map(boundEnc.fromRow).asJava + // This projection writes output to a `InternalRow`, which means applying this projection is + // not thread-safe. Here we create the projection inside this method to make `Dataset` + // thread-safe. + val objProj = GenerateSafeProjection.generate(deserializer :: Nil) + plan.executeToIterator().map { row => + // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type + // parameter of its `get` method, so it's safe to use null here. + objProj(row).get(0, null).asInstanceOf[T] + }.asJava } } @@ -2851,7 +2855,14 @@ class Dataset[T] private[sql]( * Collect all elements from a spark plan. */ private def collectFromPlan(plan: SparkPlan): Array[T] = { - plan.executeCollect().map(boundEnc.fromRow) + // This projection writes output to a `InternalRow`, which means applying this projection is not + // thread-safe. Here we create the projection inside this method to make `Dataset` thread-safe. + val objProj = GenerateSafeProjection.generate(deserializer :: Nil) + plan.executeCollect().map { row => + // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type + // parameter of its `get` method, so it's safe to use null here. + objProj(row).get(0, null).asInstanceOf[T] + } } private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 2d890118ae0a5..d05af89df38db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.types.StructType /** * A command used to create a data source table. @@ -87,14 +88,32 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo } } - val newTable = table.copy( - schema = dataSource.schema, - partitionColumnNames = partitionColumnNames, - // If metastore partition management for file source tables is enabled, we start off with - // partition provider hive, but no partitions in the metastore. The user has to call - // `msck repair table` to populate the table partitions. - tracksPartitionsInCatalog = partitionColumnNames.nonEmpty && - sessionState.conf.manageFilesourcePartitions) + val newTable = dataSource match { + // Since Spark 2.1, we store the inferred schema of data source in metastore, to avoid + // inferring the schema again at read path. However if the data source has overlapped columns + // between data and partition schema, we can't store it in metastore as it breaks the + // assumption of table schema. Here we fallback to the behavior of Spark prior to 2.1, store + // empty schema in metastore and infer it at runtime. Note that this also means the new + // scalable partitioning handling feature(introduced at Spark 2.1) is disabled in this case. + case r: HadoopFsRelation if r.overlappedPartCols.nonEmpty => + logWarning("It is not recommended to create a table with overlapped data and partition " + + "columns, as Spark cannot store a valid table schema and has to infer it at runtime, " + + "which hurts performance. Please check your data files and remove the partition " + + "columns in it.") + table.copy(schema = new StructType(), partitionColumnNames = Nil) + + case _ => + table.copy( + schema = dataSource.schema, + partitionColumnNames = partitionColumnNames, + // If metastore partition management for file source tables is enabled, we start off with + // partition provider hive, but no partitions in the metastore. The user has to call + // `msck repair table` to populate the table partitions. + tracksPartitionsInCatalog = partitionColumnNames.nonEmpty && + sessionState.conf.manageFilesourcePartitions) + + } + // We will return Nil or throw exception at the beginning if the table already exists, so when // we reach here, the table should not exist and we should set `ignoreIfExists` to false. sessionState.catalog.createTable(newTable, ignoreIfExists = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 9a08524476baa..89d8a85a9cbd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import scala.collection.mutable import org.apache.spark.sql.{SparkSession, SQLContext} @@ -50,15 +52,22 @@ case class HadoopFsRelation( override def sqlContext: SQLContext = sparkSession.sqlContext - val schema: StructType = { - val getColName: (StructField => String) = - if (sparkSession.sessionState.conf.caseSensitiveAnalysis) _.name else _.name.toLowerCase - val overlappedPartCols = mutable.Map.empty[String, StructField] - partitionSchema.foreach { partitionField => - if (dataSchema.exists(getColName(_) == getColName(partitionField))) { - overlappedPartCols += getColName(partitionField) -> partitionField - } + private def getColName(f: StructField): String = { + if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { + f.name + } else { + f.name.toLowerCase(Locale.ROOT) + } + } + + val overlappedPartCols = mutable.Map.empty[String, StructField] + partitionSchema.foreach { partitionField => + if (dataSchema.exists(getColName(_) == getColName(partitionField))) { + overlappedPartCols += getColName(partitionField) -> partitionField } + } + + val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 0183805d56257..ce0610fc09394 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -440,8 +440,9 @@ object JdbcUtils extends Logging { case StringType => (array: Object) => - array.asInstanceOf[Array[java.lang.String]] - .map(UTF8String.fromString) + // some underlying types are not String such as uuid, inet, cidr, etc. + array.asInstanceOf[Array[java.lang.Object]] + .map(obj => if (obj == null) null else UTF8String.fromString(obj.toString)) case DateType => (array: Object) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index c2d08a06569bf..5be8c581e9ddb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -258,6 +258,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("round/bround with table columns") { + withTable("t") { + Seq(BigDecimal("5.9")).toDF("i").write.saveAsTable("t") + checkAnswer( + sql("select i, round(i) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) + checkAnswer( + sql("select i, bround(i) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) + } + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c6a6efda59879..3750551d7f530 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2646,4 +2646,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-22356: overlapped columns between data and partition schema in data source tables") { + withTempPath { path => + Seq((1, 1, 1), (1, 2, 1)).toDF("i", "p", "j") + .write.mode("overwrite").parquet(new File(path, "p=1").getCanonicalPath) + withTable("t") { + sql(s"create table t using parquet options(path='${path.getCanonicalPath}')") + // We should respect the column order in data schema. + assert(spark.table("t").columns === Array("i", "p", "j")) + checkAnswer(spark.table("t"), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + // The DESC TABLE should report same schema as table scan. + assert(sql("desc t").select("col_name") + .as[String].collect().mkString(",").contains("i,p,j")) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 5f8c9d5799662..6859432c406a9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -40,7 +40,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") // For local test, you can set `sparkTestingDir` to a static value like `/tmp/test-spark`, to // avoid downloading Spark of different versions in each run. - private val sparkTestingDir = Utils.createTempDir(namePrefix = "test-spark") + private val sparkTestingDir = new File("/tmp/test-spark") private val unusedJar = TestUtils.createJarWithClasses(Seq.empty) override def afterAll(): Unit = { @@ -77,35 +77,38 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { super.beforeAll() val tempPyFile = File.createTempFile("test", ".py") + // scalastyle:off line.size.limit Files.write(tempPyFile.toPath, s""" |from pyspark.sql import SparkSession + |import os | |spark = SparkSession.builder.enableHiveSupport().getOrCreate() |version_index = spark.conf.get("spark.sql.test.version.index", None) | |spark.sql("create table data_source_tbl_{} using json as select 1 i".format(version_index)) | - |spark.sql("create table hive_compatible_data_source_tbl_" + version_index + \\ - | " using parquet as select 1 i") + |spark.sql("create table hive_compatible_data_source_tbl_{} using parquet as select 1 i".format(version_index)) | |json_file = "${genDataDir("json_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file) - |spark.sql("create table external_data_source_tbl_" + version_index + \\ - | "(i int) using json options (path '{}')".format(json_file)) + |spark.sql("create table external_data_source_tbl_{}(i int) using json options (path '{}')".format(version_index, json_file)) | |parquet_file = "${genDataDir("parquet_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.parquet(parquet_file) - |spark.sql("create table hive_compatible_external_data_source_tbl_" + version_index + \\ - | "(i int) using parquet options (path '{}')".format(parquet_file)) + |spark.sql("create table hive_compatible_external_data_source_tbl_{}(i int) using parquet options (path '{}')".format(version_index, parquet_file)) | |json_file2 = "${genDataDir("json2_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file2) - |spark.sql("create table external_table_without_schema_" + version_index + \\ - | " using json options (path '{}')".format(json_file2)) + |spark.sql("create table external_table_without_schema_{} using json options (path '{}')".format(version_index, json_file2)) + | + |parquet_file2 = "${genDataDir("parquet2_")}" + str(version_index) + |spark.range(1, 3).selectExpr("1 as i", "cast(id as int) as p", "1 as j").write.parquet(os.path.join(parquet_file2, "p=1")) + |spark.sql("create table tbl_with_col_overlap_{} using parquet options(path '{}')".format(version_index, parquet_file2)) | |spark.sql("create view v_{} as select 1 i".format(version_index)) """.stripMargin.getBytes("utf8")) + // scalastyle:on line.size.limit PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => val sparkHome = new File(sparkTestingDir, s"spark-$version") @@ -153,6 +156,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { .enableHiveSupport() .getOrCreate() spark = session + import session.implicits._ testingVersions.indices.foreach { index => Seq( @@ -194,6 +198,22 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // test permanent view checkAnswer(sql(s"select i from v_$index"), Row(1)) + + // SPARK-22356: overlapped columns between data and partition schema in data source tables + val tbl_with_col_overlap = s"tbl_with_col_overlap_$index" + // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0. + if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") { + spark.sql("msck repair table " + tbl_with_col_overlap) + assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p")) + checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + assert(sql("desc " + tbl_with_col_overlap).select("col_name") + .as[String].collect().mkString(",").contains("i,j,p")) + } else { + assert(spark.table(tbl_with_col_overlap).columns === Array("i", "p", "j")) + checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + assert(sql("desc " + tbl_with_col_overlap).select("col_name") + .as[String].collect().mkString(",").contains("i,p,j")) + } } } }