From 8cb24942890153bad70e003110a28d6111f0c407 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 28 Dec 2015 20:15:34 +0800 Subject: [PATCH 01/15] write bucketed table --- .../spark/sql/catalyst/expressions/misc.scala | 218 ++++++++++++++++++ .../apache/spark/sql/DataFrameWriter.scala | 24 ++ .../spark/sql/execution/SparkStrategies.scala | 7 +- .../sql/execution/datasources/DDLParser.scala | 3 + .../InsertIntoHadoopFsRelation.scala | 3 + .../datasources/ResolvedDataSource.scala | 19 +- .../datasources/WriterContainer.scala | 56 ++++- .../spark/sql/execution/datasources/ddl.scala | 14 +- .../datasources/json/JSONRelation.scala | 35 ++- .../datasources/parquet/ParquetRelation.scala | 27 ++- .../sql/execution/datasources/rules.scala | 12 +- .../datasources/text/DefaultSource.scala | 21 +- .../apache/spark/sql/sources/interfaces.scala | 13 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 21 ++ .../spark/sql/hive/HiveStrategies.scala | 7 +- .../spark/sql/hive/execution/commands.scala | 21 +- .../spark/sql/hive/orc/OrcRelation.scala | 26 ++- .../sql/hive/MetastoreDataSourcesSuite.scala | 3 + .../sql/hive/execution/SQLQuerySuite.scala | 10 + .../sql/sources/SimpleTextRelation.scala | 58 ++++- 20 files changed, 545 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 97f276d49f08..8f0950731845 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -23,6 +23,8 @@ import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -176,3 +178,219 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp }) } } + +/** + * A function that calculates hash value for a group of expressions. + * + * The hash value for an expression depends on its type: + * - null: 0 + * - boolean: 1 for true, 0 for false. + * - byte, short, int: the input itself. + * - long: input XOR (input >>> 32) + * - float: java.lang.Float.floatToIntBits(input) + * - double: l = java.lang.Double.doubleToLongBits(input); l XOR (l >>> 32) + * - binary: java.util.Arrays.hashCode(input) + * - array: recursively calculate hash value for each element, and aggregate them by + * `result = result * 31 + elementHash` with an initial value `result = 0`. + * - map: recursively calculate hash value for each key-value pair, and aggregate + * them by `result += keyHash XOR valueHash`. + * - struct: similar to array, calculate hash value for each field and aggregate them. + * - other type: input.hashCode(). + * e.g. calculate hash value for string type by `UTF8String.hashCode()`. + * Finally we aggregate the hash values for each expression by `result = result * 31 + exprHash`. + * + * This hash algorithm follows hive's bucketing hash function, so that our bucketing function can + * be compatible with hive's, e.g. we can benefit from bucketing even the data source is mixed with + * hive tables. + */ +case class Hash(children: Seq[Expression]) extends Expression { + + override def dataType: DataType = IntegerType + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + var result = 0 + for (e <- children) { + val hashValue = computeHash(e.eval(input), e.dataType) + result = result * 31 + hashValue + } + result + } + + private def computeHash(v: Any, dataType: DataType): Int = v match { + case null => 0 + case b: Boolean => if (b) 1 else 0 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + + case array: ArrayData => + val elementType = dataType.asInstanceOf[ArrayType].elementType + var result = 0 + var i = 0 + while (i < array.numElements()) { + val hashValue = computeHash(array.get(i, elementType), elementType) + result = result * 31 + hashValue + i += 1 + } + result + + case map: MapData => + val mapType = dataType.asInstanceOf[MapType] + val keys = map.keyArray() + val values = map.valueArray() + var result = 0 + var i = 0 + while (i < map.numElements()) { + val keyHash = computeHash(keys.get(i, mapType.keyType), mapType.keyType) + val valueHash = computeHash(values.get(i, mapType.valueType), mapType.valueType) + result += keyHash ^ valueHash + i += 1 + } + result + + case row: InternalRow => + val fieldTypes = dataType.asInstanceOf[StructType].map(_.dataType) + var result = 0 + var i = 0 + while (i < row.numFields) { + val hashValue = computeHash(row.get(i, fieldTypes(i)), fieldTypes(i)) + result = result * 31 + hashValue + i += 1 + } + result + + case other => other.hashCode() + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val expressions = children.map(_.gen(ctx)) + + val updateHashResult = expressions.zip(children.map(_.dataType)).map { case (expr, dataType) => + val hash = computeHash(expr.value, dataType, ctx) + s""" + if (${expr.isNull}) { + ${ev.value} *= 31; + } else { + ${hash.code} + ${ev.value} = ${ev.value} * 31 + ${hash.value}; + } + """ + }.mkString("\n") + + s""" + ${expressions.map(_.code).mkString("\n")} + final boolean ${ev.isNull} = false; + int ${ev.value} = 0; + $updateHashResult + """ + } + + private def computeHash( + input: String, + dataType: DataType, + ctx: CodeGenContext): GeneratedExpressionCode = { + def simpleHashValue(v: String) = GeneratedExpressionCode(code = "", isNull = "false", value = v) + + dataType match { + case NullType => simpleHashValue("0") + case BooleanType => simpleHashValue(s"($input ? 1 : 0)") + case ByteType | ShortType | IntegerType | DateType => simpleHashValue(input) + case LongType | TimestampType => simpleHashValue(s"(int) ($input ^ ($input >>> 32))") + case FloatType => simpleHashValue(s"Float.floatToIntBits($input)") + case DoubleType => + val longBits = ctx.freshName("longBits") + GeneratedExpressionCode( + code = s"final long $longBits = Double.doubleToLongBits($input);", + isNull = "false", + value = s"(int) ($longBits ^ ($longBits >>> 32))" + ) + case BinaryType => simpleHashValue(s"java.util.Arrays.hashCode($input)") + + case ArrayType(et, _) => + val arrayHash = ctx.freshName("arrayHash") + val index = ctx.freshName("index") + val element = ctx.freshName("element") + val hash = computeHash(element, et, ctx) + val code = s""" + int $arrayHash = 0; + for (int $index = 0; $index < $input.numElements(); $index++) { + if ($input.isNullAt($index)) { + $arrayHash *= 31; + } else { + final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)}; + ${hash.code} + $arrayHash = $arrayHash * 31 + ${hash.value}; + } + } + """ + GeneratedExpressionCode(code = code, isNull = "false", value = arrayHash) + + case MapType(kt, vt, _) => + val mapHash = ctx.freshName("mapHash") + + val keys = ctx.freshName("keys") + val key = ctx.freshName("key") + val keyHash = computeHash(key, kt, ctx) + + val values = ctx.freshName("values") + val value = ctx.freshName("value") + val valueHash = computeHash(value, vt, ctx) + + val index = ctx.freshName("index") + + val code = s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + int $mapHash = 0; + for (int $index = 0; $index < $input.numElements(); $index++) { + final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)}; + ${keyHash.code} + if ($values.isNullAt($index)) { + $mapHash += ${keyHash.value}; + } else { + final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)}; + ${valueHash.code} + $mapHash += ${keyHash.value} ^ ${valueHash.value}; + } + } + """ + GeneratedExpressionCode(code = code, isNull = "false", value = mapHash) + + case StructType(fields) => + val structHash = ctx.freshName("structHash") + + val updateHashResult = fields.zipWithIndex.map { case (f, i) => + val jt = ctx.javaType(f.dataType) + val fieldValue = ctx.freshName(f.name) + val fieldHash = computeHash(fieldValue, f.dataType, ctx) + s""" + if ($input.isNullAt($i)) { + $structHash *= 31; + } else { + final $jt $fieldValue = ${ctx.getValue(input, f.dataType, i.toString)}; + ${fieldHash.code} + $structHash = $structHash * 31 + ${fieldHash.value}; + } + """ + }.mkString("\n") + + val code = s""" + int $structHash = 0; + $updateHashResult + """ + GeneratedExpressionCode(code = code, isNull = "false", value = structHash) + + case other => simpleHashValue(s"$input.hashCode()") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 03867beb7822..3a435a43dff2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -129,6 +129,19 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + @scala.annotation.varargs + def bucketBy(numBuckets: Int, colNames: String*): DataFrameWriter = { + this.numBuckets = Some(numBuckets) + this.bucketingColumns = Option(colNames) + this + } + + @scala.annotation.varargs + def sortBy(colNames: String*): DataFrameWriter = { + this.sortingColumns = Option(colNames) + this + } + /** * Saves the content of the [[DataFrame]] at the specified path. * @@ -149,6 +162,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { df.sqlContext, source, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + numBuckets.getOrElse(0), + bucketingColumns.map(_.toArray).getOrElse(Array.empty[String]), + sortingColumns.map(_.toArray).getOrElse(Array.empty[String]), mode, extraOptions.toMap, df) @@ -245,6 +261,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { source, temporary = false, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + numBuckets.getOrElse(0), + bucketingColumns.map(_.toArray).getOrElse(Array.empty[String]), + sortingColumns.map(_.toArray).getOrElse(Array.empty[String]), mode, extraOptions.toMap, df.logicalPlan) @@ -368,4 +387,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var partitioningColumns: Option[Seq[String]] = None + private var bucketingColumns: Option[Seq[String]] = None + + private var numBuckets: Option[Int] = None + + private var sortingColumns: Option[Seq[String]] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 183d9b65023b..358694b12d57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -382,13 +382,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case c: CreateTableUsing if c.temporary && c.allowExisting => sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsSelect(tableIdent, provider, true, partitionsCols, mode, opts, query) - if partitionsCols.nonEmpty => + case c: CreateTableUsingAsSelect if c.temporary && c.partitionColumns.nonEmpty => sys.error("Cannot create temporary partitioned table.") - case CreateTableUsingAsSelect(tableIdent, provider, true, _, mode, opts, query) => + case c: CreateTableUsingAsSelect if c.temporary => val cmd = CreateTempTableUsingAsSelect( - tableIdent, provider, Array.empty[String], mode, opts, query) + c.tableIdent, c.provider, Array.empty[String], c.mode, c.options, c.child) ExecutedCommand(cmd) :: Nil case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index f22508b21090..29e2dae26b33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -109,6 +109,9 @@ class DDLParser(parseQuery: String => LogicalPlan) provider, temp.isDefined, Array.empty[String], + 0, + Array.empty[String], + Array.empty[String], mode, options, queryPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 735d52f80886..f0d8b20d0f09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -135,6 +135,9 @@ private[sql] case class InsertIntoHadoopFsRelation( relation, job, partitionOutput, + relation.numBuckets, + relation.bucketColumns.map(c => dataOutput.find(_.name == c).get), + relation.sortColumns.map(c => dataOutput.find(_.name == c).get), dataOutput, output, PartitioningUtils.DEFAULT_PARTITION_NAME, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index e02ee6cd6b90..7fa87d665e98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -142,6 +142,9 @@ object ResolvedDataSource extends Logging { paths, Some(dataSchema), maybePartitionsSchema, + 0, + Array.empty, + Array.empty, caseInsensitiveOptions) case dataSource: org.apache.spark.sql.sources.RelationProvider => throw new AnalysisException(s"$className does not allow user-specified schemas.") @@ -173,7 +176,15 @@ object ResolvedDataSource extends Logging { SparkHadoopUtil.get.globPathIfNecessary(qualified).map(_.toString) } } - dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) + dataSource.createRelation( + sqlContext, + paths, + None, + None, + 0, + Array.empty, + Array.empty, + caseInsensitiveOptions) case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => throw new AnalysisException( s"A schema needs to be specified when using $className.") @@ -210,6 +221,9 @@ object ResolvedDataSource extends Logging { sqlContext: SQLContext, provider: String, partitionColumns: Array[String], + numBuckets: Int, + bucketColumns: Array[String], + sortColumns: Array[String], mode: SaveMode, options: Map[String, String], data: DataFrame): ResolvedDataSource = { @@ -244,6 +258,9 @@ object ResolvedDataSource extends Logging { Array(outputPath.toString), Some(dataSchema.asNullable), Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), + numBuckets, + bucketColumns, + sortColumns, caseInsensitiveOptions) // For partitioned relation r, r.schema's column ordering can be different from the column diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index ad5536725889..f5627c372a78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.types.{StructType, StringType} +import org.apache.spark.sql.types.{IntegerType, StructType, StringType} import org.apache.spark.util.SerializableConfiguration @@ -124,9 +124,9 @@ private[sql] abstract class BaseWriterContainer( } } - protected def newOutputWriter(path: String): OutputWriter = { + protected def newOutputWriter(path: String, bucketId: Option[Int]): OutputWriter = { try { - outputWriterFactory.newInstance(path, dataSchema, taskAttemptContext) + outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext) } catch { case e: org.apache.hadoop.fs.FileAlreadyExistsException => if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) { @@ -252,7 +252,7 @@ private[sql] class DefaultWriterContainer( executorSideSetup(taskContext) val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) configuration.set("spark.sql.sources.output.path", outputPath) - val writer = newOutputWriter(getWorkPath) + val writer = newOutputWriter(getWorkPath, None) writer.initConverter(dataSchema) var writerClosed = false @@ -310,6 +310,9 @@ private[sql] class DynamicPartitionWriterContainer( relation: HadoopFsRelation, job: Job, partitionColumns: Seq[Attribute], + numBuckets: Int, + bucketColumns: Seq[Attribute], + sortColumns: Seq[Attribute], dataColumns: Seq[Attribute], inputSchema: Seq[Attribute], defaultPartitionName: String, @@ -323,8 +326,30 @@ private[sql] class DynamicPartitionWriterContainer( var outputWritersCleared = false - // Returns the partition key given an input row - val getPartitionKey = UnsafeProjection.create(partitionColumns, inputSchema) + // TODO: this follows hive, but can we just use pmod? + val buckNumExpr = Remainder(Abs(Hash(bucketColumns)), Literal(numBuckets)) + + val getKey = if (numBuckets == 0) { + UnsafeProjection.create(partitionColumns, inputSchema) + } else { + UnsafeProjection.create(partitionColumns ++ (buckNumExpr +: sortColumns), inputSchema) + } + + val keySchema = if (numBuckets == 0) { + StructType.fromAttributes(partitionColumns) + } else { + val fields = StructType.fromAttributes(partitionColumns) + .add("bucketId", IntegerType).fields ++ + StructType.fromAttributes(sortColumns).fields + StructType(fields) + } + + def getBucketId(key: InternalRow): Option[Int] = if (numBuckets > 0) { + Some(key.getInt(partitionColumns.length)) + } else { + None + } + // Returns the data columns to be written given an input row val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) @@ -345,10 +370,18 @@ private[sql] class DynamicPartitionWriterContainer( // If anything below fails, we should abort the task. try { // This will be filled in if we have to fall back on sorting. - var sorter: UnsafeKVExternalSorter = null + var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) { + new UnsafeKVExternalSorter( + keySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + } else { + null + } while (iterator.hasNext && sorter == null) { val inputRow = iterator.next() - val currentKey = getPartitionKey(inputRow) + val currentKey = getKey(inputRow) var currentWriter = outputWriters.get(currentKey) if (currentWriter == null) { @@ -359,7 +392,7 @@ private[sql] class DynamicPartitionWriterContainer( } else { logInfo(s"Maximum partitions reached, falling back on sorting.") sorter = new UnsafeKVExternalSorter( - StructType.fromAttributes(partitionColumns), + keySchema, StructType.fromAttributes(dataColumns), SparkEnv.get.blockManager, TaskContext.get().taskMemoryManager().pageSizeBytes) @@ -375,7 +408,7 @@ private[sql] class DynamicPartitionWriterContainer( if (sorter != null) { while (iterator.hasNext) { val currentRow = iterator.next() - sorter.insertKV(getPartitionKey(currentRow), getOutputRow(currentRow)) + sorter.insertKV(getKey(currentRow), getOutputRow(currentRow)) } logInfo(s"Sorting complete. Writing out partition files one at a time.") @@ -417,11 +450,12 @@ private[sql] class DynamicPartitionWriterContainer( /** Open and returns a new OutputWriter given a partition key. */ def newOutputWriter(key: InternalRow): OutputWriter = { val partitionPath = getPartitionString(key).getString(0) + val bucketId = getBucketId(key) val path = new Path(getWorkPath, partitionPath) val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) configuration.set( "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - val newWriter = super.newOutputWriter(path.toString) + val newWriter = super.newOutputWriter(path.toString, bucketId) newWriter.initConverter(dataSchema) newWriter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index e759c011e75d..6956b667c10b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -76,6 +76,9 @@ case class CreateTableUsingAsSelect( provider: String, temporary: Boolean, partitionColumns: Array[String], + numBuckets: Int, + bucketColumns: Array[String], + sortColumns: Array[String], mode: SaveMode, options: Map[String, String], child: LogicalPlan) extends UnaryNode { @@ -109,7 +112,16 @@ case class CreateTempTableUsingAsSelect( override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) - val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) + val resolved = ResolvedDataSource( + sqlContext, + provider, + partitionColumns, + 0, + Array.empty, + Array.empty, + mode, + options, + df) sqlContext.catalog.registerTable( tableIdent, DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 3e61ba35bea8..9d5bf9042b8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -51,6 +51,9 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], + numBuckets: Int, + bucketColumns: Array[String], + sortColumns: Array[String], parameters: Map[String, String]): HadoopFsRelation = { new JSONRelation( @@ -58,6 +61,9 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { maybeDataSchema = dataSchema, maybePartitionSpec = None, userDefinedPartitionColumns = partitionColumns, + numBuckets, + bucketColumns, + sortColumns, paths = paths, parameters = parameters)(sqlContext) } @@ -68,11 +74,33 @@ private[sql] class JSONRelation( val maybeDataSchema: Option[StructType], val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], + val numBuckets: Int, + val bucketColumns: Array[String], + val sortColumns: Array[String], override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) { + def this( + inputRDD: Option[RDD[String]], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + userDefinedPartitionColumns: Option[StructType], + paths: Array[String] = Array.empty[String], + parameters: Map[String, String] = Map.empty[String, String])(sqlContext: SQLContext) = { + this( + inputRDD, + maybeDataSchema, + maybePartitionSpec, + userDefinedPartitionColumns, + 0, + Array.empty, + Array.empty, + paths, + parameters)(sqlContext) + } + val options: JSONOptions = JSONOptions.createFromConfigMap(parameters) /** Constraints to be imposed on schema to be stored. */ @@ -164,9 +192,10 @@ private[sql] class JSONRelation( new OutputWriterFactory { override def newInstance( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new JsonOutputWriter(path, dataSchema, context) + new JsonOutputWriter(path, bucketId, dataSchema, context) } } } @@ -174,6 +203,7 @@ private[sql] class JSONRelation( private[json] class JsonOutputWriter( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter with SparkHadoopMapRedUtil with Logging { @@ -190,7 +220,8 @@ private[json] class JsonOutputWriter( val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") } }.getRecordWriter(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 1af2a394f399..37b47cdacc78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -60,13 +60,20 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], + numBuckets: Int, + bucketColumns: Array[String], + sortColumns: Array[String], parameters: Map[String, String]): HadoopFsRelation = { - new ParquetRelation(paths, schema, None, partitionColumns, parameters)(sqlContext) + new ParquetRelation(paths, schema, None, partitionColumns, numBuckets, bucketColumns, + sortColumns, parameters)(sqlContext) } } // NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) +private[sql] class ParquetOutputWriter( + path: String, + bucketId: Option[Int], + context: TaskAttemptContext) extends OutputWriter { private val recordWriter: RecordWriter[Void, InternalRow] = { @@ -86,7 +93,8 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") } } } @@ -107,6 +115,9 @@ private[sql] class ParquetRelation( // This is for metastore conversion. private val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], + val numBuckets: Int, + val bucketColumns: Array[String], + val sortColumns: Array[String], parameters: Map[String, String])( val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) @@ -123,6 +134,9 @@ private[sql] class ParquetRelation( maybeDataSchema, maybePartitionSpec, maybePartitionSpec.map(_.partitionColumns), + 0, + Array.empty, + Array.empty, parameters)(sqlContext) } @@ -282,8 +296,11 @@ private[sql] class ParquetRelation( new OutputWriterFactory { override def newInstance( - path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new ParquetOutputWriter(path, context) + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, bucketId, context) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 1a8e7ab202dc..21830e1a8647 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -165,22 +165,22 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - case CreateTableUsingAsSelect(tableIdent, _, _, partitionColumns, mode, _, query) => + case c: CreateTableUsingAsSelect => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. - if (mode == SaveMode.Overwrite && catalog.tableExists(tableIdent)) { + if (c.mode == SaveMode.Overwrite && catalog.tableExists(c.tableIdent)) { // Need to remove SubQuery operator. - EliminateSubQueries(catalog.lookupRelation(tableIdent)) match { + EliminateSubQueries(catalog.lookupRelation(c.tableIdent)) match { // Only do the check if the table is a data source table // (the relation is a BaseRelation). case l @ LogicalRelation(dest: BaseRelation, _) => // Get all input data source relations of the query. - val srcRelations = query.collect { + val srcRelations = c.child.collect { case LogicalRelation(src: BaseRelation, _) => src } if (srcRelations.contains(dest)) { failAnalysis( - s"Cannot overwrite table $tableIdent that is also being read from.") + s"Cannot overwrite table ${c.tableIdent} that is also being read from.") } else { // OK } @@ -192,7 +192,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } PartitioningUtils.validatePartitionColumnDataTypes( - query.schema, partitionColumns, catalog.conf.caseSensitiveAnalysis) + c.child.schema, c.partitionColumns, catalog.conf.caseSensitiveAnalysis) case _ => // OK } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 4a1cbe4c38fa..ae6d1a9f7a15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -48,9 +48,13 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], + numBuckets: Int, + bucketColumns: Array[String], + sortColumns: Array[String], parameters: Map[String, String]): HadoopFsRelation = { dataSchema.foreach(verifySchema) - new TextRelation(None, partitionColumns, paths)(sqlContext) + new TextRelation( + None, partitionColumns, numBuckets, bucketColumns, sortColumns, paths)(sqlContext) } override def shortName(): String = "text" @@ -71,6 +75,9 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { private[sql] class TextRelation( val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], + val numBuckets: Int, + val bucketColumns: Array[String], + val sortColumns: Array[String], override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) @@ -119,9 +126,10 @@ private[sql] class TextRelation( new OutputWriterFactory { override def newInstance( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new TextOutputWriter(path, dataSchema, context) + new TextOutputWriter(path, bucketId, dataSchema, context) } } } @@ -137,7 +145,11 @@ private[sql] class TextRelation( } } -class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) +class TextOutputWriter( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext) extends OutputWriter with SparkHadoopMapRedUtil { @@ -150,7 +162,8 @@ class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemp val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") } }.getRecordWriter(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index fc8ce6901dfc..e54d3f3c04f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -160,6 +160,9 @@ trait HadoopFsRelationProvider { paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], + numBuckets: Int, + bucketColumns: Array[String], + sortColumns: Array[String], parameters: Map[String, String]): HadoopFsRelation } @@ -351,7 +354,11 @@ abstract class OutputWriterFactory extends Serializable { * * @since 1.4.0 */ - def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter + def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter } /** @@ -577,6 +584,10 @@ abstract class HadoopFsRelation private[sql]( final def partitionColumns: StructType = userDefinedPartitionColumns.getOrElse(partitionSpec.partitionColumns) + def numBuckets: Int + def bucketColumns: Array[String] + def sortColumns: Array[String] + /** * Optional user defined partition columns. * diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f099e146d1e3..043cc27a0f11 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -211,6 +211,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], partitionColumns: Array[String], + numBuckets: Int, + bucketColumns: Array[String], + sortColumns: Array[String], provider: String, options: Map[String, String], isExternal: Boolean): Unit = { @@ -240,6 +243,21 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } + if (userSpecifiedSchema.isDefined && bucketColumns.length > 0) { + tableProperties.put("spark.sql.sources.schema.numBuckets", numBuckets.toString) + tableProperties.put("spark.sql.sources.schema.numBucketCols", bucketColumns.length.toString) + bucketColumns.zipWithIndex.foreach { case (bucketCol, index) => + tableProperties.put(s"spark.sql.sources.schema.bucketCol.$index", bucketCol) + } + } + + if (userSpecifiedSchema.isDefined && sortColumns.length > 0) { + tableProperties.put("spark.sql.sources.schema.numSortCols", sortColumns.length.toString) + sortColumns.zipWithIndex.foreach { case (sortCol, index) => + tableProperties.put(s"spark.sql.sources.schema.sortCol.$index", sortCol) + } + } + if (userSpecifiedSchema.isEmpty && partitionColumns.length > 0) { // The table does not have a specified schema, which means that the schema will be inferred // when we load the table. So, we are not expecting partition columns and we will discover @@ -596,6 +614,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive conf.defaultDataSourceName, temporary = false, Array.empty[String], + 0, + Array.empty[String], + Array.empty[String], mode, options = Map.empty[String, String], child diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index d38ad9127327..6800e0ef51b0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -89,10 +89,9 @@ private[hive] trait HiveStrategies { tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath) ExecutedCommand(cmd) :: Nil - case CreateTableUsingAsSelect( - tableIdent, provider, false, partitionCols, mode, opts, query) => - val cmd = - CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query) + case c: CreateTableUsingAsSelect => + val cmd = CreateMetastoreDataSourceAsSelect(c.tableIdent, c.provider, c.partitionColumns, + c.numBuckets, c.bucketColumns, c.sortColumns, c.mode, c.options, c.child) ExecutedCommand(cmd) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 94210a5394f9..ecb2a09190f7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -151,6 +151,9 @@ case class CreateMetastoreDataSource( tableIdent, userSpecifiedSchema, Array.empty[String], + 0, + Array.empty[String], + Array.empty[String], provider, optionsWithPath, isExternal) @@ -164,6 +167,9 @@ case class CreateMetastoreDataSourceAsSelect( tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], + numBuckets: Int, + bucketColumns: Array[String], + sortColumns: Array[String], mode: SaveMode, options: Map[String, String], query: LogicalPlan) extends RunnableCommand { @@ -254,8 +260,16 @@ case class CreateMetastoreDataSourceAsSelect( } // Create the relation based on the data of df. - val resolved = - ResolvedDataSource(sqlContext, provider, partitionColumns, mode, optionsWithPath, df) + val resolved = ResolvedDataSource( + sqlContext, + provider, + partitionColumns, + numBuckets, + bucketColumns, + sortColumns, + mode, + optionsWithPath, + df) if (createMetastoreTable) { // We will use the schema of resolved.relation as the schema of the table (instead of @@ -265,6 +279,9 @@ case class CreateMetastoreDataSourceAsSelect( tableIdent, Some(resolved.relation.schema), partitionColumns, + numBuckets, + bucketColumns, + sortColumns, provider, optionsWithPath, isExternal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 1136670b7a0e..0f1e01b9ee93 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -54,17 +54,29 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], + numBuckets: Int, + bucketColumns: Array[String], + sortColumns: Array[String], parameters: Map[String, String]): HadoopFsRelation = { assert( sqlContext.isInstanceOf[HiveContext], "The ORC data source can only be used with HiveContext.") - new OrcRelation(paths, dataSchema, None, partitionColumns, parameters)(sqlContext) + new OrcRelation( + paths, + dataSchema, + None, + partitionColumns, + numBuckets, + bucketColumns, + sortColumns, + parameters)(sqlContext) } } private[orc] class OrcOutputWriter( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { @@ -103,7 +115,8 @@ private[orc] class OrcOutputWriter( val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) val partition = taskAttemptId.getTaskID.getId - val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + val filename = f"part-r-$partition%05d-$uniqueWriteJobId$bucketString.orc" new OrcOutputFormat().getRecordWriter( new Path(path, filename).getFileSystem(conf), @@ -155,6 +168,9 @@ private[sql] class OrcRelation( maybeDataSchema: Option[StructType], maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], + val numBuckets: Int, + val bucketColumns: Array[String], + val sortColumns: Array[String], parameters: Map[String, String])( @transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) @@ -171,6 +187,9 @@ private[sql] class OrcRelation( maybeDataSchema, maybePartitionSpec, maybePartitionSpec.map(_.partitionColumns), + 0, + Array.empty, + Array.empty, parameters)(sqlContext) } @@ -221,9 +240,10 @@ private[sql] class OrcRelation( new OutputWriterFactory { override def newInstance( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new OrcOutputWriter(path, dataSchema, context) + new OrcOutputWriter(path, bucketId, dataSchema, context) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index f74eb1500b98..9b34b917b840 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -707,6 +707,9 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv tableIdent = TableIdentifier("wide_schema"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], + numBuckets = 0, + bucketColumns = Array.empty[String], + sortColumns = Array.empty[String], provider = "json", options = Map("path" -> "just a dummy path"), isExternal = false) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3427152b2da0..1f18c3aac538 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1479,4 +1479,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } + + test("aa") { + Seq(("a", 1, 2.3), ("b", 2, 3.4)).toDF("i", "j", "k").write + .format("orc") + .partitionBy("i") + .bucketBy(5, "j") + .sortBy("k") + .saveAsTable("tt") + sqlContext.table("tt").show() + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 01960fd2901b..b5a5e280cd42 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -41,12 +41,24 @@ class SimpleTextSource extends HadoopFsRelationProvider { paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], + numBuckets: Int, + bucketColumns: Array[String], + sortColumns: Array[String], parameters: Map[String, String]): HadoopFsRelation = { - new SimpleTextRelation(paths, schema, partitionColumns, parameters)(sqlContext) + new SimpleTextRelation( + paths, + schema, + partitionColumns, + 0, + bucketColumns, + sortColumns, + parameters)(sqlContext) } } -class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullWritable, Text] { +class AppendingTextOutputFormat( + outputFile: Path, + bucketId: Option[Int]) extends TextOutputFormat[NullWritable, Text] { val numberFormat = NumberFormat.getInstance() numberFormat.setMinimumIntegerDigits(5) @@ -55,16 +67,20 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) val split = taskAttemptId.getTaskID.getId val name = FileOutputFormat.getOutputName(context) - new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") + new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId$bucketString") } } -class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { +class SimpleTextOutputWriter( + path: String, + bucketId: Option[Int], + context: TaskAttemptContext) extends OutputWriter { private val recordWriter: RecordWriter[NullWritable, Text] = - new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) + new AppendingTextOutputFormat(new Path(path), bucketId).getRecordWriter(context) override def write(row: Row): Unit = { val serialized = row.toSeq.map { v => @@ -87,6 +103,9 @@ class SimpleTextRelation( override val paths: Array[String], val maybeDataSchema: Option[StructType], override val userDefinedPartitionColumns: Option[StructType], + val numBuckets: Int, + val bucketColumns: Array[String], + val sortColumns: Array[String], parameters: Map[String, String])( @transient val sqlContext: SQLContext) extends HadoopFsRelation(parameters) { @@ -178,9 +197,10 @@ class SimpleTextRelation( override def newInstance( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) + new SimpleTextOutputWriter(path, bucketId, context) } } @@ -211,8 +231,18 @@ class CommitFailureTestSource extends HadoopFsRelationProvider { paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], + numBuckets: Int, + bucketColumns: Array[String], + sortColumns: Array[String], parameters: Map[String, String]): HadoopFsRelation = { - new CommitFailureTestRelation(paths, schema, partitionColumns, parameters)(sqlContext) + new CommitFailureTestRelation( + paths, + schema, + partitionColumns, + numBuckets, + bucketColumns, + sortColumns, + parameters)(sqlContext) } } @@ -220,16 +250,26 @@ class CommitFailureTestRelation( override val paths: Array[String], maybeDataSchema: Option[StructType], override val userDefinedPartitionColumns: Option[StructType], + numBuckets: Int, + bucketColumns: Array[String], + sortColumns: Array[String], parameters: Map[String, String])( @transient sqlContext: SQLContext) extends SimpleTextRelation( - paths, maybeDataSchema, userDefinedPartitionColumns, parameters)(sqlContext) { + paths, + maybeDataSchema, + userDefinedPartitionColumns, + numBuckets, + bucketColumns, + sortColumns, + parameters)(sqlContext) { override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { override def newInstance( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) { + new SimpleTextOutputWriter(path, bucketId, context) { override def close(): Unit = { super.close() sys.error("Intentional task commitment failure for testing purpose.") From a9dc99722bfea886c6381abbd2e1e9366fcf9064 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 29 Dec 2015 22:31:50 +0800 Subject: [PATCH 02/15] code refine --- .../apache/spark/sql/DataFrameWriter.scala | 57 ++++++++++++---- .../execution/datasources/BucketSpec.scala | 23 +++++++ .../sql/execution/datasources/DDLParser.scala | 4 +- .../InsertIntoHadoopFsRelation.scala | 6 +- .../datasources/ResolvedDataSource.scala | 22 ++---- .../datasources/WriterContainer.scala | 42 +++++++----- .../spark/sql/execution/datasources/ddl.scala | 8 +-- .../datasources/json/JSONRelation.scala | 18 ++--- .../datasources/parquet/ParquetRelation.scala | 17 ++--- .../sql/execution/datasources/rules.scala | 9 ++- .../datasources/text/DefaultSource.scala | 13 ++-- .../apache/spark/sql/sources/interfaces.scala | 12 ++-- .../spark/sql/hive/HiveMetastoreCatalog.scala | 24 +++---- .../spark/sql/hive/HiveStrategies.scala | 2 +- .../spark/sql/hive/execution/commands.scala | 18 ++--- .../spark/sql/hive/orc/OrcRelation.scala | 24 ++----- .../sql/hive/MetastoreDataSourcesSuite.scala | 4 +- .../sql/hive/execution/SQLQuerySuite.scala | 9 --- .../sql/sources/BucketedWriteSuite.scala | 67 +++++++++++++++++++ .../sql/sources/SimpleTextRelation.scala | 40 +++-------- 20 files changed, 224 insertions(+), 195 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketSpec.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3a435a43dff2..b7599b0c87e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.plans.logical.{Project, InsertIntoTable} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils -import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.sources.HadoopFsRelation @@ -158,13 +158,12 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def save(): Unit = { + assertNoBucketing() ResolvedDataSource( df.sqlContext, source, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), - numBuckets.getOrElse(0), - bucketingColumns.map(_.toArray).getOrElse(Array.empty[String]), - sortingColumns.map(_.toArray).getOrElse(Array.empty[String]), + getBucketSpec, mode, extraOptions.toMap, df) @@ -183,6 +182,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def insertInto(tableIdent: TableIdentifier): Unit = { + assertNoBucketing() val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) val overwrite = mode == SaveMode.Overwrite @@ -205,13 +205,44 @@ final class DataFrameWriter private[sql](df: DataFrame) { ifNotExists = false)).toRdd } - private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { parCols => - parCols.map { col => - df.logicalPlan.output - .map(_.name) - .find(df.sqlContext.analyzer.resolver(_, col)) - .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + - s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => + cols.map(normalize(_, "Partition")) + } + + private def normalizedBucketCols: Option[Seq[String]] = bucketingColumns.map { cols => + cols.map(normalize(_, "Bucketing")) + } + + private def normalizedSortCols: Option[Seq[String]] = sortingColumns.map { cols => + cols.map(normalize(_, "Sorting")) + } + + private def getBucketSpec: Option[BucketSpec] = { + if (numBuckets.isEmpty && sortingColumns.isDefined) { + throw new IllegalArgumentException("Specify numBuckets and bucketing columns first.") + } + if (numBuckets.isDefined && numBuckets.get <= 0) { + throw new IllegalArgumentException("numBuckets must be greater than 0.") + } + + if (numBuckets.isDefined) { + Some(BucketSpec(numBuckets.get, normalizedBucketCols.get, normalizedSortCols)) + } else { + None + } + } + + private def normalize(columnName: String, columnType: String): String = { + val validColumnNames = df.logicalPlan.output.map(_.name) + validColumnNames.find(df.sqlContext.analyzer.resolver(_, columnName)) + .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " + + s"existing columns (${validColumnNames.mkString(", ")})")) + } + + private def assertNoBucketing(): Unit = { + if (numBuckets.isDefined || sortingColumns.isDefined) { + throw new IllegalArgumentException( + "Currently we don't support writing bucketed data to this data source.") } } @@ -261,9 +292,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { source, temporary = false, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), - numBuckets.getOrElse(0), - bucketingColumns.map(_.toArray).getOrElse(Array.empty[String]), - sortingColumns.map(_.toArray).getOrElse(Array.empty[String]), + getBucketSpec, mode, extraOptions.toMap, df.logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketSpec.scala new file mode 100644 index 000000000000..606204ea8702 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketSpec.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +case class BucketSpec( + numBuckets: Int, + bucketingColumns: Seq[String], + sortingColumns: Option[Seq[String]]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index 29e2dae26b33..32fef585a1b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -109,9 +109,7 @@ class DDLParser(parseQuery: String => LogicalPlan) provider, temp.isDefined, Array.empty[String], - 0, - Array.empty[String], - Array.empty[String], + None, mode, options, queryPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index f0d8b20d0f09..70a9a14673e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -124,7 +124,7 @@ private[sql] case class InsertIntoHadoopFsRelation( |Actual: ${partitionColumns.mkString(", ")} """.stripMargin) - val writerContainer = if (partitionColumns.isEmpty) { + val writerContainer = if (partitionColumns.isEmpty && relation.bucketSpec.isEmpty) { new DefaultWriterContainer(relation, job, isAppend) } else { val output = df.queryExecution.executedPlan.output @@ -135,9 +135,7 @@ private[sql] case class InsertIntoHadoopFsRelation( relation, job, partitionOutput, - relation.numBuckets, - relation.bucketColumns.map(c => dataOutput.find(_.name == c).get), - relation.sortColumns.map(c => dataOutput.find(_.name == c).get), + relation.bucketSpec, dataOutput, output, PartitioningUtils.DEFAULT_PARTITION_NAME, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 7fa87d665e98..0e21d44d972f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -142,9 +142,7 @@ object ResolvedDataSource extends Logging { paths, Some(dataSchema), maybePartitionsSchema, - 0, - Array.empty, - Array.empty, + None, caseInsensitiveOptions) case dataSource: org.apache.spark.sql.sources.RelationProvider => throw new AnalysisException(s"$className does not allow user-specified schemas.") @@ -176,15 +174,7 @@ object ResolvedDataSource extends Logging { SparkHadoopUtil.get.globPathIfNecessary(qualified).map(_.toString) } } - dataSource.createRelation( - sqlContext, - paths, - None, - None, - 0, - Array.empty, - Array.empty, - caseInsensitiveOptions) + dataSource.createRelation(sqlContext, paths, None, None, None, caseInsensitiveOptions) case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => throw new AnalysisException( s"A schema needs to be specified when using $className.") @@ -221,9 +211,7 @@ object ResolvedDataSource extends Logging { sqlContext: SQLContext, provider: String, partitionColumns: Array[String], - numBuckets: Int, - bucketColumns: Array[String], - sortColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], data: DataFrame): ResolvedDataSource = { @@ -258,9 +246,7 @@ object ResolvedDataSource extends Logging { Array(outputPath.toString), Some(dataSchema.asNullable), Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), - numBuckets, - bucketColumns, - sortColumns, + bucketSpec, caseInsensitiveOptions) // For partitioned relation r, r.schema's column ordering can be different from the column diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index f5627c372a78..c92ef1572a5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -310,9 +310,7 @@ private[sql] class DynamicPartitionWriterContainer( relation: HadoopFsRelation, job: Job, partitionColumns: Seq[Attribute], - numBuckets: Int, - bucketColumns: Seq[Attribute], - sortColumns: Seq[Attribute], + bucketSpec: Option[BucketSpec], dataColumns: Seq[Attribute], inputSchema: Seq[Attribute], defaultPartitionName: String, @@ -320,31 +318,34 @@ private[sql] class DynamicPartitionWriterContainer( isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { + private def toAttribute(columnName: String) = inputSchema.find(_.name == columnName).get + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] executorSideSetup(taskContext) var outputWritersCleared = false - // TODO: this follows hive, but can we just use pmod? - val buckNumExpr = Remainder(Abs(Hash(bucketColumns)), Literal(numBuckets)) - - val getKey = if (numBuckets == 0) { + val getKey = if (bucketSpec.isEmpty) { UnsafeProjection.create(partitionColumns, inputSchema) } else { - UnsafeProjection.create(partitionColumns ++ (buckNumExpr +: sortColumns), inputSchema) + val BucketSpec(numBuckets, bucketColumns, sortColumns) = bucketSpec.get + val bucketIdExpr = Pmod(Hash(bucketColumns.map(toAttribute)), Literal(numBuckets)) + val sortingAttrs = sortColumns.map(_.map(toAttribute)).getOrElse(Nil) + UnsafeProjection.create(partitionColumns ++ (bucketIdExpr +: sortingAttrs), inputSchema) } - val keySchema = if (numBuckets == 0) { + val keySchema = if (bucketSpec.isEmpty) { StructType.fromAttributes(partitionColumns) } else { + val sortingAttrs = bucketSpec.get.sortingColumns.map(_.map(toAttribute)).getOrElse(Nil) val fields = StructType.fromAttributes(partitionColumns) .add("bucketId", IntegerType).fields ++ - StructType.fromAttributes(sortColumns).fields + StructType.fromAttributes(sortingAttrs).fields StructType(fields) } - def getBucketId(key: InternalRow): Option[Int] = if (numBuckets > 0) { + def getBucketId(key: InternalRow): Option[Int] = if (bucketSpec.isDefined) { Some(key.getInt(partitionColumns.length)) } else { None @@ -370,7 +371,8 @@ private[sql] class DynamicPartitionWriterContainer( // If anything below fails, we should abort the task. try { // This will be filled in if we have to fall back on sorting. - var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) { + val alwaysSort = bucketSpec.isDefined && bucketSpec.get.sortingColumns.isDefined + var sorter: UnsafeKVExternalSorter = if (alwaysSort) { new UnsafeKVExternalSorter( keySchema, StructType.fromAttributes(dataColumns), @@ -449,13 +451,17 @@ private[sql] class DynamicPartitionWriterContainer( /** Open and returns a new OutputWriter given a partition key. */ def newOutputWriter(key: InternalRow): OutputWriter = { - val partitionPath = getPartitionString(key).getString(0) + val path = if (partitionColumns.nonEmpty) { + val partitionPath = getPartitionString(key).getString(0) + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) + configuration.set( + "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) + new Path(getWorkPath, partitionPath).toString + } else { + getWorkPath + } val bucketId = getBucketId(key) - val path = new Path(getWorkPath, partitionPath) - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) - configuration.set( - "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - val newWriter = super.newOutputWriter(path.toString, bucketId) + val newWriter = super.newOutputWriter(path, bucketId) newWriter.initConverter(dataSchema) newWriter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 6956b667c10b..21fea9e6eeb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -76,9 +76,7 @@ case class CreateTableUsingAsSelect( provider: String, temporary: Boolean, partitionColumns: Array[String], - numBuckets: Int, - bucketColumns: Array[String], - sortColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], child: LogicalPlan) extends UnaryNode { @@ -116,9 +114,7 @@ case class CreateTempTableUsingAsSelect( sqlContext, provider, partitionColumns, - 0, - Array.empty, - Array.empty, + None, mode, options, df) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 9d5bf9042b8b..01d2cc6186e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -35,7 +35,7 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.execution.datasources.{BucketSpec, PartitionSpec} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{AnalysisException, Row, SQLContext} @@ -51,9 +51,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], - numBuckets: Int, - bucketColumns: Array[String], - sortColumns: Array[String], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { new JSONRelation( @@ -61,9 +59,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { maybeDataSchema = dataSchema, maybePartitionSpec = None, userDefinedPartitionColumns = partitionColumns, - numBuckets, - bucketColumns, - sortColumns, + bucketSpec = bucketSpec, paths = paths, parameters = parameters)(sqlContext) } @@ -74,9 +70,7 @@ private[sql] class JSONRelation( val maybeDataSchema: Option[StructType], val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - val numBuckets: Int, - val bucketColumns: Array[String], - val sortColumns: Array[String], + val bucketSpec: Option[BucketSpec], override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) @@ -94,9 +88,7 @@ private[sql] class JSONRelation( maybeDataSchema, maybePartitionSpec, userDefinedPartitionColumns, - 0, - Array.empty, - Array.empty, + None, paths, parameters)(sqlContext) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 37b47cdacc78..5620816824b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -44,7 +44,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.execution.datasources.{BucketSpec, PartitionSpec} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -60,12 +60,9 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], - numBuckets: Int, - bucketColumns: Array[String], - sortColumns: Array[String], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { - new ParquetRelation(paths, schema, None, partitionColumns, numBuckets, bucketColumns, - sortColumns, parameters)(sqlContext) + new ParquetRelation(paths, schema, None, partitionColumns, bucketSpec, parameters)(sqlContext) } } @@ -115,9 +112,7 @@ private[sql] class ParquetRelation( // This is for metastore conversion. private val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - val numBuckets: Int, - val bucketColumns: Array[String], - val sortColumns: Array[String], + val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) @@ -134,9 +129,7 @@ private[sql] class ParquetRelation( maybeDataSchema, maybePartitionSpec, maybePartitionSpec.map(_.partitionColumns), - 0, - Array.empty, - Array.empty, + None, parameters)(sqlContext) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 21830e1a8647..54e46e23ea65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} +import org.apache.spark.sql.catalyst.expressions.{RowOrdering, Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -194,6 +194,13 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => PartitioningUtils.validatePartitionColumnDataTypes( c.child.schema, c.partitionColumns, catalog.conf.caseSensitiveAnalysis) + c.bucketSpec.foreach(_.sortingColumns.foreach(_.foreach { sortCol => + val dataType = c.child.schema.find(_.name == sortCol).get.dataType + if (!RowOrdering.isOrderable(dataType)) { + failAnalysis(s"Cannot use ${dataType.simpleString} for sorting column.") + } + })) + case _ => // OK } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index ae6d1a9f7a15..dc011c677f0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, BufferHolder} import org.apache.spark.sql.{AnalysisException, Row, SQLContext} -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.execution.datasources.{BucketSpec, PartitionSpec} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -48,13 +48,10 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], - numBuckets: Int, - bucketColumns: Array[String], - sortColumns: Array[String], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { dataSchema.foreach(verifySchema) - new TextRelation( - None, partitionColumns, numBuckets, bucketColumns, sortColumns, paths)(sqlContext) + new TextRelation(None, partitionColumns, bucketSpec, paths)(sqlContext) } override def shortName(): String = "text" @@ -75,9 +72,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { private[sql] class TextRelation( val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - val numBuckets: Int, - val bucketColumns: Array[String], - val sortColumns: Array[String], + val bucketSpec: Option[BucketSpec], override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index e54d3f3c04f7..5618889a5ad7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapred.{JobConf, FileInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} -import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} +import org.apache.spark.sql.execution.datasources.{BucketSpec, PartitioningUtils, PartitionSpec, Partition} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration @@ -160,9 +160,7 @@ trait HadoopFsRelationProvider { paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], - numBuckets: Int, - bucketColumns: Array[String], - sortColumns: Array[String], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation } @@ -584,9 +582,7 @@ abstract class HadoopFsRelation private[sql]( final def partitionColumns: StructType = userDefinedPartitionColumns.getOrElse(partitionSpec.partitionColumns) - def numBuckets: Int - def bucketColumns: Array[String] - def sortColumns: Array[String] + def bucketSpec: Option[BucketSpec] /** * Optional user defined partition columns. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 043cc27a0f11..f3df871e7540 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{Partition => ParquetPartition, _} import org.apache.spark.sql.execution.{FileRelation, datasources} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.HiveNativeCommand @@ -211,9 +211,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], partitionColumns: Array[String], - numBuckets: Int, - bucketColumns: Array[String], - sortColumns: Array[String], + bucketSpec: Option[BucketSpec], provider: String, options: Map[String, String], isExternal: Boolean): Unit = { @@ -243,18 +241,20 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } - if (userSpecifiedSchema.isDefined && bucketColumns.length > 0) { + if (userSpecifiedSchema.isDefined && bucketSpec.isDefined) { + val BucketSpec(numBuckets, bucketColumns, sortColumns) = bucketSpec.get + tableProperties.put("spark.sql.sources.schema.numBuckets", numBuckets.toString) tableProperties.put("spark.sql.sources.schema.numBucketCols", bucketColumns.length.toString) bucketColumns.zipWithIndex.foreach { case (bucketCol, index) => tableProperties.put(s"spark.sql.sources.schema.bucketCol.$index", bucketCol) } - } - if (userSpecifiedSchema.isDefined && sortColumns.length > 0) { - tableProperties.put("spark.sql.sources.schema.numSortCols", sortColumns.length.toString) - sortColumns.zipWithIndex.foreach { case (sortCol, index) => - tableProperties.put(s"spark.sql.sources.schema.sortCol.$index", sortCol) + if (sortColumns.isDefined) { + tableProperties.put("spark.sql.sources.schema.numSortCols", sortColumns.get.length.toString) + sortColumns.get.zipWithIndex.foreach { case (sortCol, index) => + tableProperties.put(s"spark.sql.sources.schema.sortCol.$index", sortCol) + } } } @@ -614,9 +614,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive conf.defaultDataSourceName, temporary = false, Array.empty[String], - 0, - Array.empty[String], - Array.empty[String], + None, mode, options = Map.empty[String, String], child diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 6800e0ef51b0..016567067463 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -91,7 +91,7 @@ private[hive] trait HiveStrategies { case c: CreateTableUsingAsSelect => val cmd = CreateMetastoreDataSourceAsSelect(c.tableIdent, c.provider, c.partitionColumns, - c.numBuckets, c.bucketColumns, c.sortColumns, c.mode, c.options, c.child) + c.bucketSpec, c.mode, c.options, c.child) ExecutedCommand(cmd) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index ecb2a09190f7..bb0b81925a2d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{BucketSpec, LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -151,9 +151,7 @@ case class CreateMetastoreDataSource( tableIdent, userSpecifiedSchema, Array.empty[String], - 0, - Array.empty[String], - Array.empty[String], + None, provider, optionsWithPath, isExternal) @@ -167,9 +165,7 @@ case class CreateMetastoreDataSourceAsSelect( tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], - numBuckets: Int, - bucketColumns: Array[String], - sortColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], query: LogicalPlan) extends RunnableCommand { @@ -264,9 +260,7 @@ case class CreateMetastoreDataSourceAsSelect( sqlContext, provider, partitionColumns, - numBuckets, - bucketColumns, - sortColumns, + bucketSpec, mode, optionsWithPath, df) @@ -279,9 +273,7 @@ case class CreateMetastoreDataSourceAsSelect( tableIdent, Some(resolved.relation.schema), partitionColumns, - numBuckets, - bucketColumns, - sortColumns, + bucketSpec, provider, optionsWithPath, isExternal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 0f1e01b9ee93..c253c1d3f531 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -38,7 +38,7 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.execution.datasources.{BucketSpec, PartitionSpec} import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType @@ -54,23 +54,13 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], - numBuckets: Int, - bucketColumns: Array[String], - sortColumns: Array[String], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { assert( sqlContext.isInstanceOf[HiveContext], "The ORC data source can only be used with HiveContext.") - new OrcRelation( - paths, - dataSchema, - None, - partitionColumns, - numBuckets, - bucketColumns, - sortColumns, - parameters)(sqlContext) + new OrcRelation(paths, dataSchema, None, partitionColumns, bucketSpec, parameters)(sqlContext) } } @@ -168,9 +158,7 @@ private[sql] class OrcRelation( maybeDataSchema: Option[StructType], maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - val numBuckets: Int, - val bucketColumns: Array[String], - val sortColumns: Array[String], + val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( @transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) @@ -187,9 +175,7 @@ private[sql] class OrcRelation( maybeDataSchema, maybePartitionSpec, maybePartitionSpec.map(_.partitionColumns), - 0, - Array.empty, - Array.empty, + None, parameters)(sqlContext) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 9b34b917b840..57872f9500f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -707,9 +707,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv tableIdent = TableIdentifier("wide_schema"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], - numBuckets = 0, - bucketColumns = Array.empty[String], - sortColumns = Array.empty[String], + bucketSpec = None, provider = "json", options = Map("path" -> "just a dummy path"), isExternal = false) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1f18c3aac538..76fb94a3374a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1480,13 +1480,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """.stripMargin), Row("value1", "12", 3.14, "hello")) } - test("aa") { - Seq(("a", 1, 2.3), ("b", 2, 3.4)).toDF("i", "j", "k").write - .format("orc") - .partitionBy("i") - .bucketBy(5, "j") - .sortBy("k") - .saveAsTable("tt") - sqlContext.table("tt").show() - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala new file mode 100644 index 000000000000..78bb96d4ff6f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{AnalysisException, QueryTest} + +class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + + test("bucketed by non-existing column") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[AnalysisException](df.write.bucketBy(2, "k").saveAsTable("tt")) + } + + test("numBuckets not greater than 0") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.bucketBy(0, "i").saveAsTable("tt")) + } + + test("specify sorting columns without bucketing columns") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.sortBy("j").saveAsTable("tt")) + } + + test("sorting by non-orderable column") { + val df = Seq("a" -> Map(1 -> 1), "b" -> Map(2 -> 2)).toDF("i", "j") + intercept[AnalysisException](df.write.bucketBy(2, "i").sortBy("j").saveAsTable("tt")) + } + + test("write bucketed data to non-hive-table or existing hive table") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").parquet("/tmp/path")) + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").json("/tmp/path")) + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt")) + } + + test("write bucketed data") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + withTable("bucketedTable") { + df.write.partitionBy("i").bucketBy(8, "j").saveAsTable("bucketedTable") + } + } + + test("write bucketed data without partitioning") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + withTable("bucketedTable") { + df.write.bucketBy(8, "i").sortBy("j").saveAsTable("bucketedTable") + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index b5a5e280cd42..6bc76db1525f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -29,6 +29,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} +import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLContext, sources} @@ -41,18 +42,9 @@ class SimpleTextSource extends HadoopFsRelationProvider { paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], - numBuckets: Int, - bucketColumns: Array[String], - sortColumns: Array[String], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { - new SimpleTextRelation( - paths, - schema, - partitionColumns, - 0, - bucketColumns, - sortColumns, - parameters)(sqlContext) + new SimpleTextRelation(paths, schema, partitionColumns, bucketSpec, parameters)(sqlContext) } } @@ -103,9 +95,7 @@ class SimpleTextRelation( override val paths: Array[String], val maybeDataSchema: Option[StructType], override val userDefinedPartitionColumns: Option[StructType], - val numBuckets: Int, - val bucketColumns: Array[String], - val sortColumns: Array[String], + val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( @transient val sqlContext: SQLContext) extends HadoopFsRelation(parameters) { @@ -231,18 +221,10 @@ class CommitFailureTestSource extends HadoopFsRelationProvider { paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], - numBuckets: Int, - bucketColumns: Array[String], - sortColumns: Array[String], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { - new CommitFailureTestRelation( - paths, - schema, - partitionColumns, - numBuckets, - bucketColumns, - sortColumns, - parameters)(sqlContext) + new CommitFailureTestRelation(paths, schema, partitionColumns, bucketSpec, parameters)( + sqlContext) } } @@ -250,18 +232,14 @@ class CommitFailureTestRelation( override val paths: Array[String], maybeDataSchema: Option[StructType], override val userDefinedPartitionColumns: Option[StructType], - numBuckets: Int, - bucketColumns: Array[String], - sortColumns: Array[String], + bucketSpec: Option[BucketSpec], parameters: Map[String, String])( @transient sqlContext: SQLContext) extends SimpleTextRelation( paths, maybeDataSchema, userDefinedPartitionColumns, - numBuckets, - bucketColumns, - sortColumns, + bucketSpec, parameters)(sqlContext) { override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { override def newInstance( From 4c9969848fdc35b8a5afe83f80b071d9ad310636 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 30 Dec 2015 22:58:12 +0800 Subject: [PATCH 03/15] add more tests --- .../spark/sql/catalyst/expressions/misc.scala | 216 ------------------ .../apache/spark/sql/DataFrameWriter.scala | 19 +- .../execution/datasources/BucketSpec.scala | 23 +- .../datasources/WriterContainer.scala | 169 +++++++++----- .../sql/sources/BucketedWriteSuite.scala | 137 ++++++++++- 5 files changed, 275 insertions(+), 289 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 8f0950731845..3121ff512fb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -178,219 +178,3 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp }) } } - -/** - * A function that calculates hash value for a group of expressions. - * - * The hash value for an expression depends on its type: - * - null: 0 - * - boolean: 1 for true, 0 for false. - * - byte, short, int: the input itself. - * - long: input XOR (input >>> 32) - * - float: java.lang.Float.floatToIntBits(input) - * - double: l = java.lang.Double.doubleToLongBits(input); l XOR (l >>> 32) - * - binary: java.util.Arrays.hashCode(input) - * - array: recursively calculate hash value for each element, and aggregate them by - * `result = result * 31 + elementHash` with an initial value `result = 0`. - * - map: recursively calculate hash value for each key-value pair, and aggregate - * them by `result += keyHash XOR valueHash`. - * - struct: similar to array, calculate hash value for each field and aggregate them. - * - other type: input.hashCode(). - * e.g. calculate hash value for string type by `UTF8String.hashCode()`. - * Finally we aggregate the hash values for each expression by `result = result * 31 + exprHash`. - * - * This hash algorithm follows hive's bucketing hash function, so that our bucketing function can - * be compatible with hive's, e.g. we can benefit from bucketing even the data source is mixed with - * hive tables. - */ -case class Hash(children: Seq[Expression]) extends Expression { - - override def dataType: DataType = IntegerType - - override def foldable: Boolean = children.forall(_.foldable) - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - var result = 0 - for (e <- children) { - val hashValue = computeHash(e.eval(input), e.dataType) - result = result * 31 + hashValue - } - result - } - - private def computeHash(v: Any, dataType: DataType): Int = v match { - case null => 0 - case b: Boolean => if (b) 1 else 0 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case a: Array[Byte] => java.util.Arrays.hashCode(a) - - case array: ArrayData => - val elementType = dataType.asInstanceOf[ArrayType].elementType - var result = 0 - var i = 0 - while (i < array.numElements()) { - val hashValue = computeHash(array.get(i, elementType), elementType) - result = result * 31 + hashValue - i += 1 - } - result - - case map: MapData => - val mapType = dataType.asInstanceOf[MapType] - val keys = map.keyArray() - val values = map.valueArray() - var result = 0 - var i = 0 - while (i < map.numElements()) { - val keyHash = computeHash(keys.get(i, mapType.keyType), mapType.keyType) - val valueHash = computeHash(values.get(i, mapType.valueType), mapType.valueType) - result += keyHash ^ valueHash - i += 1 - } - result - - case row: InternalRow => - val fieldTypes = dataType.asInstanceOf[StructType].map(_.dataType) - var result = 0 - var i = 0 - while (i < row.numFields) { - val hashValue = computeHash(row.get(i, fieldTypes(i)), fieldTypes(i)) - result = result * 31 + hashValue - i += 1 - } - result - - case other => other.hashCode() - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val expressions = children.map(_.gen(ctx)) - - val updateHashResult = expressions.zip(children.map(_.dataType)).map { case (expr, dataType) => - val hash = computeHash(expr.value, dataType, ctx) - s""" - if (${expr.isNull}) { - ${ev.value} *= 31; - } else { - ${hash.code} - ${ev.value} = ${ev.value} * 31 + ${hash.value}; - } - """ - }.mkString("\n") - - s""" - ${expressions.map(_.code).mkString("\n")} - final boolean ${ev.isNull} = false; - int ${ev.value} = 0; - $updateHashResult - """ - } - - private def computeHash( - input: String, - dataType: DataType, - ctx: CodeGenContext): GeneratedExpressionCode = { - def simpleHashValue(v: String) = GeneratedExpressionCode(code = "", isNull = "false", value = v) - - dataType match { - case NullType => simpleHashValue("0") - case BooleanType => simpleHashValue(s"($input ? 1 : 0)") - case ByteType | ShortType | IntegerType | DateType => simpleHashValue(input) - case LongType | TimestampType => simpleHashValue(s"(int) ($input ^ ($input >>> 32))") - case FloatType => simpleHashValue(s"Float.floatToIntBits($input)") - case DoubleType => - val longBits = ctx.freshName("longBits") - GeneratedExpressionCode( - code = s"final long $longBits = Double.doubleToLongBits($input);", - isNull = "false", - value = s"(int) ($longBits ^ ($longBits >>> 32))" - ) - case BinaryType => simpleHashValue(s"java.util.Arrays.hashCode($input)") - - case ArrayType(et, _) => - val arrayHash = ctx.freshName("arrayHash") - val index = ctx.freshName("index") - val element = ctx.freshName("element") - val hash = computeHash(element, et, ctx) - val code = s""" - int $arrayHash = 0; - for (int $index = 0; $index < $input.numElements(); $index++) { - if ($input.isNullAt($index)) { - $arrayHash *= 31; - } else { - final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)}; - ${hash.code} - $arrayHash = $arrayHash * 31 + ${hash.value}; - } - } - """ - GeneratedExpressionCode(code = code, isNull = "false", value = arrayHash) - - case MapType(kt, vt, _) => - val mapHash = ctx.freshName("mapHash") - - val keys = ctx.freshName("keys") - val key = ctx.freshName("key") - val keyHash = computeHash(key, kt, ctx) - - val values = ctx.freshName("values") - val value = ctx.freshName("value") - val valueHash = computeHash(value, vt, ctx) - - val index = ctx.freshName("index") - - val code = s""" - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); - int $mapHash = 0; - for (int $index = 0; $index < $input.numElements(); $index++) { - final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)}; - ${keyHash.code} - if ($values.isNullAt($index)) { - $mapHash += ${keyHash.value}; - } else { - final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)}; - ${valueHash.code} - $mapHash += ${keyHash.value} ^ ${valueHash.value}; - } - } - """ - GeneratedExpressionCode(code = code, isNull = "false", value = mapHash) - - case StructType(fields) => - val structHash = ctx.freshName("structHash") - - val updateHashResult = fields.zipWithIndex.map { case (f, i) => - val jt = ctx.javaType(f.dataType) - val fieldValue = ctx.freshName(f.name) - val fieldHash = computeHash(fieldValue, f.dataType, ctx) - s""" - if ($input.isNullAt($i)) { - $structHash *= 31; - } else { - final $jt $fieldValue = ${ctx.getValue(input, f.dataType, i.toString)}; - ${fieldHash.code} - $structHash = $structHash * 31 + ${fieldHash.value}; - } - """ - }.mkString("\n") - - val code = s""" - int $structHash = 0; - $updateHashResult - """ - GeneratedExpressionCode(code = code, isNull = "false", value = structHash) - - case other => simpleHashValue(s"$input.hashCode()") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b7599b0c87e3..ea3a479c3890 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -130,15 +130,15 @@ final class DataFrameWriter private[sql](df: DataFrame) { } @scala.annotation.varargs - def bucketBy(numBuckets: Int, colNames: String*): DataFrameWriter = { - this.numBuckets = Some(numBuckets) - this.bucketingColumns = Option(colNames) + def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = { + this.numBuckets = Option(numBuckets) + this.bucketingColumns = Option(colName +: colNames) this } @scala.annotation.varargs - def sortBy(colNames: String*): DataFrameWriter = { - this.sortingColumns = Option(colNames) + def sortBy(colName: String, colNames: String*): DataFrameWriter = { + this.sortingColumns = Option(colName +: colNames) this } @@ -158,7 +158,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def save(): Unit = { - assertNoBucketing() + assertNotBucketed() ResolvedDataSource( df.sqlContext, source, @@ -182,7 +182,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def insertInto(tableIdent: TableIdentifier): Unit = { - assertNoBucketing() + assertNotBucketed() val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) val overwrite = mode == SaveMode.Overwrite @@ -219,7 +219,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { private def getBucketSpec: Option[BucketSpec] = { if (numBuckets.isEmpty && sortingColumns.isDefined) { - throw new IllegalArgumentException("Specify numBuckets and bucketing columns first.") + throw new IllegalArgumentException( + "Specify bucketing information by bucketBy when use sortBy.") } if (numBuckets.isDefined && numBuckets.get <= 0) { throw new IllegalArgumentException("numBuckets must be greater than 0.") @@ -239,7 +240,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { s"existing columns (${validColumnNames.mkString(", ")})")) } - private def assertNoBucketing(): Unit = { + private def assertNotBucketed(): Unit = { if (numBuckets.isDefined || sortingColumns.isDefined) { throw new IllegalArgumentException( "Currently we don't support writing bucketed data to this data source.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketSpec.scala index 606204ea8702..3f20d99305c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketSpec.scala @@ -17,7 +17,26 @@ package org.apache.spark.sql.execution.datasources -case class BucketSpec( +import org.apache.spark.sql.catalyst.expressions.Attribute + +private[sql] case class BucketSpec( numBuckets: Int, bucketingColumns: Seq[String], - sortingColumns: Option[Seq[String]]) + sortingColumns: Option[Seq[String]]) { + + def resolvedBucketingColumns(inputSchema: Seq[Attribute]): Seq[Attribute] = { + bucketingColumns.map { col => + inputSchema.find(_.name == col).get + } + } + + def resolvedSortingColumns(inputSchema: Seq[Attribute]): Seq[Attribute] = { + if (sortingColumns.isDefined) { + sortingColumns.get.map { col => + inputSchema.find(_.name == col).get + } + } else { + Nil + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index c92ef1572a5d..7be1b128158b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -318,31 +318,31 @@ private[sql] class DynamicPartitionWriterContainer( isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { - private def toAttribute(columnName: String) = inputSchema.find(_.name == columnName).get - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] executorSideSetup(taskContext) var outputWritersCleared = false - val getKey = if (bucketSpec.isEmpty) { - UnsafeProjection.create(partitionColumns, inputSchema) + val getKey: InternalRow => UnsafeRow = if (bucketSpec.isEmpty) { + val projection = UnsafeProjection.create(partitionColumns, inputSchema) + row => projection(row) } else { - val BucketSpec(numBuckets, bucketColumns, sortColumns) = bucketSpec.get - val bucketIdExpr = Pmod(Hash(bucketColumns.map(toAttribute)), Literal(numBuckets)) - val sortingAttrs = sortColumns.map(_.map(toAttribute)).getOrElse(Nil) - UnsafeProjection.create(partitionColumns ++ (bucketIdExpr +: sortingAttrs), inputSchema) + val bucketColumns = bucketSpec.get.resolvedBucketingColumns(inputSchema) + val getBucketKey = UnsafeProjection.create(bucketColumns, inputSchema) + val getResultRow = UnsafeProjection.create(partitionColumns :+ Literal(-1), inputSchema) + row => { + val bucketId = math.abs(getBucketKey(row).hashCode()) % bucketSpec.get.numBuckets + val result = getResultRow(row) + result.setInt(partitionColumns.length, bucketId) + result + } } val keySchema = if (bucketSpec.isEmpty) { StructType.fromAttributes(partitionColumns) } else { - val sortingAttrs = bucketSpec.get.sortingColumns.map(_.map(toAttribute)).getOrElse(Nil) - val fields = StructType.fromAttributes(partitionColumns) - .add("bucketId", IntegerType).fields ++ - StructType.fromAttributes(sortingAttrs).fields - StructType(fields) + StructType.fromAttributes(partitionColumns).add("bucketId", IntegerType, nullable = false) } def getBucketId(key: InternalRow): Option[Int] = if (bucketSpec.isDefined) { @@ -370,75 +370,129 @@ private[sql] class DynamicPartitionWriterContainer( // If anything below fails, we should abort the task. try { - // This will be filled in if we have to fall back on sorting. - val alwaysSort = bucketSpec.isDefined && bucketSpec.get.sortingColumns.isDefined - var sorter: UnsafeKVExternalSorter = if (alwaysSort) { - new UnsafeKVExternalSorter( - keySchema, + val mustSort = bucketSpec.isDefined && bucketSpec.get.sortingColumns.isDefined + // TODO: remove duplicated code. + if (mustSort) { + val bucketColumns = bucketSpec.get.resolvedBucketingColumns(inputSchema) + val sortColumns = bucketSpec.get.resolvedSortingColumns(inputSchema) + + val getSortingKey = { + val getBucketKey = UnsafeProjection.create(bucketColumns, inputSchema) + val getResultRow = UnsafeProjection.create( + (partitionColumns :+ Literal(-1)) ++ sortColumns, inputSchema) + (row: InternalRow) => { + val bucketId = math.abs(getBucketKey(row).hashCode()) % bucketSpec.get.numBuckets + val result = getResultRow(row) + result.setInt(partitionColumns.length, bucketId) + result + } + } + + val sortingKeySchema = { + val fields = StructType.fromAttributes(partitionColumns) + .add("bucketId", IntegerType, nullable = false) ++ + StructType.fromAttributes(sortColumns) + StructType(fields) + } + + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, StructType.fromAttributes(dataColumns), SparkEnv.get.blockManager, TaskContext.get().taskMemoryManager().pageSizeBytes) - } else { - null - } - while (iterator.hasNext && sorter == null) { - val inputRow = iterator.next() - val currentKey = getKey(inputRow) - var currentWriter = outputWriters.get(currentKey) - - if (currentWriter == null) { - if (outputWriters.size < maxOpenFiles) { - currentWriter = newOutputWriter(currentKey) - outputWriters.put(currentKey.copy(), currentWriter) - currentWriter.writeInternal(getOutputRow(inputRow)) - } else { - logInfo(s"Maximum partitions reached, falling back on sorting.") - sorter = new UnsafeKVExternalSorter( - keySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - sorter.insertKV(currentKey, getOutputRow(inputRow)) - } - } else { - currentWriter.writeInternal(getOutputRow(inputRow)) - } - } - // If the sorter is not null that means that we reached the maxFiles above and need to finish - // using external sort. - if (sorter != null) { while (iterator.hasNext) { val currentRow = iterator.next() - sorter.insertKV(getKey(currentRow), getOutputRow(currentRow)) + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) } logInfo(s"Sorting complete. Writing out partition files one at a time.") + def sameBucket(row1: InternalRow, row2: InternalRow): Boolean = { + partitionColumns.map(_.dataType).zipWithIndex.forall { case (dt, index) => + row1.get(index, dt) == row2.get(index, dt) + } && row1.getInt(partitionColumns.length) == row2.getInt(partitionColumns.length) + } val sortedIterator = sorter.sortedIterator() var currentKey: InternalRow = null var currentWriter: OutputWriter = null try { while (sortedIterator.next()) { - if (currentKey != sortedIterator.getKey) { + if (currentKey == null || !sameBucket(currentKey, sortedIterator.getKey)) { if (currentWriter != null) { currentWriter.close() } currentKey = sortedIterator.getKey.copy() logDebug(s"Writing partition: $currentKey") - - // Either use an existing file from before, or open a new one. - currentWriter = outputWriters.remove(currentKey) - if (currentWriter == null) { - currentWriter = newOutputWriter(currentKey) - } + currentWriter = newOutputWriter(currentKey) } - currentWriter.writeInternal(sortedIterator.getValue) } } finally { if (currentWriter != null) { currentWriter.close() } } + } else { + // This will be filled in if we have to fall back on sorting. + var sorter: UnsafeKVExternalSorter = null + while (iterator.hasNext && sorter == null) { + val inputRow = iterator.next() + val currentKey = getKey(inputRow) + var currentWriter = outputWriters.get(currentKey) + + if (currentWriter == null) { + if (outputWriters.size < maxOpenFiles) { + currentWriter = newOutputWriter(currentKey) + outputWriters.put(currentKey.copy(), currentWriter) + currentWriter.writeInternal(getOutputRow(inputRow)) + } else { + logInfo(s"Maximum partitions reached, falling back on sorting.") + sorter = new UnsafeKVExternalSorter( + keySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + sorter.insertKV(currentKey, getOutputRow(inputRow)) + } + } else { + currentWriter.writeInternal(getOutputRow(inputRow)) + } + } + + // If the sorter is not null that means that we reached the maxFiles above and need to + // finish using external sort. + if (sorter != null) { + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getKey(currentRow), getOutputRow(currentRow)) + } + + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val sortedIterator = sorter.sortedIterator() + var currentKey: InternalRow = null + var currentWriter: OutputWriter = null + try { + while (sortedIterator.next()) { + if (currentKey != sortedIterator.getKey) { + if (currentWriter != null) { + currentWriter.close() + } + currentKey = sortedIterator.getKey.copy() + logDebug(s"Writing partition: $currentKey") + + // Either use an existing file from before, or open a new one. + currentWriter = outputWriters.remove(currentKey) + if (currentWriter == null) { + currentWriter = newOutputWriter(currentKey) + } + } + + currentWriter.writeInternal(sortedIterator.getValue) + } + } finally { + if (currentWriter != null) { currentWriter.close() } + } + } } commitTask() @@ -451,13 +505,14 @@ private[sql] class DynamicPartitionWriterContainer( /** Open and returns a new OutputWriter given a partition key. */ def newOutputWriter(key: InternalRow): OutputWriter = { + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) val path = if (partitionColumns.nonEmpty) { val partitionPath = getPartitionString(key).getString(0) - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) configuration.set( "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) new Path(getWorkPath, partitionPath).toString } else { + configuration.set("spark.sql.sources.output.path", outputPath) getWorkPath } val bucketId = getBucketId(key) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 78bb96d4ff6f..5568aaed4911 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.sources +import java.io.File + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.{AnalysisException, QueryTest} @@ -51,17 +54,141 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt")) } + private val parquetFileName = """.*-(\d+)\..*\.parquet""".r + private def getBucketId(fileName: String): Int = { + fileName match { + case parquetFileName(bucketId) => bucketId.toInt + } + } + test("write bucketed data") { - val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") withTable("bucketedTable") { - df.write.partitionBy("i").bucketBy(8, "j").saveAsTable("bucketedTable") + df.write + .format("parquet") + .partitionBy("i") + .bucketBy(8, "j", "k") + .saveAsTable("bucketedTable") + + val tableDir = new File(hiveContext.warehousePath, "bucketedTable") + for (i <- 0 until 5) { + val allBucketFiles = new File(tableDir, s"i=$i").listFiles().filter(!_.isHidden) + val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) + assert(groupedBucketFiles.size <= 8) + + for ((bucketId, bucketFiles) <- groupedBucketFiles) { + for (bucketFile <- bucketFiles) { + withSQLConf("spark.sql.parquet.enableUnsafeRowRecordReader" -> "false") { + val df = sqlContext.read.parquet(bucketFile.getAbsolutePath).select("j", "k") + val rows = df.queryExecution.toRdd.map(_.copy()).collect() + + for (row <- rows) { + assert(row.isInstanceOf[UnsafeRow]) + val actualBucketId = math.abs(row.hashCode()) % 8 + assert(actualBucketId == bucketId) + } + } + } + } + } } } - test("write bucketed data without partitioning") { - val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + test("write bucketed data with sortBy") { + val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + withTable("bucketedTable") { + df.write + .format("parquet") + .partitionBy("i") + .bucketBy(8, "j") + .sortBy("k") + .saveAsTable("bucketedTable") + + val tableDir = new File(hiveContext.warehousePath, "bucketedTable") + for (i <- 0 until 5) { + val allBucketFiles = new File(tableDir, s"i=$i").listFiles() + .filter(_.getName.startsWith("part")) + val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) + assert(groupedBucketFiles.size <= 8) + + for ((bucketId, bucketFiles) <- groupedBucketFiles) { + for (bucketFile <- bucketFiles) { + withSQLConf("spark.sql.parquet.enableUnsafeRowRecordReader" -> "false") { + val df = sqlContext.read.parquet(bucketFile.getAbsolutePath).select("j", "k") + checkAnswer(df.sort("k"), df.collect()) + val rows = df.select("j").queryExecution.toRdd.map(_.copy()).collect() + + for (row <- rows) { + assert(row.isInstanceOf[UnsafeRow]) + val actuaBucketId = math.abs(row.hashCode()) % 8 + assert(actuaBucketId == bucketId) + } + } + } + } + } + } + } + + test("write bucketed data without partitionBy") { + val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") withTable("bucketedTable") { - df.write.bucketBy(8, "i").sortBy("j").saveAsTable("bucketedTable") + df.write + .format("parquet") + .bucketBy(8, "i", "j") + .saveAsTable("bucketedTable") + + val tableDir = new File(hiveContext.warehousePath, "bucketedTable") + val allBucketFiles = tableDir.listFiles().filter(_.getName.startsWith("part")) + val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) + assert(groupedBucketFiles.size <= 8) + + for ((bucketId, bucketFiles) <- groupedBucketFiles) { + for (bucketFile <- bucketFiles) { + withSQLConf("spark.sql.parquet.enableUnsafeRowRecordReader" -> "false") { + val df = sqlContext.read.parquet(bucketFile.getAbsolutePath).select("i", "j") + val rows = df.queryExecution.toRdd.map(_.copy()).collect() + + for (row <- rows) { + assert(row.isInstanceOf[UnsafeRow]) + val actuaBucketId = math.abs(row.hashCode()) % 8 + assert(actuaBucketId == bucketId) + } + } + } + } + } + } + + test("write bucketed data without partitionBy with sortBy") { + val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + withTable("bucketedTable") { + df.write + .format("parquet") + .bucketBy(8, "i", "j") + .sortBy("k") + .saveAsTable("bucketedTable") + + val tableDir = new File(hiveContext.warehousePath, "bucketedTable") + val allBucketFiles = tableDir.listFiles().filter(_.getName.startsWith("part")) + val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) + assert(groupedBucketFiles.size <= 8) + + for ((bucketId, bucketFiles) <- groupedBucketFiles) { + for (bucketFile <- bucketFiles) { + withSQLConf("spark.sql.parquet.enableUnsafeRowRecordReader" -> "false") { + val df = sqlContext.read.parquet(bucketFile.getAbsolutePath) + checkAnswer(df.sort("k"), df.collect()) + val rows = df.select("i", "j").queryExecution.toRdd.map(_.copy()).collect() + + for (row <- rows) { + assert(row.isInstanceOf[UnsafeRow]) + val actuaBucketId = math.abs(row.hashCode()) % 8 + assert(actuaBucketId == bucketId) + } + } + } + } } } } From d2dc9b3ce51bd84ed5f59137d2eb76c8b6bd4f9c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 30 Dec 2015 23:09:19 +0800 Subject: [PATCH 04/15] add more comments --- .../spark/sql/catalyst/expressions/misc.scala | 2 -- .../org/apache/spark/sql/DataFrameWriter.scala | 13 +++++++++++-- .../sql/execution/datasources/WriterContainer.scala | 5 +++-- .../spark/sql/sources/BucketedWriteSuite.scala | 12 ++++++------ 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 3121ff512fb1..97f276d49f08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -23,8 +23,6 @@ import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ea3a479c3890..e062134259c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -119,8 +119,6 @@ final class DataFrameWriter private[sql](df: DataFrame) { * Partitions the output by the given columns on the file system. If specified, the output is * laid out on the file system similar to Hive's partitioning scheme. * - * This is only applicable for Parquet at the moment. - * * @since 1.4.0 */ @scala.annotation.varargs @@ -129,6 +127,12 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * Buckets the output by the given columns on the file system. If specified, the output is + * laid out on the file system similar to Hive's bucketing scheme. + * + * @since 2.0 + */ @scala.annotation.varargs def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = { this.numBuckets = Option(numBuckets) @@ -136,6 +140,11 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * Sorts the bucketed output by the given columns. + * + * @since 2.0 + */ @scala.annotation.varargs def sortBy(colName: String, colNames: String*): DataFrameWriter = { this.sortingColumns = Option(colName +: colNames) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 7be1b128158b..d9f7e3c01574 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -327,9 +327,10 @@ private[sql] class DynamicPartitionWriterContainer( val getKey: InternalRow => UnsafeRow = if (bucketSpec.isEmpty) { val projection = UnsafeProjection.create(partitionColumns, inputSchema) row => projection(row) - } else { + } else { // If it's bucketed, we should also consider bucket id as part of the key. val bucketColumns = bucketSpec.get.resolvedBucketingColumns(inputSchema) val getBucketKey = UnsafeProjection.create(bucketColumns, inputSchema) + // Leave an empty int slot at the last of the result row, so that we can set bucket id later. val getResultRow = UnsafeProjection.create(partitionColumns :+ Literal(-1), inputSchema) row => { val bucketId = math.abs(getBucketKey(row).hashCode()) % bucketSpec.get.numBuckets @@ -341,7 +342,7 @@ private[sql] class DynamicPartitionWriterContainer( val keySchema = if (bucketSpec.isEmpty) { StructType.fromAttributes(partitionColumns) - } else { + } else { // If it's bucketed, we should also consider bucket id as part of the key. StructType.fromAttributes(partitionColumns).add("bucketId", IntegerType, nullable = false) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 5568aaed4911..ff58a3f6479b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -120,8 +120,8 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle for (row <- rows) { assert(row.isInstanceOf[UnsafeRow]) - val actuaBucketId = math.abs(row.hashCode()) % 8 - assert(actuaBucketId == bucketId) + val actualBucketId = math.abs(row.hashCode()) % 8 + assert(actualBucketId == bucketId) } } } @@ -151,8 +151,8 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle for (row <- rows) { assert(row.isInstanceOf[UnsafeRow]) - val actuaBucketId = math.abs(row.hashCode()) % 8 - assert(actuaBucketId == bucketId) + val actualBucketId = math.abs(row.hashCode()) % 8 + assert(actualBucketId == bucketId) } } } @@ -183,8 +183,8 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle for (row <- rows) { assert(row.isInstanceOf[UnsafeRow]) - val actuaBucketId = math.abs(row.hashCode()) % 8 - assert(actuaBucketId == bucketId) + val actualBucketId = math.abs(row.hashCode()) % 8 + assert(actualBucketId == bucketId) } } } From ba2329261740b08bd1a19dc8be0ef281281b84c9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 4 Jan 2016 19:58:18 +0800 Subject: [PATCH 05/15] address comments --- .../sql/catalyst/expressions/UnsafeRow.java | 4 ++ .../spark/sql/catalyst/expressions/misc.scala | 35 +++++++++++++ .../apache/spark/sql/DataFrameWriter.scala | 18 +++---- .../datasources/WriterContainer.scala | 50 ++++++++----------- 4 files changed, 67 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 7492b88c471a..f8956a61ac95 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -562,6 +562,10 @@ public int hashCode() { return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); } + public int hashCode(int seed) { + return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, seed); + } + @Override public boolean equals(Object other) { if (other instanceof UnsafeRow) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index d0ec99b2320d..49bd04397a50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,8 @@ import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -177,3 +179,36 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp }) } } + +case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression { + def this(arguments: Seq[Expression]) = this(arguments, 42) + + override def dataType: DataType = IntegerType + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = false + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckFailure("arguments of function hash cannot be empty") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + private lazy val unsafeProjection = UnsafeProjection.create(children) + + override def eval(input: InternalRow): Any = { + unsafeProjection(input).hashCode(seed) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val unsafeRow = GenerateUnsafeProjection.createCode(ctx, children) + ev.isNull = "false" + s""" + ${unsafeRow.code} + final int ${ev.value} = ${unsafeRow.value}.hashCode($seed); + """ + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index f6badaa86ae2..4af4c57e718f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -227,18 +227,16 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def getBucketSpec: Option[BucketSpec] = { - if (numBuckets.isEmpty && sortingColumns.isDefined) { - throw new IllegalArgumentException( - "Specify bucketing information by bucketBy when use sortBy.") - } - if (numBuckets.isDefined && numBuckets.get <= 0) { - throw new IllegalArgumentException("numBuckets must be greater than 0.") + if (sortingColumns.isDefined) { + require(numBuckets.isDefined, "sortBy must be used together with bucketBy") } - if (numBuckets.isDefined) { - Some(BucketSpec(numBuckets.get, normalizedBucketCols.get, normalizedSortCols)) - } else { - None + for { + n <- numBuckets + cols <- normalizedBucketCols + } yield { + require(n > 0, "Bucket number must be greater than 0.") + BucketSpec(n, cols, normalizedSortCols) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 786f84a42562..4310a0cd17e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -313,26 +313,21 @@ private[sql] class DynamicPartitionWriterContainer( isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { + private def bucketColumns = bucketSpec.get.resolvedBucketingColumns(inputSchema) + private def sortColumns = bucketSpec.get.resolvedSortingColumns(inputSchema) + private def numBuckets = bucketSpec.get.numBuckets + private def bucketIdExpr = Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets)) + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] executorSideSetup(taskContext) var outputWritersCleared = false - val getKey: InternalRow => UnsafeRow = if (bucketSpec.isEmpty) { - val projection = UnsafeProjection.create(partitionColumns, inputSchema) - row => projection(row) + val getKey = if (bucketSpec.isEmpty) { + UnsafeProjection.create(partitionColumns, inputSchema) } else { // If it's bucketed, we should also consider bucket id as part of the key. - val bucketColumns = bucketSpec.get.resolvedBucketingColumns(inputSchema) - val getBucketKey = UnsafeProjection.create(bucketColumns, inputSchema) - // Leave an empty int slot at the last of the result row, so that we can set bucket id later. - val getResultRow = UnsafeProjection.create(partitionColumns :+ Literal(-1), inputSchema) - row => { - val bucketId = math.abs(getBucketKey(row).hashCode()) % bucketSpec.get.numBuckets - val result = getResultRow(row) - result.setInt(partitionColumns.length, bucketId) - result - } + UnsafeProjection.create(partitionColumns :+ bucketIdExpr, inputSchema) } val keySchema = if (bucketSpec.isEmpty) { @@ -372,20 +367,8 @@ private[sql] class DynamicPartitionWriterContainer( val mustSort = bucketSpec.isDefined && bucketSpec.get.sortingColumns.isDefined // TODO: remove duplicated code. if (mustSort) { - val bucketColumns = bucketSpec.get.resolvedBucketingColumns(inputSchema) - val sortColumns = bucketSpec.get.resolvedSortingColumns(inputSchema) - - val getSortingKey = { - val getBucketKey = UnsafeProjection.create(bucketColumns, inputSchema) - val getResultRow = UnsafeProjection.create( - (partitionColumns :+ Literal(-1)) ++ sortColumns, inputSchema) - (row: InternalRow) => { - val bucketId = math.abs(getBucketKey(row).hashCode()) % bucketSpec.get.numBuckets - val result = getResultRow(row) - result.setInt(partitionColumns.length, bucketId) - result - } - } + val getSortingKey = + UnsafeProjection.create((partitionColumns :+ bucketIdExpr) ++ sortColumns, inputSchema) val sortingKeySchema = { val fields = StructType.fromAttributes(partitionColumns) @@ -408,9 +391,16 @@ private[sql] class DynamicPartitionWriterContainer( logInfo(s"Sorting complete. Writing out partition files one at a time.") def sameBucket(row1: InternalRow, row2: InternalRow): Boolean = { - partitionColumns.map(_.dataType).zipWithIndex.forall { case (dt, index) => - row1.get(index, dt) == row2.get(index, dt) - } && row1.getInt(partitionColumns.length) == row2.getInt(partitionColumns.length) + if (row1.getInt(partitionColumns.length) != row2.getInt(partitionColumns.length)) { + false + } else { + var i = partitionColumns.length - 1 + val dt = partitionColumns(i).dataType + while (i >= 0 && row1.get(i, dt) == row2.get(i, dt)) { + i -= 1 + } + i < 0 + } } val sortedIterator = sorter.sortedIterator() var currentKey: InternalRow = null From 21e0c48e83319f7319ba339deb2bffde0188583d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 4 Jan 2016 20:12:52 +0800 Subject: [PATCH 06/15] fix typo --- .../src/main/scala/org/apache/spark/sql/DataFrameWriter.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 4af4c57e718f..88ad806e8c66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -117,7 +117,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { /** * Partitions the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's partitioning scheme. + * laid out on the file system similar to Hive's partitioning schema. * * @since 1.4.0 */ @@ -129,7 +129,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { /** * Buckets the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's bucketing scheme. + * laid out on the file system similar to Hive's bucketing schema. * * @since 2.0 */ From e3c3728fd67aea1849c8d4d1dab3658b1efb7417 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 4 Jan 2016 21:45:59 +0800 Subject: [PATCH 07/15] do not break existing data source API --- .../execution/datasources/BucketSpec.scala | 42 --------- .../InsertIntoHadoopFsRelation.scala | 9 +- .../datasources/ResolvedDataSource.scala | 25 ++++-- .../datasources/WriterContainer.scala | 14 ++- .../sql/execution/datasources/bucket.scala | 89 +++++++++++++++++++ .../datasources/json/JSONRelation.scala | 10 +-- .../datasources/parquet/ParquetRelation.scala | 10 +-- .../datasources/text/DefaultSource.scala | 10 +-- .../apache/spark/sql/sources/interfaces.scala | 4 - .../spark/sql/hive/orc/OrcRelation.scala | 10 +-- .../sql/sources/BucketedWriteSuite.scala | 8 +- .../sql/sources/SimpleTextRelation.scala | 16 ++-- 12 files changed, 154 insertions(+), 93 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketSpec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketSpec.scala deleted file mode 100644 index 3f20d99305c5..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketSpec.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources - -import org.apache.spark.sql.catalyst.expressions.Attribute - -private[sql] case class BucketSpec( - numBuckets: Int, - bucketingColumns: Seq[String], - sortingColumns: Option[Seq[String]]) { - - def resolvedBucketingColumns(inputSchema: Seq[Attribute]): Seq[Attribute] = { - bucketingColumns.map { col => - inputSchema.find(_.name == col).get - } - } - - def resolvedSortingColumns(inputSchema: Seq[Attribute]): Seq[Attribute] = { - if (sortingColumns.isDefined) { - sortingColumns.get.map { col => - inputSchema.find(_.name == col).get - } - } else { - Nil - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index c8217a641422..aeeefecc7973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -124,7 +124,12 @@ private[sql] case class InsertIntoHadoopFsRelation( |Actual: ${partitionColumns.mkString(", ")} """.stripMargin) - val writerContainer = if (partitionColumns.isEmpty && relation.bucketSpec.isEmpty) { + val bucketSpec = relation match { + case relation: BucketedHadoopFsRelation => relation.bucketSpec + case _ => None + } + + val writerContainer = if (partitionColumns.isEmpty && bucketSpec.isEmpty) { new DefaultWriterContainer(relation, job, isAppend) } else { val output = df.queryExecution.executedPlan.output @@ -135,7 +140,7 @@ private[sql] case class InsertIntoHadoopFsRelation( relation, job, partitionOutput, - relation.bucketSpec, + bucketSpec, dataOutput, output, PartitioningUtils.DEFAULT_PARTITION_NAME, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 0e21d44d972f..4aa451f48b13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -142,7 +142,6 @@ object ResolvedDataSource extends Logging { paths, Some(dataSchema), maybePartitionsSchema, - None, caseInsensitiveOptions) case dataSource: org.apache.spark.sql.sources.RelationProvider => throw new AnalysisException(s"$className does not allow user-specified schemas.") @@ -174,7 +173,7 @@ object ResolvedDataSource extends Logging { SparkHadoopUtil.get.globPathIfNecessary(qualified).map(_.toString) } } - dataSource.createRelation(sqlContext, paths, None, None, None, caseInsensitiveOptions) + dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => throw new AnalysisException( s"A schema needs to be specified when using $className.") @@ -241,13 +240,21 @@ object ResolvedDataSource extends Logging { val equality = columnNameEquality(caseSensitive) val dataSchema = StructType( data.schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) - val r = dataSource.createRelation( - sqlContext, - Array(outputPath.toString), - Some(dataSchema.asNullable), - Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), - bucketSpec, - caseInsensitiveOptions) + val r = dataSource match { + case provider: BucketedHadoopFsRelationProvider => provider.createRelation( + sqlContext, + Array(outputPath.toString), + Some(dataSchema.asNullable), + Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), + bucketSpec, + caseInsensitiveOptions) + case provider: HadoopFsRelationProvider => provider.createRelation( + sqlContext, + Array(outputPath.toString), + Some(dataSchema.asNullable), + Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), + caseInsensitiveOptions) + } // For partitioned relation r, r.schema's column ordering can be different from the column // ordering of data.logicalPlan (partition columns are all moved after data column). This diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 4310a0cd17e9..4840d8900d87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -123,7 +123,12 @@ private[sql] abstract class BaseWriterContainer( protected def newOutputWriter(path: String, bucketId: Option[Int]): OutputWriter = { try { - outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext) + outputWriterFactory match { + case factory: BucketedOutputWriterFactory => + factory.newInstance(path, bucketId, dataSchema, taskAttemptContext) + case factory: OutputWriterFactory => + factory.newInstance(path, dataSchema, taskAttemptContext) + } } catch { case e: org.apache.hadoop.fs.FileAlreadyExistsException => if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) { @@ -395,11 +400,12 @@ private[sql] class DynamicPartitionWriterContainer( false } else { var i = partitionColumns.length - 1 - val dt = partitionColumns(i).dataType - while (i >= 0 && row1.get(i, dt) == row2.get(i, dt)) { + while (i >= 0) { + val dt = partitionColumns(i).dataType + if (row1.get(i, dt) != row2.get(i, dt)) return false i -= 1 } - i < 0 + true } } val sortedIterator = sorter.sortedIterator() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala new file mode 100644 index 000000000000..8f53272a2ce6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory, HadoopFsRelationProvider, HadoopFsRelation} +import org.apache.spark.sql.types.StructType + +private[sql] case class BucketSpec( + numBuckets: Int, + bucketingColumns: Seq[String], + sortingColumns: Option[Seq[String]]) { + + def resolvedBucketingColumns(inputSchema: Seq[Attribute]): Seq[Attribute] = { + bucketingColumns.map { col => + inputSchema.find(_.name == col).get + } + } + + def resolvedSortingColumns(inputSchema: Seq[Attribute]): Seq[Attribute] = { + if (sortingColumns.isDefined) { + sortingColumns.get.map { col => + inputSchema.find(_.name == col).get + } + } else { + Nil + } + } +} + +private[sql] trait BucketedHadoopFsRelationProvider extends HadoopFsRelationProvider { + final override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = + createRelation(sqlContext, paths, dataSchema, partitionColumns, None, parameters) + + def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], + parameters: Map[String, String]): BucketedHadoopFsRelation +} + +private[sql] abstract class BucketedHadoopFsRelation( + maybePartitionSpec: Option[PartitionSpec], + parameters: Map[String, String]) + extends HadoopFsRelation(maybePartitionSpec, parameters) { + def this() = this(None, Map.empty[String, String]) + + def this(parameters: Map[String, String]) = this(None, parameters) + + def bucketSpec: Option[BucketSpec] +} + +private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFactory { + final override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = + newInstance(path, None, dataSchema, context) + + def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index c02d9431cb0d..91e93245f0dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -33,14 +33,14 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.execution.datasources.{BucketSpec, PartitionSpec} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.util.SerializableConfiguration -class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "json" @@ -50,7 +50,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { dataSchema: Option[StructType], partitionColumns: Option[StructType], bucketSpec: Option[BucketSpec], - parameters: Map[String, String]): HadoopFsRelation = { + parameters: Map[String, String]): BucketedHadoopFsRelation = { new JSONRelation( inputRDD = None, @@ -72,7 +72,7 @@ private[sql] class JSONRelation( override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec, parameters) { + extends BucketedHadoopFsRelation(maybePartitionSpec, parameters) { def this( inputRDD: Option[RDD[String]], @@ -179,7 +179,7 @@ private[sql] class JSONRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = { - new OutputWriterFactory { + new BucketedOutputWriterFactory { override def newInstance( path: String, bucketId: Option[Int], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 581de8a7d3cd..18346e9def97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -44,14 +44,14 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.{BucketSpec, PartitionSpec} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "parquet" @@ -61,7 +61,7 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc schema: Option[StructType], partitionColumns: Option[StructType], bucketSpec: Option[BucketSpec], - parameters: Map[String, String]): HadoopFsRelation = { + parameters: Map[String, String]): BucketedHadoopFsRelation = { new ParquetRelation(paths, schema, None, partitionColumns, bucketSpec, parameters)(sqlContext) } } @@ -115,7 +115,7 @@ private[sql] class ParquetRelation( val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec, parameters) + extends BucketedHadoopFsRelation(maybePartitionSpec, parameters) with Logging { private[sql] def this( @@ -283,7 +283,7 @@ private[sql] class ParquetRelation( sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED).name()) - new OutputWriterFactory { + new BucketedOutputWriterFactory { override def newInstance( path: String, bucketId: Option[Int], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 8818353485bb..f2dcec90bd06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, BufferHolder} import org.apache.spark.sql.{AnalysisException, Row, SQLContext} -import org.apache.spark.sql.execution.datasources.{BucketSpec, PartitionSpec} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -39,7 +39,7 @@ import org.apache.spark.util.SerializableConfiguration /** * A data source for reading text files. */ -class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { override def createRelation( sqlContext: SQLContext, @@ -47,7 +47,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { dataSchema: Option[StructType], partitionColumns: Option[StructType], bucketSpec: Option[BucketSpec], - parameters: Map[String, String]): HadoopFsRelation = { + parameters: Map[String, String]): BucketedHadoopFsRelation = { dataSchema.foreach(verifySchema) new TextRelation(None, dataSchema, partitionColumns, bucketSpec, paths)(sqlContext) } @@ -75,7 +75,7 @@ private[sql] class TextRelation( override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec, parameters) { + extends BucketedHadoopFsRelation(maybePartitionSpec, parameters) { /** Data schema is always a single column, named "value" if original Data source has no schema. */ override def dataSchema: StructType = @@ -117,7 +117,7 @@ private[sql] class TextRelation( /** Write path. */ override def prepareJobForWrite(job: Job): OutputWriterFactory = { - new OutputWriterFactory { + new BucketedOutputWriterFactory { override def newInstance( path: String, bucketId: Option[Int], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 8100bf6cfec6..b22569fce288 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -160,7 +160,6 @@ trait HadoopFsRelationProvider { paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], - bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation } @@ -354,7 +353,6 @@ abstract class OutputWriterFactory extends Serializable { */ def newInstance( path: String, - bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter } @@ -582,8 +580,6 @@ abstract class HadoopFsRelation private[sql]( final def partitionColumns: StructType = userDefinedPartitionColumns.getOrElse(partitionSpec.partitionColumns) - def bucketSpec: Option[BucketSpec] - /** * Optional user defined partition columns. * diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 89324163414e..131a1b98f997 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -36,14 +36,14 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.{BucketSpec, PartitionSpec} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.SerializableConfiguration -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "orc" @@ -53,7 +53,7 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc dataSchema: Option[StructType], partitionColumns: Option[StructType], bucketSpec: Option[BucketSpec], - parameters: Map[String, String]): HadoopFsRelation = { + parameters: Map[String, String]): BucketedHadoopFsRelation = { assert( sqlContext.isInstanceOf[HiveContext], "The ORC data source can only be used with HiveContext.") @@ -159,7 +159,7 @@ private[sql] class OrcRelation( val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( @transient val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec, parameters) + extends BucketedHadoopFsRelation(maybePartitionSpec, parameters) with Logging { private[sql] def this( @@ -221,7 +221,7 @@ private[sql] class OrcRelation( classOf[MapRedOutputFormat[_, _]]) } - new OutputWriterFactory { + new BucketedOutputWriterFactory { override def newInstance( path: String, bucketId: Option[Int], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index ff58a3f6479b..82d006124dc6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -84,7 +84,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle for (row <- rows) { assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = math.abs(row.hashCode()) % 8 + val actualBucketId = (row.hashCode() % 8 + 8) % 8 assert(actualBucketId == bucketId) } } @@ -120,7 +120,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle for (row <- rows) { assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = math.abs(row.hashCode()) % 8 + val actualBucketId = (row.hashCode() % 8 + 8) % 8 assert(actualBucketId == bucketId) } } @@ -151,7 +151,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle for (row <- rows) { assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = math.abs(row.hashCode()) % 8 + val actualBucketId = (row.hashCode() % 8 + 8) % 8 assert(actualBucketId == bucketId) } } @@ -183,7 +183,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle for (row <- rows) { assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = math.abs(row.hashCode()) % 8 + val actualBucketId = (row.hashCode() % 8 + 8) % 8 assert(actualBucketId == bucketId) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 9f28e7881ae8..fde67ed2cb90 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -28,21 +28,21 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} -import org.apache.spark.sql.execution.datasources.BucketSpec +import org.apache.spark.sql.execution.datasources.{BucketedOutputWriterFactory, BucketedHadoopFsRelation, BucketedHadoopFsRelationProvider, BucketSpec} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLContext, sources} /** * A simple example [[HadoopFsRelationProvider]]. */ -class SimpleTextSource extends HadoopFsRelationProvider { +class SimpleTextSource extends BucketedHadoopFsRelationProvider { override def createRelation( sqlContext: SQLContext, paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], bucketSpec: Option[BucketSpec], - parameters: Map[String, String]): HadoopFsRelation = { + parameters: Map[String, String]): BucketedHadoopFsRelation = { new SimpleTextRelation(paths, schema, partitionColumns, bucketSpec, parameters)(sqlContext) } } @@ -97,7 +97,7 @@ class SimpleTextRelation( val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( @transient val sqlContext: SQLContext) - extends HadoopFsRelation(parameters) { + extends BucketedHadoopFsRelation(parameters) { import sqlContext.sparkContext @@ -181,7 +181,7 @@ class SimpleTextRelation( } } - override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + override def prepareJobForWrite(job: Job): OutputWriterFactory = new BucketedOutputWriterFactory { job.setOutputFormatClass(classOf[TextOutputFormat[_, _]]) override def newInstance( @@ -214,14 +214,14 @@ object SimpleTextRelation { /** * A simple example [[HadoopFsRelationProvider]]. */ -class CommitFailureTestSource extends HadoopFsRelationProvider { +class CommitFailureTestSource extends BucketedHadoopFsRelationProvider { override def createRelation( sqlContext: SQLContext, paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], bucketSpec: Option[BucketSpec], - parameters: Map[String, String]): HadoopFsRelation = { + parameters: Map[String, String]): BucketedHadoopFsRelation = { new CommitFailureTestRelation(paths, schema, partitionColumns, bucketSpec, parameters)( sqlContext) } @@ -240,7 +240,7 @@ class CommitFailureTestRelation( userDefinedPartitionColumns, bucketSpec, parameters)(sqlContext) { - override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + override def prepareJobForWrite(job: Job): OutputWriterFactory = new BucketedOutputWriterFactory { override def newInstance( path: String, bucketId: Option[Int], From 70ebd69190e1ebd27362e17240b20bf60b5fdf16 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 4 Jan 2016 22:41:15 +0800 Subject: [PATCH 08/15] debug --- .../scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 82d006124dc6..9930505ad544 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -71,6 +71,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle .saveAsTable("bucketedTable") val tableDir = new File(hiveContext.warehousePath, "bucketedTable") + logWarning(tableDir.listFiles().map(_.getAbsolutePath).mkString("\n")) for (i <- 0 until 5) { val allBucketFiles = new File(tableDir, s"i=$i").listFiles().filter(!_.isHidden) val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) From 6e3c1c0370dec30002992a3a83b2066f4d5278df Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Jan 2016 08:56:06 +0800 Subject: [PATCH 09/15] debug --- .../sql/sources/BucketedWriteSuite.scala | 69 +++++++++---------- 1 file changed, 31 insertions(+), 38 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 9930505ad544..dc9aea0b9ccf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -71,7 +71,8 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle .saveAsTable("bucketedTable") val tableDir = new File(hiveContext.warehousePath, "bucketedTable") - logWarning(tableDir.listFiles().map(_.getAbsolutePath).mkString("\n")) + logWarning(hiveContext.warehousePath.getAbsolutePath) + logWarning(hiveContext.warehousePath.listFiles().map(_.getAbsolutePath).mkString("\n")) for (i <- 0 until 5) { val allBucketFiles = new File(tableDir, s"i=$i").listFiles().filter(!_.isHidden) val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) @@ -79,15 +80,13 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle for ((bucketId, bucketFiles) <- groupedBucketFiles) { for (bucketFile <- bucketFiles) { - withSQLConf("spark.sql.parquet.enableUnsafeRowRecordReader" -> "false") { - val df = sqlContext.read.parquet(bucketFile.getAbsolutePath).select("j", "k") - val rows = df.queryExecution.toRdd.map(_.copy()).collect() - - for (row <- rows) { - assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = (row.hashCode() % 8 + 8) % 8 - assert(actualBucketId == bucketId) - } + val df = sqlContext.read.parquet(bucketFile.getAbsolutePath).select("j", "k") + val rows = df.queryExecution.toRdd.map(_.copy()).collect() + + for (row <- rows) { + assert(row.isInstanceOf[UnsafeRow]) + val actualBucketId = (row.hashCode() % 8 + 8) % 8 + assert(actualBucketId == bucketId) } } } @@ -114,16 +113,14 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle for ((bucketId, bucketFiles) <- groupedBucketFiles) { for (bucketFile <- bucketFiles) { - withSQLConf("spark.sql.parquet.enableUnsafeRowRecordReader" -> "false") { - val df = sqlContext.read.parquet(bucketFile.getAbsolutePath).select("j", "k") - checkAnswer(df.sort("k"), df.collect()) - val rows = df.select("j").queryExecution.toRdd.map(_.copy()).collect() - - for (row <- rows) { - assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = (row.hashCode() % 8 + 8) % 8 - assert(actualBucketId == bucketId) - } + val df = sqlContext.read.parquet(bucketFile.getAbsolutePath).select("j", "k") + checkAnswer(df.sort("k"), df.collect()) + val rows = df.select("j").queryExecution.toRdd.map(_.copy()).collect() + + for (row <- rows) { + assert(row.isInstanceOf[UnsafeRow]) + val actualBucketId = (row.hashCode() % 8 + 8) % 8 + assert(actualBucketId == bucketId) } } } @@ -146,15 +143,13 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle for ((bucketId, bucketFiles) <- groupedBucketFiles) { for (bucketFile <- bucketFiles) { - withSQLConf("spark.sql.parquet.enableUnsafeRowRecordReader" -> "false") { - val df = sqlContext.read.parquet(bucketFile.getAbsolutePath).select("i", "j") - val rows = df.queryExecution.toRdd.map(_.copy()).collect() + val df = sqlContext.read.parquet(bucketFile.getAbsolutePath).select("i", "j") + val rows = df.queryExecution.toRdd.map(_.copy()).collect() - for (row <- rows) { - assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = (row.hashCode() % 8 + 8) % 8 - assert(actualBucketId == bucketId) - } + for (row <- rows) { + assert(row.isInstanceOf[UnsafeRow]) + val actualBucketId = (row.hashCode() % 8 + 8) % 8 + assert(actualBucketId == bucketId) } } } @@ -177,16 +172,14 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle for ((bucketId, bucketFiles) <- groupedBucketFiles) { for (bucketFile <- bucketFiles) { - withSQLConf("spark.sql.parquet.enableUnsafeRowRecordReader" -> "false") { - val df = sqlContext.read.parquet(bucketFile.getAbsolutePath) - checkAnswer(df.sort("k"), df.collect()) - val rows = df.select("i", "j").queryExecution.toRdd.map(_.copy()).collect() - - for (row <- rows) { - assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = (row.hashCode() % 8 + 8) % 8 - assert(actualBucketId == bucketId) - } + val df = sqlContext.read.parquet(bucketFile.getAbsolutePath) + checkAnswer(df.sort("k"), df.collect()) + val rows = df.select("i", "j").queryExecution.toRdd.map(_.copy()).collect() + + for (row <- rows) { + assert(row.isInstanceOf[UnsafeRow]) + val actualBucketId = (row.hashCode() % 8 + 8) % 8 + assert(actualBucketId == bucketId) } } } From 3df61dcf76f7991a3fc47254a54e135ad2c044dd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Jan 2016 13:55:55 +0800 Subject: [PATCH 10/15] fix tests --- .../sql/sources/BucketedWriteSuite.scala | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index dc9aea0b9ccf..65992464133a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -63,16 +63,14 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle test("write bucketed data") { val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") - withTable("bucketedTable") { + withTable("bucketed_table") { df.write .format("parquet") .partitionBy("i") .bucketBy(8, "j", "k") - .saveAsTable("bucketedTable") + .saveAsTable("bucketed_table") - val tableDir = new File(hiveContext.warehousePath, "bucketedTable") - logWarning(hiveContext.warehousePath.getAbsolutePath) - logWarning(hiveContext.warehousePath.listFiles().map(_.getAbsolutePath).mkString("\n")) + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") for (i <- 0 until 5) { val allBucketFiles = new File(tableDir, s"i=$i").listFiles().filter(!_.isHidden) val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) @@ -96,15 +94,15 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle test("write bucketed data with sortBy") { val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") - withTable("bucketedTable") { + withTable("bucketed_table") { df.write .format("parquet") .partitionBy("i") .bucketBy(8, "j") .sortBy("k") - .saveAsTable("bucketedTable") + .saveAsTable("bucketed_table") - val tableDir = new File(hiveContext.warehousePath, "bucketedTable") + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") for (i <- 0 until 5) { val allBucketFiles = new File(tableDir, s"i=$i").listFiles() .filter(_.getName.startsWith("part")) @@ -130,13 +128,13 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle test("write bucketed data without partitionBy") { val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") - withTable("bucketedTable") { + withTable("bucketed_table") { df.write .format("parquet") .bucketBy(8, "i", "j") - .saveAsTable("bucketedTable") + .saveAsTable("bucketed_table") - val tableDir = new File(hiveContext.warehousePath, "bucketedTable") + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") val allBucketFiles = tableDir.listFiles().filter(_.getName.startsWith("part")) val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) assert(groupedBucketFiles.size <= 8) @@ -158,14 +156,14 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle test("write bucketed data without partitionBy with sortBy") { val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") - withTable("bucketedTable") { + withTable("bucketed_table") { df.write .format("parquet") .bucketBy(8, "i", "j") .sortBy("k") - .saveAsTable("bucketedTable") + .saveAsTable("bucketed_table") - val tableDir = new File(hiveContext.warehousePath, "bucketedTable") + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") val allBucketFiles = tableDir.listFiles().filter(_.getName.startsWith("part")) val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) assert(groupedBucketFiles.size <= 8) From d5f390d1d54bceafc7ed8bade76cd0831d095cac Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Jan 2016 16:18:15 +0800 Subject: [PATCH 11/15] refine --- .../apache/spark/sql/DataFrameWriter.scala | 4 +- .../sql/execution/datasources/bucket.scala | 4 +- .../datasources/json/JSONRelation.scala | 2 +- .../datasources/parquet/ParquetRelation.scala | 2 +- .../datasources/text/DefaultSource.scala | 2 +- .../spark/sql/hive/orc/OrcRelation.scala | 2 +- .../sql/sources/SimpleTextRelation.scala | 39 +++++++++++-------- 7 files changed, 31 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 8b05a10d0a87..528e496bd56f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -117,7 +117,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { /** * Partitions the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's partitioning schema. + * laid out on the file system similar to Hive's partitioning scheme. * * This was initially applicable for Parquet but in 1.5+ covers JSON, text, ORC and avro as well. * @@ -131,7 +131,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { /** * Buckets the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's bucketing schema. + * laid out on the file system similar to Hive's bucketing scheme. * * This is applicable for Parquet, JSON, text, ORC and avro. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala index 8f53272a2ce6..70865c62e201 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources -import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory, HadoopFsRelationProvider, HadoopFsRelation} @@ -72,6 +72,8 @@ private[sql] abstract class BucketedHadoopFsRelation( def this(parameters: Map[String, String]) = this(None, parameters) def bucketSpec: Option[BucketSpec] + + def prepareJobForWrite(job: Job): BucketedOutputWriterFactory } private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFactory { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 91e93245f0dc..da0ea30b3a87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -178,7 +178,7 @@ private[sql] class JSONRelation( partitionColumns) } - override def prepareJobForWrite(job: Job): OutputWriterFactory = { + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { new BucketedOutputWriterFactory { override def newInstance( path: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 23685fee0099..267465ba36e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -224,7 +224,7 @@ private[sql] class ParquetRelation( override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum - override def prepareJobForWrite(job: Job): OutputWriterFactory = { + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { val conf = ContextUtil.getConfiguration(job) // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index f2dcec90bd06..6f53c817907f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -116,7 +116,7 @@ private[sql] class TextRelation( } /** Write path. */ - override def prepareJobForWrite(job: Job): OutputWriterFactory = { + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { new BucketedOutputWriterFactory { override def newInstance( path: String, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 131a1b98f997..5e6089a04d53 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -210,7 +210,7 @@ private[sql] class OrcRelation( OrcTableScan(output, this, filters, inputPaths).execute() } - override def prepareJobForWrite(job: Job): OutputWriterFactory = { + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { job.getConfiguration match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index fde67ed2cb90..f588ad868761 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -181,15 +181,17 @@ class SimpleTextRelation( } } - override def prepareJobForWrite(job: Job): OutputWriterFactory = new BucketedOutputWriterFactory { + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { job.setOutputFormatClass(classOf[TextOutputFormat[_, _]]) - override def newInstance( - path: String, - bucketId: Option[Int], - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, bucketId, context) + new BucketedOutputWriterFactory { + override def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new SimpleTextOutputWriter(path, bucketId, context) + } } } @@ -240,16 +242,19 @@ class CommitFailureTestRelation( userDefinedPartitionColumns, bucketSpec, parameters)(sqlContext) { - override def prepareJobForWrite(job: Job): OutputWriterFactory = new BucketedOutputWriterFactory { - override def newInstance( - path: String, - bucketId: Option[Int], - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, bucketId, context) { - override def close(): Unit = { - super.close() - sys.error("Intentional task commitment failure for testing purpose.") + + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { + new BucketedOutputWriterFactory { + override def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new SimpleTextOutputWriter(path, bucketId, context) { + override def close(): Unit = { + super.close() + sys.error("Intentional task commitment failure for testing purpose.") + } } } } From 3ff968b29d3852c92952454254ae6e1f7ba6599d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Jan 2016 20:00:58 +0800 Subject: [PATCH 12/15] address comments --- .../apache/spark/sql/DataFrameWriter.scala | 24 +- .../sql/execution/datasources/DDLParser.scala | 2 +- .../InsertIntoHadoopFsRelation.scala | 31 +- .../datasources/ResolvedDataSource.scala | 8 +- .../datasources/WriterContainer.scala | 334 ++++++++++-------- .../sql/execution/datasources/bucket.scala | 30 +- .../spark/sql/execution/datasources/ddl.scala | 2 +- .../sql/execution/datasources/rules.scala | 4 +- .../apache/spark/sql/sources/interfaces.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 16 +- .../spark/sql/hive/execution/commands.scala | 2 +- 11 files changed, 261 insertions(+), 194 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 528e496bd56f..7e3d7907377d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -140,7 +140,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { @scala.annotation.varargs def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = { this.numBuckets = Option(numBuckets) - this.bucketingColumns = Option(colName +: colNames) + this.bucketColumnNames = Option(colName +: colNames) this } @@ -153,7 +153,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { */ @scala.annotation.varargs def sortBy(colName: String, colNames: String*): DataFrameWriter = { - this.sortingColumns = Option(colName +: colNames) + this.sortColumnNames = Option(colName +: colNames) this } @@ -224,28 +224,32 @@ final class DataFrameWriter private[sql](df: DataFrame) { cols.map(normalize(_, "Partition")) } - private def normalizedBucketCols: Option[Seq[String]] = bucketingColumns.map { cols => + private def normalizedBucketColNames: Option[Seq[String]] = bucketColumnNames.map { cols => cols.map(normalize(_, "Bucketing")) } - private def normalizedSortCols: Option[Seq[String]] = sortingColumns.map { cols => + private def normalizedSortColNames: Option[Seq[String]] = sortColumnNames.map { cols => cols.map(normalize(_, "Sorting")) } private def getBucketSpec: Option[BucketSpec] = { - if (sortingColumns.isDefined) { + if (sortColumnNames.isDefined) { require(numBuckets.isDefined, "sortBy must be used together with bucketBy") } for { n <- numBuckets - cols <- normalizedBucketCols } yield { require(n > 0, "Bucket number must be greater than 0.") - BucketSpec(n, cols, normalizedSortCols) + BucketSpec(n, normalizedBucketColNames.get, normalizedSortColNames.getOrElse(Nil)) } } + /** + * The given column name may not be equal to any of the existing column names if we were in + * case-insensitive context. Normalize the given column name to the real one so that we don't + * need to care about case sensitivity afterwards. + */ private def normalize(columnName: String, columnType: String): String = { val validColumnNames = df.logicalPlan.output.map(_.name) validColumnNames.find(df.sqlContext.analyzer.resolver(_, columnName)) @@ -254,7 +258,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def assertNotBucketed(): Unit = { - if (numBuckets.isDefined || sortingColumns.isDefined) { + if (numBuckets.isDefined || sortColumnNames.isDefined) { throw new IllegalArgumentException( "Currently we don't support writing bucketed data to this data source.") } @@ -435,9 +439,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var partitioningColumns: Option[Seq[String]] = None - private var bucketingColumns: Option[Seq[String]] = None + private var bucketColumnNames: Option[Seq[String]] = None private var numBuckets: Option[Int] = None - private var sortingColumns: Option[Seq[String]] = None + private var sortColumnNames: Option[Seq[String]] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index 32fef585a1b5..6b3671d51071 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -109,7 +109,7 @@ class DDLParser(parseQuery: String => LogicalPlan) provider, temp.isDefined, Array.empty[String], - None, + bucketSpec = None, mode, options, queryPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index aeeefecc7973..80da2bf720cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -136,16 +136,27 @@ private[sql] case class InsertIntoHadoopFsRelation( val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name)) - new DynamicPartitionWriterContainer( - relation, - job, - partitionOutput, - bucketSpec, - dataOutput, - output, - PartitioningUtils.DEFAULT_PARTITION_NAME, - sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES), - isAppend) + if (bucketSpec.isEmpty) { + new DynamicPartitionWriterContainer( + relation, + job, + partitionOutput, + dataOutput, + output, + PartitioningUtils.DEFAULT_PARTITION_NAME, + sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES), + isAppend) + } else { + new BucketedPartitionWriterContainer( + relation.asInstanceOf[BucketedHadoopFsRelation], + job, + partitionOutput, + bucketSpec.get, + dataOutput, + output, + PartitioningUtils.DEFAULT_PARTITION_NAME, + isAppend) + } } // This call shouldn't be put into the `try` block below because it only initializes and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 4aa451f48b13..87b3e0b44695 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -244,15 +244,15 @@ object ResolvedDataSource extends Logging { case provider: BucketedHadoopFsRelationProvider => provider.createRelation( sqlContext, Array(outputPath.toString), - Some(dataSchema.asNullable), - Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), + Option(dataSchema.asNullable), + Option(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), bucketSpec, caseInsensitiveOptions) case provider: HadoopFsRelationProvider => provider.createRelation( sqlContext, Array(outputPath.toString), - Some(dataSchema.asNullable), - Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), + Option(dataSchema.asNullable), + Option(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), caseInsensitiveOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 4840d8900d87..59c9609587e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -121,14 +121,15 @@ private[sql] abstract class BaseWriterContainer( } } - protected def newOutputWriter(path: String, bucketId: Option[Int]): OutputWriter = { + protected def newOutputWriter(path: String): OutputWriter = { + improveErrorMessage { + outputWriterFactory.newInstance(path, dataSchema, taskAttemptContext) + } + } + + protected def improveErrorMessage[T](f: => T): T = { try { - outputWriterFactory match { - case factory: BucketedOutputWriterFactory => - factory.newInstance(path, bucketId, dataSchema, taskAttemptContext) - case factory: OutputWriterFactory => - factory.newInstance(path, dataSchema, taskAttemptContext) - } + f } catch { case e: org.apache.hadoop.fs.FileAlreadyExistsException => if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) { @@ -252,7 +253,7 @@ private[sql] class DefaultWriterContainer( executorSideSetup(taskContext) val configuration = taskAttemptContext.getConfiguration configuration.set("spark.sql.sources.output.path", outputPath) - val writer = newOutputWriter(getWorkPath, None) + val writer = newOutputWriter(getWorkPath) writer.initConverter(dataSchema) var writerClosed = false @@ -310,7 +311,6 @@ private[sql] class DynamicPartitionWriterContainer( relation: HadoopFsRelation, job: Job, partitionColumns: Seq[Attribute], - bucketSpec: Option[BucketSpec], dataColumns: Seq[Attribute], inputSchema: Seq[Attribute], defaultPartitionName: String, @@ -318,35 +318,14 @@ private[sql] class DynamicPartitionWriterContainer( isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { - private def bucketColumns = bucketSpec.get.resolvedBucketingColumns(inputSchema) - private def sortColumns = bucketSpec.get.resolvedSortingColumns(inputSchema) - private def numBuckets = bucketSpec.get.numBuckets - private def bucketIdExpr = Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets)) - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] executorSideSetup(taskContext) var outputWritersCleared = false - val getKey = if (bucketSpec.isEmpty) { - UnsafeProjection.create(partitionColumns, inputSchema) - } else { // If it's bucketed, we should also consider bucket id as part of the key. - UnsafeProjection.create(partitionColumns :+ bucketIdExpr, inputSchema) - } - - val keySchema = if (bucketSpec.isEmpty) { - StructType.fromAttributes(partitionColumns) - } else { // If it's bucketed, we should also consider bucket id as part of the key. - StructType.fromAttributes(partitionColumns).add("bucketId", IntegerType, nullable = false) - } - - def getBucketId(key: InternalRow): Option[Int] = if (bucketSpec.isDefined) { - Some(key.getInt(partitionColumns.length)) - } else { - None - } - + // Returns the partition key given an input row + val getPartitionKey = UnsafeProjection.create(partitionColumns, inputSchema) // Returns the data columns to be written given an input row val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) @@ -369,125 +348,66 @@ private[sql] class DynamicPartitionWriterContainer( // If anything below fails, we should abort the task. try { - val mustSort = bucketSpec.isDefined && bucketSpec.get.sortingColumns.isDefined - // TODO: remove duplicated code. - if (mustSort) { - val getSortingKey = - UnsafeProjection.create((partitionColumns :+ bucketIdExpr) ++ sortColumns, inputSchema) - - val sortingKeySchema = { - val fields = StructType.fromAttributes(partitionColumns) - .add("bucketId", IntegerType, nullable = false) ++ - StructType.fromAttributes(sortColumns) - StructType(fields) + // This will be filled in if we have to fall back on sorting. + var sorter: UnsafeKVExternalSorter = null + while (iterator.hasNext && sorter == null) { + val inputRow = iterator.next() + val currentKey = getPartitionKey(inputRow) + var currentWriter = outputWriters.get(currentKey) + + if (currentWriter == null) { + if (outputWriters.size < maxOpenFiles) { + currentWriter = newOutputWriter(currentKey) + outputWriters.put(currentKey.copy(), currentWriter) + currentWriter.writeInternal(getOutputRow(inputRow)) + } else { + logInfo(s"Maximum partitions reached, falling back on sorting.") + sorter = new UnsafeKVExternalSorter( + StructType.fromAttributes(partitionColumns), + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + sorter.insertKV(currentKey, getOutputRow(inputRow)) + } + } else { + currentWriter.writeInternal(getOutputRow(inputRow)) } + } - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - + // If the sorter is not null that means that we reached the maxFiles above and need to finish + // using external sort. + if (sorter != null) { while (iterator.hasNext) { val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + sorter.insertKV(getPartitionKey(currentRow), getOutputRow(currentRow)) } logInfo(s"Sorting complete. Writing out partition files one at a time.") - def sameBucket(row1: InternalRow, row2: InternalRow): Boolean = { - if (row1.getInt(partitionColumns.length) != row2.getInt(partitionColumns.length)) { - false - } else { - var i = partitionColumns.length - 1 - while (i >= 0) { - val dt = partitionColumns(i).dataType - if (row1.get(i, dt) != row2.get(i, dt)) return false - i -= 1 - } - true - } - } val sortedIterator = sorter.sortedIterator() var currentKey: InternalRow = null var currentWriter: OutputWriter = null try { while (sortedIterator.next()) { - if (currentKey == null || !sameBucket(currentKey, sortedIterator.getKey)) { + if (currentKey != sortedIterator.getKey) { if (currentWriter != null) { currentWriter.close() } currentKey = sortedIterator.getKey.copy() logDebug(s"Writing partition: $currentKey") - currentWriter = newOutputWriter(currentKey) + + // Either use an existing file from before, or open a new one. + currentWriter = outputWriters.remove(currentKey) + if (currentWriter == null) { + currentWriter = newOutputWriter(currentKey) + } } + currentWriter.writeInternal(sortedIterator.getValue) } } finally { if (currentWriter != null) { currentWriter.close() } } - } else { - // This will be filled in if we have to fall back on sorting. - var sorter: UnsafeKVExternalSorter = null - while (iterator.hasNext && sorter == null) { - val inputRow = iterator.next() - val currentKey = getKey(inputRow) - var currentWriter = outputWriters.get(currentKey) - - if (currentWriter == null) { - if (outputWriters.size < maxOpenFiles) { - currentWriter = newOutputWriter(currentKey) - outputWriters.put(currentKey.copy(), currentWriter) - currentWriter.writeInternal(getOutputRow(inputRow)) - } else { - logInfo(s"Maximum partitions reached, falling back on sorting.") - sorter = new UnsafeKVExternalSorter( - keySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - sorter.insertKV(currentKey, getOutputRow(inputRow)) - } - } else { - currentWriter.writeInternal(getOutputRow(inputRow)) - } - } - - // If the sorter is not null that means that we reached the maxFiles above and need to - // finish using external sort. - if (sorter != null) { - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getKey(currentRow), getOutputRow(currentRow)) - } - - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val sortedIterator = sorter.sortedIterator() - var currentKey: InternalRow = null - var currentWriter: OutputWriter = null - try { - while (sortedIterator.next()) { - if (currentKey != sortedIterator.getKey) { - if (currentWriter != null) { - currentWriter.close() - } - currentKey = sortedIterator.getKey.copy() - logDebug(s"Writing partition: $currentKey") - - // Either use an existing file from before, or open a new one. - currentWriter = outputWriters.remove(currentKey) - if (currentWriter == null) { - currentWriter = newOutputWriter(currentKey) - } - } - - currentWriter.writeInternal(sortedIterator.getValue) - } - } finally { - if (currentWriter != null) { currentWriter.close() } - } - } } commitTask() @@ -500,18 +420,12 @@ private[sql] class DynamicPartitionWriterContainer( /** Open and returns a new OutputWriter given a partition key. */ def newOutputWriter(key: InternalRow): OutputWriter = { + val partitionPath = getPartitionString(key).getString(0) + val path = new Path(getWorkPath, partitionPath) val configuration = taskAttemptContext.getConfiguration - val path = if (partitionColumns.nonEmpty) { - val partitionPath = getPartitionString(key).getString(0) - configuration.set( - "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - new Path(getWorkPath, partitionPath).toString - } else { - configuration.set("spark.sql.sources.output.path", outputPath) - getWorkPath - } - val bucketId = getBucketId(key) - val newWriter = super.newOutputWriter(path, bucketId) + configuration.set( + "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) + val newWriter = super.newOutputWriter(path.toString) newWriter.initConverter(dataSchema) newWriter } @@ -543,3 +457,147 @@ private[sql] class DynamicPartitionWriterContainer( } } } + +/** + * A writer that dynamically opens files based on the given partition columns. Internally this is + * done by maintaining a HashMap of open files until `maxFiles` is reached. If this occurs, the + * writer externally sorts the remaining rows and then writes out them out one file at a time. + */ +private[sql] class BucketedPartitionWriterContainer( + relation: BucketedHadoopFsRelation, + job: Job, + partitionColumns: Seq[Attribute], + bucketSpec: BucketSpec, + dataColumns: Seq[Attribute], + inputSchema: Seq[Attribute], + defaultPartitionName: String, + isAppend: Boolean) + extends BaseWriterContainer(relation, job, isAppend) { + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + executorSideSetup(taskContext) + + val bucketColumns = bucketSpec.bucketColumnNames.map(c => inputSchema.find(_.name == c).get) + val numBuckets = bucketSpec.numBuckets + val sortColumns = bucketSpec.sortColumnNames.map(c => inputSchema.find(_.name == c).get) + val bucketIdExpr = Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets)) + + val getSortingKey = + UnsafeProjection.create((partitionColumns :+ bucketIdExpr) ++ sortColumns, inputSchema) + + val sortingKeySchema = { + val fields = StructType.fromAttributes(partitionColumns) + .add("bucketId", IntegerType, nullable = false) ++ + StructType.fromAttributes(sortColumns) + StructType(fields) + } + + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) + + // Expressions that given a partition key build a string like: col1=val/col2=val/... + val partitionStringExpression = partitionColumns.zipWithIndex.flatMap { case (c, i) => + val escaped = + ScalaUDF( + PartitioningUtils.escapePathName _, + StringType, + Seq(Cast(c, StringType)), + Seq(StringType)) + val str = If(IsNull(c), Literal(defaultPartitionName), escaped) + val partitionName = Literal(c.name + "=") :: str :: Nil + if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName + } + + // Returns the partition path given a partition key. + val getPartitionString = + UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) + + // If anything below fails, we should abort the task. + try { + // TODO: remove duplicated code. + // TODO: if sorting columns are empty, we can keep all writers in a hash map and avoid sorting + // here. + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + def sameBucket(row1: InternalRow, row2: InternalRow): Boolean = { + if (row1.getInt(partitionColumns.length) != row2.getInt(partitionColumns.length)) { + false + } else { + var i = partitionColumns.length - 1 + while (i >= 0) { + val dt = partitionColumns(i).dataType + if (row1.get(i, dt) != row2.get(i, dt)) return false + i -= 1 + } + true + } + } + val sortedIterator = sorter.sortedIterator() + var currentKey: InternalRow = null + var currentWriter: OutputWriter = null + try { + while (sortedIterator.next()) { + if (currentKey == null || !sameBucket(currentKey, sortedIterator.getKey)) { + if (currentWriter != null) { + currentWriter.close() + } + currentKey = sortedIterator.getKey.copy() + logDebug(s"Writing partition: $currentKey") + currentWriter = newOutputWriter(currentKey) + } + currentWriter.writeInternal(sortedIterator.getValue) + } + } finally { + if (currentWriter != null) { currentWriter.close() } + } + + commitTask() + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + + /** Open and returns a new OutputWriter given a partition key. */ + def newOutputWriter(key: InternalRow): OutputWriter = { + val configuration = taskAttemptContext.getConfiguration + val path = if (partitionColumns.nonEmpty) { + val partitionPath = getPartitionString(key).getString(0) + configuration.set( + "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) + new Path(getWorkPath, partitionPath).toString + } else { + configuration.set("spark.sql.sources.output.path", outputPath) + getWorkPath + } + val bucketId = key.getInt(partitionColumns.length) + val newWriter = improveErrorMessage { + outputWriterFactory.asInstanceOf[BucketedOutputWriterFactory].newInstance( + path, Some(bucketId), dataSchema, taskAttemptContext) + } + newWriter.initConverter(dataSchema) + newWriter + } + + def commitTask(): Unit = { + try { + super.commitTask() + } catch { + case cause: Throwable => + throw new RuntimeException("Failed to commit task", cause) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala index 70865c62e201..dc6d4488fe71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -23,27 +23,19 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory, HadoopFsRelationProvider, HadoopFsRelation} import org.apache.spark.sql.types.StructType +/** + * A container for bucketing information. + * Bucketing is a technology for decomposing data sets into more manageable parts, and the number + * of buckets is fixed so it does not fluctuate with data. + * + * @param numBuckets number of buckets. + * @param bucketColumnNames the names of the columns that used to generate the bucket id. + * @param sortColumnNames the names of the columns that used to sort data in each bucket. + */ private[sql] case class BucketSpec( numBuckets: Int, - bucketingColumns: Seq[String], - sortingColumns: Option[Seq[String]]) { - - def resolvedBucketingColumns(inputSchema: Seq[Attribute]): Seq[Attribute] = { - bucketingColumns.map { col => - inputSchema.find(_.name == col).get - } - } - - def resolvedSortingColumns(inputSchema: Seq[Attribute]): Seq[Attribute] = { - if (sortingColumns.isDefined) { - sortingColumns.get.map { col => - inputSchema.find(_.name == col).get - } - } else { - Nil - } - } -} + bucketColumnNames: Seq[String], + sortColumnNames: Seq[String]) private[sql] trait BucketedHadoopFsRelationProvider extends HadoopFsRelationProvider { final override def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 21fea9e6eeb9..19581c9794ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -114,7 +114,7 @@ case class CreateTempTableUsingAsSelect( sqlContext, provider, partitionColumns, - None, + bucketSpec = None, mode, options, df) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 54e46e23ea65..d356838448dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -194,12 +194,12 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => PartitioningUtils.validatePartitionColumnDataTypes( c.child.schema, c.partitionColumns, catalog.conf.caseSensitiveAnalysis) - c.bucketSpec.foreach(_.sortingColumns.foreach(_.foreach { sortCol => + c.bucketSpec.foreach(_.sortColumnNames.foreach { sortCol => val dataType = c.child.schema.find(_.name == sortCol).get.dataType if (!RowOrdering.isOrderable(dataType)) { failAnalysis(s"Cannot use ${dataType.simpleString} for sorting column.") } - })) + }) case _ => // OK } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b22569fce288..19ca5d8691d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} -import org.apache.spark.sql.execution.datasources.{BucketSpec, PartitioningUtils, PartitionSpec, Partition} +import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f3df871e7540..aa3627808441 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -242,17 +242,19 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } if (userSpecifiedSchema.isDefined && bucketSpec.isDefined) { - val BucketSpec(numBuckets, bucketColumns, sortColumns) = bucketSpec.get + val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec.get tableProperties.put("spark.sql.sources.schema.numBuckets", numBuckets.toString) - tableProperties.put("spark.sql.sources.schema.numBucketCols", bucketColumns.length.toString) - bucketColumns.zipWithIndex.foreach { case (bucketCol, index) => + tableProperties.put("spark.sql.sources.schema.numBucketCols", + bucketColumnNames.length.toString) + bucketColumnNames.zipWithIndex.foreach { case (bucketCol, index) => tableProperties.put(s"spark.sql.sources.schema.bucketCol.$index", bucketCol) } - if (sortColumns.isDefined) { - tableProperties.put("spark.sql.sources.schema.numSortCols", sortColumns.get.length.toString) - sortColumns.get.zipWithIndex.foreach { case (sortCol, index) => + if (sortColumnNames.nonEmpty) { + tableProperties.put("spark.sql.sources.schema.numSortCols", + sortColumnNames.length.toString) + sortColumnNames.zipWithIndex.foreach { case (sortCol, index) => tableProperties.put(s"spark.sql.sources.schema.sortCol.$index", sortCol) } } @@ -614,7 +616,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive conf.defaultDataSourceName, temporary = false, Array.empty[String], - None, + bucketSpec = None, mode, options = Map.empty[String, String], child diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index bb0b81925a2d..612f01cda88b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -151,7 +151,7 @@ case class CreateMetastoreDataSource( tableIdent, userSpecifiedSchema, Array.empty[String], - None, + bucketSpec = None, provider, optionsWithPath, isExternal) From 74bd52461f381f67da737b8c9db595b09c77ad8d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Jan 2016 21:05:35 +0800 Subject: [PATCH 13/15] improve test --- .../apache/spark/sql/DataFrameWriter.scala | 4 +- .../datasources/text/DefaultSource.scala | 27 +-- .../sql/sources/BucketedWriteSuite.scala | 183 ++++++++---------- .../sql/sources/SimpleTextRelation.scala | 77 +++----- 4 files changed, 119 insertions(+), 172 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7e3d7907377d..6496a8d757b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -133,7 +133,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * Buckets the output by the given columns on the file system. If specified, the output is * laid out on the file system similar to Hive's bucketing scheme. * - * This is applicable for Parquet, JSON, text, ORC and avro. + * This is applicable for Parquet, JSON and ORC. * * @since 2.0 */ @@ -147,7 +147,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { /** * Sorts the bucketed output by the given columns. * - * This is applicable for Parquet, JSON, text, ORC and avro. + * This is applicable for Parquet, JSON and ORC. * * @since 2.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 6f53c817907f..fe69c72d28cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, BufferHolder} import org.apache.spark.sql.{AnalysisException, Row, SQLContext} -import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -39,17 +39,16 @@ import org.apache.spark.util.SerializableConfiguration /** * A data source for reading text files. */ -class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { override def createRelation( sqlContext: SQLContext, paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], - bucketSpec: Option[BucketSpec], - parameters: Map[String, String]): BucketedHadoopFsRelation = { + parameters: Map[String, String]): HadoopFsRelation = { dataSchema.foreach(verifySchema) - new TextRelation(None, dataSchema, partitionColumns, bucketSpec, paths)(sqlContext) + new TextRelation(None, dataSchema, partitionColumns, paths)(sqlContext) } override def shortName(): String = "text" @@ -71,11 +70,10 @@ private[sql] class TextRelation( val maybePartitionSpec: Option[PartitionSpec], val textSchema: Option[StructType], override val userDefinedPartitionColumns: Option[StructType], - val bucketSpec: Option[BucketSpec], override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) - extends BucketedHadoopFsRelation(maybePartitionSpec, parameters) { + extends HadoopFsRelation(maybePartitionSpec, parameters) { /** Data schema is always a single column, named "value" if original Data source has no schema. */ override def dataSchema: StructType = @@ -116,14 +114,13 @@ private[sql] class TextRelation( } /** Write path. */ - override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { - new BucketedOutputWriterFactory { + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + new OutputWriterFactory { override def newInstance( path: String, - bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new TextOutputWriter(path, bucketId, context) + new TextOutputWriter(path, dataSchema, context) } } } @@ -139,10 +136,7 @@ private[sql] class TextRelation( } } -class TextOutputWriter( - path: String, - bucketId: Option[Int], - context: TaskAttemptContext) +class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { private[this] val buffer = new Text() @@ -154,8 +148,7 @@ class TextOutputWriter( val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId - val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } }.getRecordWriter(context) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 65992464133a..d156c43966b9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sources import java.io.File +import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -54,132 +55,108 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt")) } - private val parquetFileName = """.*-(\d+)\..*\.parquet""".r + private val testFileName = """.*-(\d+)$""".r + private val otherFileName = """.*-(\d+)\..*""".r private def getBucketId(fileName: String): Int = { fileName match { - case parquetFileName(bucketId) => bucketId.toInt + case testFileName(bucketId) => bucketId.toInt + case otherFileName(bucketId) => bucketId.toInt } } + private def testBucketing( + dataDir: File, + source: String, + bucketCols: Seq[String], + sortCols: Seq[String] = Nil): Unit = { + val allBucketFiles = dataDir.listFiles().filterNot(f => + f.getName.startsWith(".") || f.getName.startsWith("_") + ) + val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) + assert(groupedBucketFiles.size <= 8) + + for ((bucketId, bucketFiles) <- groupedBucketFiles) { + for (bucketFile <- bucketFiles) { + val df = sqlContext.read.format(source).load(bucketFile.getAbsolutePath) + .select((bucketCols ++ sortCols).map(col): _*) + + if (sortCols.nonEmpty) { + checkAnswer(df.sort(sortCols.map(col): _*), df.collect()) + } + + val rows = df.select(bucketCols.map(col): _*).queryExecution.toRdd.map(_.copy()).collect() + + for (row <- rows) { + assert(row.isInstanceOf[UnsafeRow]) + val actualBucketId = (row.hashCode() % 8 + 8) % 8 + assert(actualBucketId == bucketId) + } + } + } + } + + private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + test("write bucketed data") { - val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") - withTable("bucketed_table") { - df.write - .format("parquet") - .partitionBy("i") - .bucketBy(8, "j", "k") - .saveAsTable("bucketed_table") - - val tableDir = new File(hiveContext.warehousePath, "bucketed_table") - for (i <- 0 until 5) { - val allBucketFiles = new File(tableDir, s"i=$i").listFiles().filter(!_.isHidden) - val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) - assert(groupedBucketFiles.size <= 8) - - for ((bucketId, bucketFiles) <- groupedBucketFiles) { - for (bucketFile <- bucketFiles) { - val df = sqlContext.read.parquet(bucketFile.getAbsolutePath).select("j", "k") - val rows = df.queryExecution.toRdd.map(_.copy()).collect() - - for (row <- rows) { - assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = (row.hashCode() % 8 + 8) % 8 - assert(actualBucketId == bucketId) - } - } + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .partitionBy("i") + .bucketBy(8, "j", "k") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + for (i <- 0 until 5) { + testBucketing(new File(tableDir, s"i=$i"), source, Seq("j", "k")) } } } } test("write bucketed data with sortBy") { - val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") - withTable("bucketed_table") { - df.write - .format("parquet") - .partitionBy("i") - .bucketBy(8, "j") - .sortBy("k") - .saveAsTable("bucketed_table") - - val tableDir = new File(hiveContext.warehousePath, "bucketed_table") - for (i <- 0 until 5) { - val allBucketFiles = new File(tableDir, s"i=$i").listFiles() - .filter(_.getName.startsWith("part")) - val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) - assert(groupedBucketFiles.size <= 8) - - for ((bucketId, bucketFiles) <- groupedBucketFiles) { - for (bucketFile <- bucketFiles) { - val df = sqlContext.read.parquet(bucketFile.getAbsolutePath).select("j", "k") - checkAnswer(df.sort("k"), df.collect()) - val rows = df.select("j").queryExecution.toRdd.map(_.copy()).collect() - - for (row <- rows) { - assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = (row.hashCode() % 8 + 8) % 8 - assert(actualBucketId == bucketId) - } - } + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .partitionBy("i") + .bucketBy(8, "j") + .sortBy("k") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + for (i <- 0 until 5) { + testBucketing(new File(tableDir, s"i=$i"), source, Seq("j"), Seq("k")) } } } } test("write bucketed data without partitionBy") { - val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") - withTable("bucketed_table") { - df.write - .format("parquet") - .bucketBy(8, "i", "j") - .saveAsTable("bucketed_table") - - val tableDir = new File(hiveContext.warehousePath, "bucketed_table") - val allBucketFiles = tableDir.listFiles().filter(_.getName.startsWith("part")) - val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) - assert(groupedBucketFiles.size <= 8) - - for ((bucketId, bucketFiles) <- groupedBucketFiles) { - for (bucketFile <- bucketFiles) { - val df = sqlContext.read.parquet(bucketFile.getAbsolutePath).select("i", "j") - val rows = df.queryExecution.toRdd.map(_.copy()).collect() - - for (row <- rows) { - assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = (row.hashCode() % 8 + 8) % 8 - assert(actualBucketId == bucketId) - } - } + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .bucketBy(8, "i", "j") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + testBucketing(tableDir, source, Seq("i", "j")) } } } test("write bucketed data without partitionBy with sortBy") { - val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") - withTable("bucketed_table") { - df.write - .format("parquet") - .bucketBy(8, "i", "j") - .sortBy("k") - .saveAsTable("bucketed_table") - - val tableDir = new File(hiveContext.warehousePath, "bucketed_table") - val allBucketFiles = tableDir.listFiles().filter(_.getName.startsWith("part")) - val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) - assert(groupedBucketFiles.size <= 8) - - for ((bucketId, bucketFiles) <- groupedBucketFiles) { - for (bucketFile <- bucketFiles) { - val df = sqlContext.read.parquet(bucketFile.getAbsolutePath) - checkAnswer(df.sort("k"), df.collect()) - val rows = df.select("i", "j").queryExecution.toRdd.map(_.copy()).collect() - - for (row <- rows) { - assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = (row.hashCode() % 8 + 8) % 8 - assert(actualBucketId == bucketId) - } - } + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .bucketBy(8, "i", "j") + .sortBy("k") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + testBucketing(tableDir, source, Seq("i", "j"), Seq("k")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index f588ad868761..e10d21d5e368 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -28,28 +28,24 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} -import org.apache.spark.sql.execution.datasources.{BucketedOutputWriterFactory, BucketedHadoopFsRelation, BucketedHadoopFsRelationProvider, BucketSpec} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLContext, sources} /** * A simple example [[HadoopFsRelationProvider]]. */ -class SimpleTextSource extends BucketedHadoopFsRelationProvider { +class SimpleTextSource extends HadoopFsRelationProvider { override def createRelation( sqlContext: SQLContext, paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], - bucketSpec: Option[BucketSpec], - parameters: Map[String, String]): BucketedHadoopFsRelation = { - new SimpleTextRelation(paths, schema, partitionColumns, bucketSpec, parameters)(sqlContext) + parameters: Map[String, String]): HadoopFsRelation = { + new SimpleTextRelation(paths, schema, partitionColumns, parameters)(sqlContext) } } -class AppendingTextOutputFormat( - outputFile: Path, - bucketId: Option[Int]) extends TextOutputFormat[NullWritable, Text] { +class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullWritable, Text] { val numberFormat = NumberFormat.getInstance() numberFormat.setMinimumIntegerDigits(5) @@ -58,20 +54,16 @@ class AppendingTextOutputFormat( override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId val name = FileOutputFormat.getOutputName(context) - new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId$bucketString") + new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") } } -class SimpleTextOutputWriter( - path: String, - bucketId: Option[Int], - context: TaskAttemptContext) extends OutputWriter { +class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { private val recordWriter: RecordWriter[NullWritable, Text] = - new AppendingTextOutputFormat(new Path(path), bucketId).getRecordWriter(context) + new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) override def write(row: Row): Unit = { val serialized = row.toSeq.map { v => @@ -94,10 +86,9 @@ class SimpleTextRelation( override val paths: Array[String], val maybeDataSchema: Option[StructType], override val userDefinedPartitionColumns: Option[StructType], - val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( @transient val sqlContext: SQLContext) - extends BucketedHadoopFsRelation(parameters) { + extends HadoopFsRelation(parameters) { import sqlContext.sparkContext @@ -181,17 +172,14 @@ class SimpleTextRelation( } } - override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { + override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { job.setOutputFormatClass(classOf[TextOutputFormat[_, _]]) - new BucketedOutputWriterFactory { - override def newInstance( - path: String, - bucketId: Option[Int], - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, bucketId, context) - } + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new SimpleTextOutputWriter(path, context) } } @@ -216,16 +204,14 @@ object SimpleTextRelation { /** * A simple example [[HadoopFsRelationProvider]]. */ -class CommitFailureTestSource extends BucketedHadoopFsRelationProvider { +class CommitFailureTestSource extends HadoopFsRelationProvider { override def createRelation( sqlContext: SQLContext, paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], - bucketSpec: Option[BucketSpec], - parameters: Map[String, String]): BucketedHadoopFsRelation = { - new CommitFailureTestRelation(paths, schema, partitionColumns, bucketSpec, parameters)( - sqlContext) + parameters: Map[String, String]): HadoopFsRelation = { + new CommitFailureTestRelation(paths, schema, partitionColumns, parameters)(sqlContext) } } @@ -233,28 +219,19 @@ class CommitFailureTestRelation( override val paths: Array[String], maybeDataSchema: Option[StructType], override val userDefinedPartitionColumns: Option[StructType], - bucketSpec: Option[BucketSpec], parameters: Map[String, String])( @transient sqlContext: SQLContext) extends SimpleTextRelation( - paths, - maybeDataSchema, - userDefinedPartitionColumns, - bucketSpec, - parameters)(sqlContext) { - - override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { - new BucketedOutputWriterFactory { - override def newInstance( - path: String, - bucketId: Option[Int], - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, bucketId, context) { - override def close(): Unit = { - super.close() - sys.error("Intentional task commitment failure for testing purpose.") - } + paths, maybeDataSchema, userDefinedPartitionColumns, parameters)(sqlContext) { + override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new SimpleTextOutputWriter(path, context) { + override def close(): Unit = { + super.close() + sys.error("Intentional task commitment failure for testing purpose.") } } } From d3200cf8bdffaf0025b25c30d7c3d9fef8a6f9a8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 6 Jan 2016 19:59:25 +0800 Subject: [PATCH 14/15] simplification --- .../apache/spark/sql/DataFrameWriter.scala | 8 +- .../InsertIntoHadoopFsRelation.scala | 37 +- .../datasources/ResolvedDataSource.scala | 22 +- .../datasources/WriterContainer.scala | 365 ++++++++---------- .../sql/execution/datasources/bucket.scala | 34 +- .../datasources/json/JSONRelation.scala | 6 +- .../datasources/parquet/ParquetRelation.scala | 6 +- .../apache/spark/sql/sources/interfaces.scala | 27 +- .../spark/sql/hive/orc/OrcRelation.scala | 6 +- 9 files changed, 212 insertions(+), 299 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 112e01fdb89b..00f9817b5397 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -129,8 +129,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { } /** - * Buckets the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's bucketing scheme. + * Buckets the output by the given columns. If specified, the output is laid out on the file + * system similar to Hive's bucketing scheme. * * This is applicable for Parquet, JSON and ORC. * @@ -144,7 +144,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } /** - * Sorts the bucketed output by the given columns. + * Sorts the output in each bucket by the given columns. * * This is applicable for Parquet, JSON and ORC. * @@ -239,7 +239,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { for { n <- numBuckets } yield { - require(n > 0, "Bucket number must be greater than 0.") + require(n > 0 && n < 100000, "Bucket number must be greater than 0 and less than 100000.") BucketSpec(n, normalizedBucketColNames.get, normalizedSortColNames.getOrElse(Nil)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 403f4c0919db..7a8691e7cb9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -125,39 +125,22 @@ private[sql] case class InsertIntoHadoopFsRelation( |Actual: ${partitionColumns.mkString(", ")} """.stripMargin) - val bucketSpec = relation match { - case relation: BucketedHadoopFsRelation => relation.bucketSpec - case _ => None - } - - val writerContainer = if (partitionColumns.isEmpty && bucketSpec.isEmpty) { + val writerContainer = if (partitionColumns.isEmpty && relation.bucketSpec.isEmpty) { new DefaultWriterContainer(relation, job, isAppend) } else { val output = df.queryExecution.executedPlan.output val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name)) - if (bucketSpec.isEmpty) { - new DynamicPartitionWriterContainer( - relation, - job, - partitionOutput, - dataOutput, - output, - PartitioningUtils.DEFAULT_PARTITION_NAME, - sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES), - isAppend) - } else { - new BucketedPartitionWriterContainer( - relation.asInstanceOf[BucketedHadoopFsRelation], - job, - partitionOutput, - bucketSpec.get, - dataOutput, - output, - PartitioningUtils.DEFAULT_PARTITION_NAME, - isAppend) - } + new DynamicPartitionWriterContainer( + relation, + job, + partitionOutput, + dataOutput, + output, + PartitioningUtils.DEFAULT_PARTITION_NAME, + sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES), + isAppend) } // This call shouldn't be put into the `try` block below because it only initializes and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 86e1e3c2f624..ece9b8a9a917 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -240,21 +240,13 @@ object ResolvedDataSource extends Logging { val equality = columnNameEquality(caseSensitive) val dataSchema = StructType( data.schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) - val r = dataSource match { - case provider: BucketedHadoopFsRelationProvider => provider.createRelation( - sqlContext, - Array(outputPath.toString), - Option(dataSchema.asNullable), - Option(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), - bucketSpec, - caseInsensitiveOptions) - case provider: HadoopFsRelationProvider => provider.createRelation( - sqlContext, - Array(outputPath.toString), - Option(dataSchema.asNullable), - Option(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), - caseInsensitiveOptions) - } + val r = dataSource.createRelation( + sqlContext, + Array(outputPath.toString), + Some(dataSchema.asNullable), + Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), + bucketSpec, + caseInsensitiveOptions) // For partitioned relation r, r.schema's column ordering can be different from the column // ordering of data.logicalPlan (partition columns are all moved after data column). This diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 9c38d64db9e6..5e580c0be07a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.types.{IntegerType, StructType, StringType} import org.apache.spark.util.SerializableConfiguration @@ -121,15 +121,9 @@ private[sql] abstract class BaseWriterContainer( } } - protected def newOutputWriter(path: String): OutputWriter = { - improveErrorMessage { - outputWriterFactory.newInstance(path, dataSchema, taskAttemptContext) - } - } - - protected def improveErrorMessage[T](f: => T): T = { + protected def newOutputWriter(path: String, bucketId: Option[Int] = None): OutputWriter = { try { - f + outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext) } catch { case e: org.apache.hadoop.fs.FileAlreadyExistsException => if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) { @@ -318,19 +312,23 @@ private[sql] class DynamicPartitionWriterContainer( isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] - executorSideSetup(taskContext) + private val bucketSpec = relation.bucketSpec - var outputWritersCleared = false + private val bucketColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { + spec => spec.bucketColumnNames.map(c => inputSchema.find(_.name == c).get) + } - // Returns the partition key given an input row - val getPartitionKey = UnsafeProjection.create(partitionColumns, inputSchema) - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) + private val sortColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => inputSchema.find(_.name == c).get) + } + + private def bucketIdExpression: Option[Expression] = for { + BucketSpec(numBuckets, _, _) <- bucketSpec + } yield Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets)) - // Expressions that given a partition key build a string like: col1=val/col2=val/... - val partitionStringExpression = partitionColumns.zipWithIndex.flatMap { case (c, i) => + // Expressions that given a partition key build a string like: col1=val/col2=val/... + private def partitionStringExpression: Seq[Expression] = { + partitionColumns.zipWithIndex.flatMap { case (c, i) => val escaped = ScalaUDF( PartitioningUtils.escapePathName _, @@ -341,6 +339,117 @@ private[sql] class DynamicPartitionWriterContainer( val partitionName = Literal(c.name + "=") :: str :: Nil if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } + } + + private def getBucketIdFromKey(key: InternalRow): Option[Int] = { + if (bucketSpec.isDefined) { + Some(key.getInt(partitionColumns.length)) + } else { + None + } + } + + private def sameBucket(key1: UnsafeRow, key2: UnsafeRow): Boolean = { + val bucketIdIndex = partitionColumns.length + if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) { + false + } else { + var i = partitionColumns.length - 1 + while (i >= 0) { + val dt = partitionColumns(i).dataType + if (key1.get(i, dt) != key2.get(i, dt)) return false + i -= 1 + } + true + } + } + + private def sortBasedWrite( + sorter: UnsafeKVExternalSorter, + iterator: Iterator[InternalRow], + getSortingKey: UnsafeProjection, + getOutputRow: UnsafeProjection, + getPartitionString: UnsafeProjection, + outputWriters: java.util.HashMap[InternalRow, OutputWriter]): Unit = { + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) { + (key1, key2) => key1 != key2 + } else { + (key1, key2) => key1 == null || !sameBucket(key1, key2) + } + + val sortedIterator = sorter.sortedIterator() + var currentKey: UnsafeRow = null + var currentWriter: OutputWriter = null + try { + while (sortedIterator.next()) { + if (needNewWriter(currentKey, sortedIterator.getKey)) { + if (currentWriter != null) { + currentWriter.close() + } + currentKey = sortedIterator.getKey.copy() + logDebug(s"Writing partition: $currentKey") + + // Either use an existing file from before, or open a new one. + currentWriter = outputWriters.remove(currentKey) + if (currentWriter == null) { + currentWriter = newOutputWriter(currentKey, getPartitionString) + } + } + + currentWriter.writeInternal(sortedIterator.getValue) + } + } finally { + if (currentWriter != null) { currentWriter.close() } + } + } + + /** Open and returns a new OutputWriter given a partition key and optional bucket id. */ + private def newOutputWriter( + key: InternalRow, + getPartitionString: UnsafeProjection): OutputWriter = { + val configuration = taskAttemptContext.getConfiguration + val path = if (partitionColumns.nonEmpty) { + val partitionPath = getPartitionString(key).getString(0) + configuration.set( + "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) + new Path(getWorkPath, partitionPath).toString + } else { + configuration.set("spark.sql.sources.output.path", outputPath) + getWorkPath + } + val bucketId = getBucketIdFromKey(key) + val newWriter = super.newOutputWriter(path, bucketId) + newWriter.initConverter(dataSchema) + newWriter + } + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] + executorSideSetup(taskContext) + + var outputWritersCleared = false + + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val getSortingKey = + UnsafeProjection.create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema) + + val sortingKeySchema = if (bucketSpec.isEmpty) { + StructType.fromAttributes(partitionColumns) + } else { // If it's bucketed, we should also consider bucket id as part of the key. + val fields = StructType.fromAttributes(partitionColumns) + .add("bucketId", IntegerType, nullable = false) ++ StructType.fromAttributes(sortColumns) + StructType(fields) + } + + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) // Returns the partition path given a partition key. val getPartitionString = @@ -348,22 +457,34 @@ private[sql] class DynamicPartitionWriterContainer( // If anything below fails, we should abort the task. try { - // This will be filled in if we have to fall back on sorting. - var sorter: UnsafeKVExternalSorter = null + // If there is no sorting columns, we set sorter to null and try the hash-based writing first, + // and fill the sorter if there are too many writers and we need to fall back on sorting. + // If there are sorting columns, then we have to sort the data anyway, and no need to try the + // hash-based writing first. + var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) { + new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + } else { + null + } while (iterator.hasNext && sorter == null) { val inputRow = iterator.next() - val currentKey = getPartitionKey(inputRow) + // When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key. + val currentKey = getSortingKey(inputRow) var currentWriter = outputWriters.get(currentKey) if (currentWriter == null) { if (outputWriters.size < maxOpenFiles) { - currentWriter = newOutputWriter(currentKey) + currentWriter = newOutputWriter(currentKey, getPartitionString) outputWriters.put(currentKey.copy(), currentWriter) currentWriter.writeInternal(getOutputRow(inputRow)) } else { logInfo(s"Maximum partitions reached, falling back on sorting.") sorter = new UnsafeKVExternalSorter( - StructType.fromAttributes(partitionColumns), + sortingKeySchema, StructType.fromAttributes(dataColumns), SparkEnv.get.blockManager, TaskContext.get().taskMemoryManager().pageSizeBytes) @@ -375,39 +496,15 @@ private[sql] class DynamicPartitionWriterContainer( } // If the sorter is not null that means that we reached the maxFiles above and need to finish - // using external sort. + // using external sort, or there are sorting columns and we need to sort the whole data set. if (sorter != null) { - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getPartitionKey(currentRow), getOutputRow(currentRow)) - } - - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val sortedIterator = sorter.sortedIterator() - var currentKey: InternalRow = null - var currentWriter: OutputWriter = null - try { - while (sortedIterator.next()) { - if (currentKey != sortedIterator.getKey) { - if (currentWriter != null) { - currentWriter.close() - } - currentKey = sortedIterator.getKey.copy() - logDebug(s"Writing partition: $currentKey") - - // Either use an existing file from before, or open a new one. - currentWriter = outputWriters.remove(currentKey) - if (currentWriter == null) { - currentWriter = newOutputWriter(currentKey) - } - } - - currentWriter.writeInternal(sortedIterator.getValue) - } - } finally { - if (currentWriter != null) { currentWriter.close() } - } + sortBasedWrite( + sorter, + iterator, + getSortingKey, + getOutputRow, + getPartitionString, + outputWriters) } commitTask() @@ -418,18 +515,6 @@ private[sql] class DynamicPartitionWriterContainer( throw new SparkException("Task failed while writing rows.", cause) } - /** Open and returns a new OutputWriter given a partition key. */ - def newOutputWriter(key: InternalRow): OutputWriter = { - val partitionPath = getPartitionString(key).getString(0) - val path = new Path(getWorkPath, partitionPath) - val configuration = taskAttemptContext.getConfiguration - configuration.set( - "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - val newWriter = super.newOutputWriter(path.toString) - newWriter.initConverter(dataSchema) - newWriter - } - def clearOutputWriters(): Unit = { if (!outputWritersCleared) { outputWriters.asScala.values.foreach(_.close()) @@ -457,147 +542,3 @@ private[sql] class DynamicPartitionWriterContainer( } } } - -/** - * A writer that dynamically opens files based on the given partition columns. Internally this is - * done by maintaining a HashMap of open files until `maxFiles` is reached. If this occurs, the - * writer externally sorts the remaining rows and then writes out them out one file at a time. - */ -private[sql] class BucketedPartitionWriterContainer( - relation: BucketedHadoopFsRelation, - job: Job, - partitionColumns: Seq[Attribute], - bucketSpec: BucketSpec, - dataColumns: Seq[Attribute], - inputSchema: Seq[Attribute], - defaultPartitionName: String, - isAppend: Boolean) - extends BaseWriterContainer(relation, job, isAppend) { - - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - executorSideSetup(taskContext) - - val bucketColumns = bucketSpec.bucketColumnNames.map(c => inputSchema.find(_.name == c).get) - val numBuckets = bucketSpec.numBuckets - val sortColumns = bucketSpec.sortColumnNames.map(c => inputSchema.find(_.name == c).get) - val bucketIdExpr = Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets)) - - val getSortingKey = - UnsafeProjection.create((partitionColumns :+ bucketIdExpr) ++ sortColumns, inputSchema) - - val sortingKeySchema = { - val fields = StructType.fromAttributes(partitionColumns) - .add("bucketId", IntegerType, nullable = false) ++ - StructType.fromAttributes(sortColumns) - StructType(fields) - } - - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) - - // Expressions that given a partition key build a string like: col1=val/col2=val/... - val partitionStringExpression = partitionColumns.zipWithIndex.flatMap { case (c, i) => - val escaped = - ScalaUDF( - PartitioningUtils.escapePathName _, - StringType, - Seq(Cast(c, StringType)), - Seq(StringType)) - val str = If(IsNull(c), Literal(defaultPartitionName), escaped) - val partitionName = Literal(c.name + "=") :: str :: Nil - if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName - } - - // Returns the partition path given a partition key. - val getPartitionString = - UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) - - // If anything below fails, we should abort the task. - try { - // TODO: remove duplicated code. - // TODO: if sorting columns are empty, we can keep all writers in a hash map and avoid sorting - // here. - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } - - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - def sameBucket(row1: InternalRow, row2: InternalRow): Boolean = { - if (row1.getInt(partitionColumns.length) != row2.getInt(partitionColumns.length)) { - false - } else { - var i = partitionColumns.length - 1 - while (i >= 0) { - val dt = partitionColumns(i).dataType - if (row1.get(i, dt) != row2.get(i, dt)) return false - i -= 1 - } - true - } - } - val sortedIterator = sorter.sortedIterator() - var currentKey: InternalRow = null - var currentWriter: OutputWriter = null - try { - while (sortedIterator.next()) { - if (currentKey == null || !sameBucket(currentKey, sortedIterator.getKey)) { - if (currentWriter != null) { - currentWriter.close() - } - currentKey = sortedIterator.getKey.copy() - logDebug(s"Writing partition: $currentKey") - currentWriter = newOutputWriter(currentKey) - } - currentWriter.writeInternal(sortedIterator.getValue) - } - } finally { - if (currentWriter != null) { currentWriter.close() } - } - - commitTask() - } catch { - case cause: Throwable => - logError("Aborting task.", cause) - abortTask() - throw new SparkException("Task failed while writing rows.", cause) - } - - /** Open and returns a new OutputWriter given a partition key. */ - def newOutputWriter(key: InternalRow): OutputWriter = { - val configuration = taskAttemptContext.getConfiguration - val path = if (partitionColumns.nonEmpty) { - val partitionPath = getPartitionString(key).getString(0) - configuration.set( - "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - new Path(getWorkPath, partitionPath).toString - } else { - configuration.set("spark.sql.sources.output.path", outputPath) - getWorkPath - } - val bucketId = key.getInt(partitionColumns.length) - val newWriter = improveErrorMessage { - outputWriterFactory.asInstanceOf[BucketedOutputWriterFactory].newInstance( - path, Some(bucketId), dataSchema, taskAttemptContext) - } - newWriter.initConverter(dataSchema) - newWriter - } - - def commitTask(): Unit = { - try { - super.commitTask() - } catch { - case cause: Throwable => - throw new RuntimeException("Failed to commit task", cause) - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala index dc6d4488fe71..86e5c4a985ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.execution.datasources -import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory, HadoopFsRelationProvider, HadoopFsRelation} import org.apache.spark.sql.types.StructType @@ -44,28 +43,7 @@ private[sql] trait BucketedHadoopFsRelationProvider extends HadoopFsRelationProv dataSchema: Option[StructType], partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation = - createRelation(sqlContext, paths, dataSchema, partitionColumns, None, parameters) - - def createRelation( - sqlContext: SQLContext, - paths: Array[String], - dataSchema: Option[StructType], - partitionColumns: Option[StructType], - bucketSpec: Option[BucketSpec], - parameters: Map[String, String]): BucketedHadoopFsRelation -} - -private[sql] abstract class BucketedHadoopFsRelation( - maybePartitionSpec: Option[PartitionSpec], - parameters: Map[String, String]) - extends HadoopFsRelation(maybePartitionSpec, parameters) { - def this() = this(None, Map.empty[String, String]) - - def this(parameters: Map[String, String]) = this(None, parameters) - - def bucketSpec: Option[BucketSpec] - - def prepareJobForWrite(job: Job): BucketedOutputWriterFactory + createRelation(sqlContext, paths, dataSchema, partitionColumns, bucketSpec = None, parameters) } private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFactory { @@ -73,11 +51,5 @@ private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFact path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = - newInstance(path, None, dataSchema, context) - - def newInstance( - path: String, - bucketId: Option[Int], - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter + throw new UnsupportedOperationException("use bucket version") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index fd2af77f314f..b92edf65bfb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -50,7 +50,7 @@ class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegi dataSchema: Option[StructType], partitionColumns: Option[StructType], bucketSpec: Option[BucketSpec], - parameters: Map[String, String]): BucketedHadoopFsRelation = { + parameters: Map[String, String]): HadoopFsRelation = { new JSONRelation( inputRDD = None, @@ -68,11 +68,11 @@ private[sql] class JSONRelation( val maybeDataSchema: Option[StructType], val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - val bucketSpec: Option[BucketSpec], + override val bucketSpec: Option[BucketSpec], override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) - extends BucketedHadoopFsRelation(maybePartitionSpec, parameters) { + extends HadoopFsRelation(maybePartitionSpec, parameters) { def this( inputRDD: Option[RDD[String]], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 83e456a75f72..4b375de05e9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -61,7 +61,7 @@ private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with D schema: Option[StructType], partitionColumns: Option[StructType], bucketSpec: Option[BucketSpec], - parameters: Map[String, String]): BucketedHadoopFsRelation = { + parameters: Map[String, String]): HadoopFsRelation = { new ParquetRelation(paths, schema, None, partitionColumns, bucketSpec, parameters)(sqlContext) } } @@ -112,10 +112,10 @@ private[sql] class ParquetRelation( // This is for metastore conversion. private val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - val bucketSpec: Option[BucketSpec], + override val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( val sqlContext: SQLContext) - extends BucketedHadoopFsRelation(maybePartitionSpec, parameters) + extends HadoopFsRelation(maybePartitionSpec, parameters) with Logging { private[sql] def this( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index aece7566b4ff..119ba73fb788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} -import org.apache.spark.sql.execution.datasources.{Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.execution.datasources.{BucketSpec, Partition, PartitioningUtils, PartitionSpec} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -161,6 +161,20 @@ trait HadoopFsRelationProvider { dataSchema: Option[StructType], partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation + + // TODO: expose bucket API to users. + private[sql] def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], + parameters: Map[String, String]): HadoopFsRelation = { + if (bucketSpec.isDefined) { + throw new AnalysisException("Currently we don't support bucketing for this data source.") + } + createRelation(sqlContext, paths, dataSchema, partitionColumns, bucketSpec, parameters) + } } /** @@ -355,6 +369,14 @@ abstract class OutputWriterFactory extends Serializable { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter + + // TODO: expose bucket API to users. + private[sql] def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = + newInstance(path, dataSchema, context) } /** @@ -438,6 +460,9 @@ abstract class HadoopFsRelation private[sql]( private var _partitionSpec: PartitionSpec = _ + // TODO: expose bucket API to users. + private[sql] def bucketSpec: Option[BucketSpec] = None + private class FileStatusCache { var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 5005010c0d39..14fa152c2331 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -53,7 +53,7 @@ private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with D dataSchema: Option[StructType], partitionColumns: Option[StructType], bucketSpec: Option[BucketSpec], - parameters: Map[String, String]): BucketedHadoopFsRelation = { + parameters: Map[String, String]): HadoopFsRelation = { assert( sqlContext.isInstanceOf[HiveContext], "The ORC data source can only be used with HiveContext.") @@ -156,10 +156,10 @@ private[sql] class OrcRelation( maybeDataSchema: Option[StructType], maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - val bucketSpec: Option[BucketSpec], + override val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( @transient val sqlContext: SQLContext) - extends BucketedHadoopFsRelation(maybePartitionSpec, parameters) + extends HadoopFsRelation(maybePartitionSpec, parameters) with Logging { private[sql] def this( From 1afd3ee78484ce56dd04446bd43adab96a677411 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 6 Jan 2016 20:37:54 +0800 Subject: [PATCH 15/15] minor update --- .../sql/execution/datasources/WriterContainer.scala | 6 +++++- .../spark/sql/execution/datasources/bucket.scala | 2 ++ .../spark/sql/execution/datasources/rules.scala | 13 ++++++++----- .../org/apache/spark/sql/sources/interfaces.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 1 - .../spark/sql/sources/BucketedWriteSuite.scala | 8 +++++++- 6 files changed, 23 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 5e580c0be07a..4f8524f4b967 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -410,7 +410,11 @@ private[sql] class DynamicPartitionWriterContainer( } } - /** Open and returns a new OutputWriter given a partition key and optional bucket id. */ + /** + * Open and returns a new OutputWriter given a partition key and optional bucket id. + * If bucket id is specified, we will append it to the end of the file name, but before the + * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + */ private def newOutputWriter( key: InternalRow, getPartitionString: UnsafeProjection): OutputWriter = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala index 86e5c4a985ee..82287c896713 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -43,6 +43,8 @@ private[sql] trait BucketedHadoopFsRelationProvider extends HadoopFsRelationProv dataSchema: Option[StructType], partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation = + // TODO: throw exception here as we won't call this method during execution, after bucketed read + // support is finished. createRelation(sqlContext, paths, dataSchema, partitionColumns, bucketSpec = None, parameters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 04e9373d7ff1..d484403d1c64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -194,12 +194,15 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => PartitioningUtils.validatePartitionColumnDataTypes( c.child.schema, c.partitionColumns, catalog.conf.caseSensitiveAnalysis) - c.bucketSpec.foreach(_.sortColumnNames.foreach { sortCol => - val dataType = c.child.schema.find(_.name == sortCol).get.dataType - if (!RowOrdering.isOrderable(dataType)) { - failAnalysis(s"Cannot use ${dataType.simpleString} for sorting column.") + for { + spec <- c.bucketSpec + sortColumnName <- spec.sortColumnNames + sortColumn <- c.child.schema.find(_.name == sortColumnName) + } { + if (!RowOrdering.isOrderable(sortColumn.dataType)) { + failAnalysis(s"Cannot use ${sortColumn.dataType.simpleString} for sorting column.") } - }) + } case _ => // OK } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 119ba73fb788..c35f33132f60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -173,7 +173,7 @@ trait HadoopFsRelationProvider { if (bucketSpec.isDefined) { throw new AnalysisException("Currently we don't support bucketing for this data source.") } - createRelation(sqlContext, paths, dataSchema, partitionColumns, bucketSpec, parameters) + createRelation(sqlContext, paths, dataSchema, partitionColumns, parameters) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 4d1f243aa7d2..bf65325d54fe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1479,5 +1479,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } - } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index d156c43966b9..579da0291f29 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -33,9 +33,10 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle intercept[AnalysisException](df.write.bucketBy(2, "k").saveAsTable("tt")) } - test("numBuckets not greater than 0") { + test("numBuckets not greater than 0 or less than 100000") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") intercept[IllegalArgumentException](df.write.bucketBy(0, "i").saveAsTable("tt")) + intercept[IllegalArgumentException](df.write.bucketBy(100000, "i").saveAsTable("tt")) } test("specify sorting columns without bucketing columns") { @@ -48,6 +49,11 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle intercept[AnalysisException](df.write.bucketBy(2, "i").sortBy("j").saveAsTable("tt")) } + test("write bucketed data to unsupported data source") { + val df = Seq(Tuple1("a"), Tuple1("b")).toDF("i") + intercept[AnalysisException](df.write.bucketBy(3, "i").format("text").saveAsTable("tt")) + } + test("write bucketed data to non-hive-table or existing hive table") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") intercept[IllegalArgumentException](df.write.bucketBy(2, "i").parquet("/tmp/path"))