diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e2d72a549e6b..00f9817b5397 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -23,9 +23,9 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} -import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.sources.HadoopFsRelation @@ -128,6 +128,34 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * Buckets the output by the given columns. If specified, the output is laid out on the file + * system similar to Hive's bucketing scheme. + * + * This is applicable for Parquet, JSON and ORC. + * + * @since 2.0 + */ + @scala.annotation.varargs + def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = { + this.numBuckets = Option(numBuckets) + this.bucketColumnNames = Option(colName +: colNames) + this + } + + /** + * Sorts the output in each bucket by the given columns. + * + * This is applicable for Parquet, JSON and ORC. + * + * @since 2.0 + */ + @scala.annotation.varargs + def sortBy(colName: String, colNames: String*): DataFrameWriter = { + this.sortColumnNames = Option(colName +: colNames) + this + } + /** * Saves the content of the [[DataFrame]] at the specified path. * @@ -144,10 +172,12 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def save(): Unit = { + assertNotBucketed() ResolvedDataSource( df.sqlContext, source, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + getBucketSpec, mode, extraOptions.toMap, df) @@ -166,6 +196,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def insertInto(tableIdent: TableIdentifier): Unit = { + assertNotBucketed() val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) val overwrite = mode == SaveMode.Overwrite @@ -188,13 +219,47 @@ final class DataFrameWriter private[sql](df: DataFrame) { ifNotExists = false)).toRdd } - private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { parCols => - parCols.map { col => - df.logicalPlan.output - .map(_.name) - .find(df.sqlContext.analyzer.resolver(_, col)) - .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + - s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => + cols.map(normalize(_, "Partition")) + } + + private def normalizedBucketColNames: Option[Seq[String]] = bucketColumnNames.map { cols => + cols.map(normalize(_, "Bucketing")) + } + + private def normalizedSortColNames: Option[Seq[String]] = sortColumnNames.map { cols => + cols.map(normalize(_, "Sorting")) + } + + private def getBucketSpec: Option[BucketSpec] = { + if (sortColumnNames.isDefined) { + require(numBuckets.isDefined, "sortBy must be used together with bucketBy") + } + + for { + n <- numBuckets + } yield { + require(n > 0 && n < 100000, "Bucket number must be greater than 0 and less than 100000.") + BucketSpec(n, normalizedBucketColNames.get, normalizedSortColNames.getOrElse(Nil)) + } + } + + /** + * The given column name may not be equal to any of the existing column names if we were in + * case-insensitive context. Normalize the given column name to the real one so that we don't + * need to care about case sensitivity afterwards. + */ + private def normalize(columnName: String, columnType: String): String = { + val validColumnNames = df.logicalPlan.output.map(_.name) + validColumnNames.find(df.sqlContext.analyzer.resolver(_, columnName)) + .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " + + s"existing columns (${validColumnNames.mkString(", ")})")) + } + + private def assertNotBucketed(): Unit = { + if (numBuckets.isDefined || sortColumnNames.isDefined) { + throw new IllegalArgumentException( + "Currently we don't support writing bucketed data to this data source.") } } @@ -244,6 +309,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { source, temporary = false, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + getBucketSpec, mode, extraOptions.toMap, df.logicalPlan) @@ -372,4 +438,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var partitioningColumns: Option[Seq[String]] = None + private var bucketColumnNames: Option[Seq[String]] = None + + private var numBuckets: Option[Int] = None + + private var sortColumnNames: Option[Seq[String]] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6cf75bc17039..482130a18d93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -382,13 +382,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case c: CreateTableUsing if c.temporary && c.allowExisting => sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsSelect(tableIdent, provider, true, partitionsCols, mode, opts, query) - if partitionsCols.nonEmpty => + case c: CreateTableUsingAsSelect if c.temporary && c.partitionColumns.nonEmpty => sys.error("Cannot create temporary partitioned table.") - case CreateTableUsingAsSelect(tableIdent, provider, true, _, mode, opts, query) => + case c: CreateTableUsingAsSelect if c.temporary => val cmd = CreateTempTableUsingAsSelect( - tableIdent, provider, Array.empty[String], mode, opts, query) + c.tableIdent, c.provider, Array.empty[String], c.mode, c.options, c.child) ExecutedCommand(cmd) :: Nil case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index 48eff62b297f..d8d21b06b8b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -109,6 +109,7 @@ class DDLParser(parseQuery: String => LogicalPlan) provider, temp.isDefined, Array.empty[String], + bucketSpec = None, mode, options, queryPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 38152d0cf1a4..7a8691e7cb9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -125,7 +125,7 @@ private[sql] case class InsertIntoHadoopFsRelation( |Actual: ${partitionColumns.mkString(", ")} """.stripMargin) - val writerContainer = if (partitionColumns.isEmpty) { + val writerContainer = if (partitionColumns.isEmpty && relation.bucketSpec.isEmpty) { new DefaultWriterContainer(relation, job, isAppend) } else { val output = df.queryExecution.executedPlan.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 0ca0a38f712c..ece9b8a9a917 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -210,6 +210,7 @@ object ResolvedDataSource extends Logging { sqlContext: SQLContext, provider: String, partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], data: DataFrame): ResolvedDataSource = { @@ -244,6 +245,7 @@ object ResolvedDataSource extends Logging { Array(outputPath.toString), Some(dataSchema.asNullable), Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), + bucketSpec, caseInsensitiveOptions) // For partitioned relation r, r.schema's column ordering can be different from the column diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 9f23d531072a..4f8524f4b967 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{IntegerType, StructType, StringType} import org.apache.spark.util.SerializableConfiguration @@ -121,9 +121,9 @@ private[sql] abstract class BaseWriterContainer( } } - protected def newOutputWriter(path: String): OutputWriter = { + protected def newOutputWriter(path: String, bucketId: Option[Int] = None): OutputWriter = { try { - outputWriterFactory.newInstance(path, dataSchema, taskAttemptContext) + outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext) } catch { case e: org.apache.hadoop.fs.FileAlreadyExistsException => if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) { @@ -312,19 +312,23 @@ private[sql] class DynamicPartitionWriterContainer( isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] - executorSideSetup(taskContext) + private val bucketSpec = relation.bucketSpec - var outputWritersCleared = false + private val bucketColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { + spec => spec.bucketColumnNames.map(c => inputSchema.find(_.name == c).get) + } - // Returns the partition key given an input row - val getPartitionKey = UnsafeProjection.create(partitionColumns, inputSchema) - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) + private val sortColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => inputSchema.find(_.name == c).get) + } + + private def bucketIdExpression: Option[Expression] = for { + BucketSpec(numBuckets, _, _) <- bucketSpec + } yield Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets)) - // Expressions that given a partition key build a string like: col1=val/col2=val/... - val partitionStringExpression = partitionColumns.zipWithIndex.flatMap { case (c, i) => + // Expressions that given a partition key build a string like: col1=val/col2=val/... + private def partitionStringExpression: Seq[Expression] = { + partitionColumns.zipWithIndex.flatMap { case (c, i) => val escaped = ScalaUDF( PartitioningUtils.escapePathName _, @@ -335,6 +339,121 @@ private[sql] class DynamicPartitionWriterContainer( val partitionName = Literal(c.name + "=") :: str :: Nil if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } + } + + private def getBucketIdFromKey(key: InternalRow): Option[Int] = { + if (bucketSpec.isDefined) { + Some(key.getInt(partitionColumns.length)) + } else { + None + } + } + + private def sameBucket(key1: UnsafeRow, key2: UnsafeRow): Boolean = { + val bucketIdIndex = partitionColumns.length + if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) { + false + } else { + var i = partitionColumns.length - 1 + while (i >= 0) { + val dt = partitionColumns(i).dataType + if (key1.get(i, dt) != key2.get(i, dt)) return false + i -= 1 + } + true + } + } + + private def sortBasedWrite( + sorter: UnsafeKVExternalSorter, + iterator: Iterator[InternalRow], + getSortingKey: UnsafeProjection, + getOutputRow: UnsafeProjection, + getPartitionString: UnsafeProjection, + outputWriters: java.util.HashMap[InternalRow, OutputWriter]): Unit = { + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) { + (key1, key2) => key1 != key2 + } else { + (key1, key2) => key1 == null || !sameBucket(key1, key2) + } + + val sortedIterator = sorter.sortedIterator() + var currentKey: UnsafeRow = null + var currentWriter: OutputWriter = null + try { + while (sortedIterator.next()) { + if (needNewWriter(currentKey, sortedIterator.getKey)) { + if (currentWriter != null) { + currentWriter.close() + } + currentKey = sortedIterator.getKey.copy() + logDebug(s"Writing partition: $currentKey") + + // Either use an existing file from before, or open a new one. + currentWriter = outputWriters.remove(currentKey) + if (currentWriter == null) { + currentWriter = newOutputWriter(currentKey, getPartitionString) + } + } + + currentWriter.writeInternal(sortedIterator.getValue) + } + } finally { + if (currentWriter != null) { currentWriter.close() } + } + } + + /** + * Open and returns a new OutputWriter given a partition key and optional bucket id. + * If bucket id is specified, we will append it to the end of the file name, but before the + * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + */ + private def newOutputWriter( + key: InternalRow, + getPartitionString: UnsafeProjection): OutputWriter = { + val configuration = taskAttemptContext.getConfiguration + val path = if (partitionColumns.nonEmpty) { + val partitionPath = getPartitionString(key).getString(0) + configuration.set( + "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) + new Path(getWorkPath, partitionPath).toString + } else { + configuration.set("spark.sql.sources.output.path", outputPath) + getWorkPath + } + val bucketId = getBucketIdFromKey(key) + val newWriter = super.newOutputWriter(path, bucketId) + newWriter.initConverter(dataSchema) + newWriter + } + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] + executorSideSetup(taskContext) + + var outputWritersCleared = false + + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val getSortingKey = + UnsafeProjection.create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema) + + val sortingKeySchema = if (bucketSpec.isEmpty) { + StructType.fromAttributes(partitionColumns) + } else { // If it's bucketed, we should also consider bucket id as part of the key. + val fields = StructType.fromAttributes(partitionColumns) + .add("bucketId", IntegerType, nullable = false) ++ StructType.fromAttributes(sortColumns) + StructType(fields) + } + + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) // Returns the partition path given a partition key. val getPartitionString = @@ -342,22 +461,34 @@ private[sql] class DynamicPartitionWriterContainer( // If anything below fails, we should abort the task. try { - // This will be filled in if we have to fall back on sorting. - var sorter: UnsafeKVExternalSorter = null + // If there is no sorting columns, we set sorter to null and try the hash-based writing first, + // and fill the sorter if there are too many writers and we need to fall back on sorting. + // If there are sorting columns, then we have to sort the data anyway, and no need to try the + // hash-based writing first. + var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) { + new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + } else { + null + } while (iterator.hasNext && sorter == null) { val inputRow = iterator.next() - val currentKey = getPartitionKey(inputRow) + // When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key. + val currentKey = getSortingKey(inputRow) var currentWriter = outputWriters.get(currentKey) if (currentWriter == null) { if (outputWriters.size < maxOpenFiles) { - currentWriter = newOutputWriter(currentKey) + currentWriter = newOutputWriter(currentKey, getPartitionString) outputWriters.put(currentKey.copy(), currentWriter) currentWriter.writeInternal(getOutputRow(inputRow)) } else { logInfo(s"Maximum partitions reached, falling back on sorting.") sorter = new UnsafeKVExternalSorter( - StructType.fromAttributes(partitionColumns), + sortingKeySchema, StructType.fromAttributes(dataColumns), SparkEnv.get.blockManager, TaskContext.get().taskMemoryManager().pageSizeBytes) @@ -369,39 +500,15 @@ private[sql] class DynamicPartitionWriterContainer( } // If the sorter is not null that means that we reached the maxFiles above and need to finish - // using external sort. + // using external sort, or there are sorting columns and we need to sort the whole data set. if (sorter != null) { - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getPartitionKey(currentRow), getOutputRow(currentRow)) - } - - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val sortedIterator = sorter.sortedIterator() - var currentKey: InternalRow = null - var currentWriter: OutputWriter = null - try { - while (sortedIterator.next()) { - if (currentKey != sortedIterator.getKey) { - if (currentWriter != null) { - currentWriter.close() - } - currentKey = sortedIterator.getKey.copy() - logDebug(s"Writing partition: $currentKey") - - // Either use an existing file from before, or open a new one. - currentWriter = outputWriters.remove(currentKey) - if (currentWriter == null) { - currentWriter = newOutputWriter(currentKey) - } - } - - currentWriter.writeInternal(sortedIterator.getValue) - } - } finally { - if (currentWriter != null) { currentWriter.close() } - } + sortBasedWrite( + sorter, + iterator, + getSortingKey, + getOutputRow, + getPartitionString, + outputWriters) } commitTask() @@ -412,18 +519,6 @@ private[sql] class DynamicPartitionWriterContainer( throw new SparkException("Task failed while writing rows.", cause) } - /** Open and returns a new OutputWriter given a partition key. */ - def newOutputWriter(key: InternalRow): OutputWriter = { - val partitionPath = getPartitionString(key).getString(0) - val path = new Path(getWorkPath, partitionPath) - val configuration = taskAttemptContext.getConfiguration - configuration.set( - "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - val newWriter = super.newOutputWriter(path.toString) - newWriter.initConverter(dataSchema) - newWriter - } - def clearOutputWriters(): Unit = { if (!outputWritersCleared) { outputWriters.asScala.values.foreach(_.close()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala new file mode 100644 index 000000000000..82287c896713 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory, HadoopFsRelationProvider, HadoopFsRelation} +import org.apache.spark.sql.types.StructType + +/** + * A container for bucketing information. + * Bucketing is a technology for decomposing data sets into more manageable parts, and the number + * of buckets is fixed so it does not fluctuate with data. + * + * @param numBuckets number of buckets. + * @param bucketColumnNames the names of the columns that used to generate the bucket id. + * @param sortColumnNames the names of the columns that used to sort data in each bucket. + */ +private[sql] case class BucketSpec( + numBuckets: Int, + bucketColumnNames: Seq[String], + sortColumnNames: Seq[String]) + +private[sql] trait BucketedHadoopFsRelationProvider extends HadoopFsRelationProvider { + final override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = + // TODO: throw exception here as we won't call this method during execution, after bucketed read + // support is finished. + createRelation(sqlContext, paths, dataSchema, partitionColumns, bucketSpec = None, parameters) +} + +private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFactory { + final override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = + throw new UnsupportedOperationException("use bucket version") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index aed5d0dcf2d8..0897fcadbc01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -76,6 +76,7 @@ case class CreateTableUsingAsSelect( provider: String, temporary: Boolean, partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], child: LogicalPlan) extends UnaryNode { @@ -109,7 +110,14 @@ case class CreateTempTableUsingAsSelect( override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) - val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) + val resolved = ResolvedDataSource( + sqlContext, + provider, + partitionColumns, + bucketSpec = None, + mode, + options, + df) sqlContext.catalog.registerTable( tableIdent, DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 8bf538178b5d..b92edf65bfb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -34,13 +34,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration -class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "json" @@ -49,6 +49,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { new JSONRelation( @@ -56,6 +57,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { maybeDataSchema = dataSchema, maybePartitionSpec = None, userDefinedPartitionColumns = partitionColumns, + bucketSpec = bucketSpec, paths = paths, parameters = parameters)(sqlContext) } @@ -66,11 +68,29 @@ private[sql] class JSONRelation( val maybeDataSchema: Option[StructType], val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], + override val bucketSpec: Option[BucketSpec], override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) { + def this( + inputRDD: Option[RDD[String]], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + userDefinedPartitionColumns: Option[StructType], + paths: Array[String] = Array.empty[String], + parameters: Map[String, String] = Map.empty[String, String])(sqlContext: SQLContext) = { + this( + inputRDD, + maybeDataSchema, + maybePartitionSpec, + userDefinedPartitionColumns, + None, + paths, + parameters)(sqlContext) + } + val options: JSONOptions = JSONOptions.createFromConfigMap(parameters) /** Constraints to be imposed on schema to be stored. */ @@ -158,13 +178,14 @@ private[sql] class JSONRelation( partitionColumns) } - override def prepareJobForWrite(job: Job): OutputWriterFactory = { - new OutputWriterFactory { + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { + new BucketedOutputWriterFactory { override def newInstance( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new JsonOutputWriter(path, dataSchema, context) + new JsonOutputWriter(path, bucketId, dataSchema, context) } } } @@ -172,6 +193,7 @@ private[sql] class JSONRelation( private[json] class JsonOutputWriter( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter with Logging { @@ -188,7 +210,8 @@ private[json] class JsonOutputWriter( val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") } }.getRecordWriter(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 45f1dff96db0..4b375de05e9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -45,13 +45,13 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser -import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "parquet" @@ -60,13 +60,17 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { - new ParquetRelation(paths, schema, None, partitionColumns, parameters)(sqlContext) + new ParquetRelation(paths, schema, None, partitionColumns, bucketSpec, parameters)(sqlContext) } } // NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) +private[sql] class ParquetOutputWriter( + path: String, + bucketId: Option[Int], + context: TaskAttemptContext) extends OutputWriter { private val recordWriter: RecordWriter[Void, InternalRow] = { @@ -86,7 +90,8 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") } } } @@ -107,6 +112,7 @@ private[sql] class ParquetRelation( // This is for metastore conversion. private val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], + override val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) @@ -123,6 +129,7 @@ private[sql] class ParquetRelation( maybeDataSchema, maybePartitionSpec, maybePartitionSpec.map(_.partitionColumns), + None, parameters)(sqlContext) } @@ -216,7 +223,7 @@ private[sql] class ParquetRelation( override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum - override def prepareJobForWrite(job: Job): OutputWriterFactory = { + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { val conf = ContextUtil.getConfiguration(job) // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible @@ -276,10 +283,13 @@ private[sql] class ParquetRelation( sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED).name()) - new OutputWriterFactory { + new BucketedOutputWriterFactory { override def newInstance( - path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new ParquetOutputWriter(path, context) + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, bucketId, context) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 50ecbd35760d..d484403d1c64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} +import org.apache.spark.sql.catalyst.expressions.{RowOrdering, Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -165,22 +165,22 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - case CreateTableUsingAsSelect(tableIdent, _, _, partitionColumns, mode, _, query) => + case c: CreateTableUsingAsSelect => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. - if (mode == SaveMode.Overwrite && catalog.tableExists(tableIdent)) { + if (c.mode == SaveMode.Overwrite && catalog.tableExists(c.tableIdent)) { // Need to remove SubQuery operator. - EliminateSubQueries(catalog.lookupRelation(tableIdent)) match { + EliminateSubQueries(catalog.lookupRelation(c.tableIdent)) match { // Only do the check if the table is a data source table // (the relation is a BaseRelation). case l @ LogicalRelation(dest: BaseRelation, _) => // Get all input data source relations of the query. - val srcRelations = query.collect { + val srcRelations = c.child.collect { case LogicalRelation(src: BaseRelation, _) => src } if (srcRelations.contains(dest)) { failAnalysis( - s"Cannot overwrite table $tableIdent that is also being read from.") + s"Cannot overwrite table ${c.tableIdent} that is also being read from.") } else { // OK } @@ -192,7 +192,17 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } PartitioningUtils.validatePartitionColumnDataTypes( - query.schema, partitionColumns, catalog.conf.caseSensitiveAnalysis) + c.child.schema, c.partitionColumns, catalog.conf.caseSensitiveAnalysis) + + for { + spec <- c.bucketSpec + sortColumnName <- spec.sortColumnNames + sortColumn <- c.child.schema.find(_.name == sortColumnName) + } { + if (!RowOrdering.isOrderable(sortColumn.dataType)) { + failAnalysis(s"Cannot use ${sortColumn.dataType.simpleString} for sorting column.") + } + } case _ => // OK } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index f4c7f0a26932..c35f33132f60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} -import org.apache.spark.sql.execution.datasources.{Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.execution.datasources.{BucketSpec, Partition, PartitioningUtils, PartitionSpec} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -161,6 +161,20 @@ trait HadoopFsRelationProvider { dataSchema: Option[StructType], partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation + + // TODO: expose bucket API to users. + private[sql] def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], + parameters: Map[String, String]): HadoopFsRelation = { + if (bucketSpec.isDefined) { + throw new AnalysisException("Currently we don't support bucketing for this data source.") + } + createRelation(sqlContext, paths, dataSchema, partitionColumns, parameters) + } } /** @@ -351,7 +365,18 @@ abstract class OutputWriterFactory extends Serializable { * * @since 1.4.0 */ - def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter + def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter + + // TODO: expose bucket API to users. + private[sql] def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = + newInstance(path, dataSchema, context) } /** @@ -435,6 +460,9 @@ abstract class HadoopFsRelation private[sql]( private var _partitionSpec: PartitionSpec = _ + // TODO: expose bucket API to users. + private[sql] def bucketSpec: Option[BucketSpec] = None + private class FileStatusCache { var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 1616c4595221..43d84d507b20 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.execution.{datasources, FileRelation} -import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{Partition => ParquetPartition, _} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.HiveNativeCommand @@ -211,6 +211,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], provider: String, options: Map[String, String], isExternal: Boolean): Unit = { @@ -240,6 +241,25 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } + if (userSpecifiedSchema.isDefined && bucketSpec.isDefined) { + val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec.get + + tableProperties.put("spark.sql.sources.schema.numBuckets", numBuckets.toString) + tableProperties.put("spark.sql.sources.schema.numBucketCols", + bucketColumnNames.length.toString) + bucketColumnNames.zipWithIndex.foreach { case (bucketCol, index) => + tableProperties.put(s"spark.sql.sources.schema.bucketCol.$index", bucketCol) + } + + if (sortColumnNames.nonEmpty) { + tableProperties.put("spark.sql.sources.schema.numSortCols", + sortColumnNames.length.toString) + sortColumnNames.zipWithIndex.foreach { case (sortCol, index) => + tableProperties.put(s"spark.sql.sources.schema.sortCol.$index", sortCol) + } + } + } + if (userSpecifiedSchema.isEmpty && partitionColumns.length > 0) { // The table does not have a specified schema, which means that the schema will be inferred // when we load the table. So, we are not expecting partition columns and we will discover @@ -596,6 +616,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive conf.defaultDataSourceName, temporary = false, Array.empty[String], + bucketSpec = None, mode, options = Map.empty[String, String], child diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 0b4f5a0fd6ea..3687dd6f5a7a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -88,10 +88,9 @@ private[hive] trait HiveStrategies { tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath) ExecutedCommand(cmd) :: Nil - case CreateTableUsingAsSelect( - tableIdent, provider, false, partitionCols, mode, opts, query) => - val cmd = - CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query) + case c: CreateTableUsingAsSelect => + val cmd = CreateMetastoreDataSourceAsSelect(c.tableIdent, c.provider, c.partitionColumns, + c.bucketSpec, c.mode, c.options, c.child) ExecutedCommand(cmd) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 94210a5394f9..612f01cda88b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{BucketSpec, LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -151,6 +151,7 @@ case class CreateMetastoreDataSource( tableIdent, userSpecifiedSchema, Array.empty[String], + bucketSpec = None, provider, optionsWithPath, isExternal) @@ -164,6 +165,7 @@ case class CreateMetastoreDataSourceAsSelect( tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], query: LogicalPlan) extends RunnableCommand { @@ -254,8 +256,14 @@ case class CreateMetastoreDataSourceAsSelect( } // Create the relation based on the data of df. - val resolved = - ResolvedDataSource(sqlContext, provider, partitionColumns, mode, optionsWithPath, df) + val resolved = ResolvedDataSource( + sqlContext, + provider, + partitionColumns, + bucketSpec, + mode, + optionsWithPath, + df) if (createMetastoreTable) { // We will use the schema of resolved.relation as the schema of the table (instead of @@ -265,6 +273,7 @@ case class CreateMetastoreDataSourceAsSelect( tableIdent, Some(resolved.relation.schema), partitionColumns, + bucketSpec, provider, optionsWithPath, isExternal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 3538d642d523..14fa152c2331 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -37,13 +37,13 @@ import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "orc" @@ -52,17 +52,19 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { assert( sqlContext.isInstanceOf[HiveContext], "The ORC data source can only be used with HiveContext.") - new OrcRelation(paths, dataSchema, None, partitionColumns, parameters)(sqlContext) + new OrcRelation(paths, dataSchema, None, partitionColumns, bucketSpec, parameters)(sqlContext) } } private[orc] class OrcOutputWriter( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter with HiveInspectors { @@ -101,7 +103,8 @@ private[orc] class OrcOutputWriter( val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") val taskAttemptId = context.getTaskAttemptID val partition = taskAttemptId.getTaskID.getId - val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + val filename = f"part-r-$partition%05d-$uniqueWriteJobId$bucketString.orc" new OrcOutputFormat().getRecordWriter( new Path(path, filename).getFileSystem(conf), @@ -153,6 +156,7 @@ private[sql] class OrcRelation( maybeDataSchema: Option[StructType], maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], + override val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( @transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) @@ -169,6 +173,7 @@ private[sql] class OrcRelation( maybeDataSchema, maybePartitionSpec, maybePartitionSpec.map(_.partitionColumns), + None, parameters)(sqlContext) } @@ -205,7 +210,7 @@ private[sql] class OrcRelation( OrcTableScan(output, this, filters, inputPaths).execute() } - override def prepareJobForWrite(job: Job): OutputWriterFactory = { + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { job.getConfiguration match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) @@ -216,12 +221,13 @@ private[sql] class OrcRelation( classOf[MapRedOutputFormat[_, _]]) } - new OutputWriterFactory { + new BucketedOutputWriterFactory { override def newInstance( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new OrcOutputWriter(path, dataSchema, context) + new OrcOutputWriter(path, bucketId, dataSchema, context) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e22dac3bc9e8..202851ae1366 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -707,6 +707,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv tableIdent = TableIdentifier("wide_schema"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], + bucketSpec = None, provider = "json", options = Map("path" -> "just a dummy path"), isExternal = false) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala new file mode 100644 index 000000000000..579da0291f29 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{AnalysisException, QueryTest} + +class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + + test("bucketed by non-existing column") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[AnalysisException](df.write.bucketBy(2, "k").saveAsTable("tt")) + } + + test("numBuckets not greater than 0 or less than 100000") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.bucketBy(0, "i").saveAsTable("tt")) + intercept[IllegalArgumentException](df.write.bucketBy(100000, "i").saveAsTable("tt")) + } + + test("specify sorting columns without bucketing columns") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.sortBy("j").saveAsTable("tt")) + } + + test("sorting by non-orderable column") { + val df = Seq("a" -> Map(1 -> 1), "b" -> Map(2 -> 2)).toDF("i", "j") + intercept[AnalysisException](df.write.bucketBy(2, "i").sortBy("j").saveAsTable("tt")) + } + + test("write bucketed data to unsupported data source") { + val df = Seq(Tuple1("a"), Tuple1("b")).toDF("i") + intercept[AnalysisException](df.write.bucketBy(3, "i").format("text").saveAsTable("tt")) + } + + test("write bucketed data to non-hive-table or existing hive table") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").parquet("/tmp/path")) + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").json("/tmp/path")) + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt")) + } + + private val testFileName = """.*-(\d+)$""".r + private val otherFileName = """.*-(\d+)\..*""".r + private def getBucketId(fileName: String): Int = { + fileName match { + case testFileName(bucketId) => bucketId.toInt + case otherFileName(bucketId) => bucketId.toInt + } + } + + private def testBucketing( + dataDir: File, + source: String, + bucketCols: Seq[String], + sortCols: Seq[String] = Nil): Unit = { + val allBucketFiles = dataDir.listFiles().filterNot(f => + f.getName.startsWith(".") || f.getName.startsWith("_") + ) + val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) + assert(groupedBucketFiles.size <= 8) + + for ((bucketId, bucketFiles) <- groupedBucketFiles) { + for (bucketFile <- bucketFiles) { + val df = sqlContext.read.format(source).load(bucketFile.getAbsolutePath) + .select((bucketCols ++ sortCols).map(col): _*) + + if (sortCols.nonEmpty) { + checkAnswer(df.sort(sortCols.map(col): _*), df.collect()) + } + + val rows = df.select(bucketCols.map(col): _*).queryExecution.toRdd.map(_.copy()).collect() + + for (row <- rows) { + assert(row.isInstanceOf[UnsafeRow]) + val actualBucketId = (row.hashCode() % 8 + 8) % 8 + assert(actualBucketId == bucketId) + } + } + } + } + + private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + + test("write bucketed data") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .partitionBy("i") + .bucketBy(8, "j", "k") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + for (i <- 0 until 5) { + testBucketing(new File(tableDir, s"i=$i"), source, Seq("j", "k")) + } + } + } + } + + test("write bucketed data with sortBy") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .partitionBy("i") + .bucketBy(8, "j") + .sortBy("k") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + for (i <- 0 until 5) { + testBucketing(new File(tableDir, s"i=$i"), source, Seq("j"), Seq("k")) + } + } + } + } + + test("write bucketed data without partitionBy") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .bucketBy(8, "i", "j") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + testBucketing(tableDir, source, Seq("i", "j")) + } + } + } + + test("write bucketed data without partitionBy with sortBy") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .bucketBy(8, "i", "j") + .sortBy("k") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + testBucketing(tableDir, source, Seq("i", "j"), Seq("k")) + } + } + } +}