Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ trait CatalystConf {

def warehousePath: String

def repartitionColumnarData: Boolean

/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
Expand All @@ -55,5 +57,6 @@ case class SimpleCatalystConf(
optimizerInSetConversionThreshold: Int = 10,
maxCaseBranchesForCodegen: Int = 20,
runSQLonFile: Boolean = true,
warehousePath: String = "/user/hive/warehouse")
warehousePath: String = "/user/hive/warehouse",
repartitionColumnarData: Boolean = false)
extends CatalystConf
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class Analyzer(
ResolveWindowOrder ::
ResolveWindowFrame ::
ResolveNaturalAndUsingJoin ::
ResolveOutputColumns ::
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
Expand Down Expand Up @@ -445,7 +446,7 @@ class Analyzer(
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _, _) if child.resolved =>
val table = lookupTableFromCatalog(u)
// adding the table's partitions or validate the query's partition info
table match {
Expand Down Expand Up @@ -499,6 +500,120 @@ class Analyzer(
}
}

object ResolveOutputColumns extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
case ins @ InsertIntoTable(relation: LogicalPlan, partition, _, _, _, _)
if ins.childrenResolved && !ins.resolved =>
resolveOutputColumns(ins, expectedColumns(relation, partition), relation.toString)
}

private def resolveOutputColumns(
insertInto: InsertIntoTable,
columns: Seq[Attribute],
relation: String) = {
val resolved = if (insertInto.isMatchByName) {
projectAndCastOutputColumns(columns, insertInto.child, relation)
} else {
castAndRenameOutputColumns(columns, insertInto.child, relation)
}

if (resolved == insertInto.child.output) {
insertInto
} else {
insertInto.copy(child = Project(resolved, insertInto.child))
}
}

/**
* Resolves output columns by input column name, adding casts if necessary.
*/
private def projectAndCastOutputColumns(
output: Seq[Attribute],
data: LogicalPlan,
relation: String): Seq[NamedExpression] = {
output.map { col =>
data.resolveQuoted(col.name, resolver) match {
case Some(inCol) if col.dataType != inCol.dataType =>
Alias(UpCast(inCol, col.dataType, Seq()), col.name)()
case Some(inCol) => inCol
case None =>
throw new AnalysisException(
s"Cannot resolve ${col.name} in ${data.output.mkString(",")}")
}
}
}

private def castAndRenameOutputColumns(
output: Seq[Attribute],
data: LogicalPlan,
relation: String): Seq[NamedExpression] = {
val outputNames = output.map(_.name)
// incoming expressions may not have names
val inputNames = data.output.flatMap(col => Option(col.name))
if (output.size > data.output.size) {
// always a problem
throw new AnalysisException(
s"""Not enough data columns to write into $relation:
|Data columns: ${data.output.mkString(",")}
|Table columns: ${outputNames.mkString(",")}""".stripMargin)
} else if (output.size < data.output.size) {
if (outputNames.toSet.subsetOf(inputNames.toSet)) {
throw new AnalysisException(
s"""Table column names are a subset of the input data columns:
|Data columns: ${inputNames.mkString(",")}
|Table columns: ${outputNames.mkString(",")}
|To write a subset of the columns by name, use df.write.byName.insertInto(...)"""
.stripMargin)
} else {
// be conservative and fail if there are too many columns
throw new AnalysisException(
s"""Extra data columns to write into $relation:
|Data columns: ${data.output.mkString(",")}
|Table columns: ${outputNames.mkString(",")}""".stripMargin)
}
} else {
// check for reordered names and warn. this may be on purpose, so it isn't an error.
if (outputNames.toSet == inputNames.toSet && outputNames != inputNames) {
logWarning(
s"""Data column names match the table in a different order:
|Data columns: ${inputNames.mkString(",")}
|Table columns: ${outputNames.mkString(",")}
|To map columns by name, use df.write.byName.insertInto(...)""".stripMargin)
}
}

data.output.zip(output).map {
case (in, out) if !in.dataType.sameType(out.dataType) =>
Alias(Cast(in, out.dataType), out.name)()
case (in, out) if in.name != out.name =>
Alias(in, out.name)()
case (in, _) => in
}
}

private def expectedColumns(
data: LogicalPlan,
partitionData: Map[String, Option[String]]): Seq[Attribute] = {
data match {
case partitioned: CatalogRelation =>
val tablePartitionNames = partitioned.catalogTable.partitionColumns.map(_.name)
val (inputPartCols, dataColumns) = data.output.partition { attr =>
tablePartitionNames.contains(attr.name)
}
// Get the dynamic partition columns in partition order
val dynamicNames = tablePartitionNames.filter(
name => partitionData.getOrElse(name, None).isEmpty)
val dynamicPartCols = dynamicNames.map { name =>
inputPartCols.find(_.name == name).getOrElse(
throw new AnalysisException(s"Cannot find partition column $name"))
}

dataColumns ++ dynamicPartCols
case _ => data.output
}
}
}

/**
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
* a logical plan node's children.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@ case class CatalogTable(
def partitionColumns: Seq[CatalogColumn] =
schema.filter { c => partitionColumnNames.contains(c.name) }

/** Columns this table is bucketed by. */
private[sql] val bucketColumns: Seq[CatalogColumn] = bucketColumnNames.flatMap { name =>
schema.find(_.name == name)
}

/** Columns this table is bucketed by. */
private[sql] val sortColumns: Seq[CatalogColumn] = sortColumnNames.flatMap { name =>
schema.find(_.name == name)
}

/** Return the database this table was specified to belong to, assuming it exists. */
def database: String = identifier.database.getOrElse {
throw new AnalysisException(s"table $identifier did not specify database")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ package object dsl {
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
analysis.UnresolvedRelation(TableIdentifier(tableName)),
Map.empty, logicalPlan, overwrite, false)
Map.empty, logicalPlan, overwrite, ifNotExists = false, Map.empty)

def as(alias: String): LogicalPlan = logicalPlan match {
case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator
Expand Down Expand Up @@ -63,6 +64,21 @@ case class SortOrder(child: Expression, direction: SortDirection)
def isAscending: Boolean = direction == Ascending
}

// TODO: should this be an implicit class somewhere?
object SortOrder {
def satisfies(order: Seq[SortOrder], distribution: Distribution): Boolean = {
distribution match {
case c @ ClusteredDistribution(exprs) =>
// Zip discards extra order by expressions
(order.size >= exprs.size) && exprs.zip(order.map(_.child)).forall {
case (clusterExpr, orderExpr) => clusterExpr.semanticEquals(orderExpr)
case _ => false
}
case _ => false
}
}
}

/**
* An expression to generate a 64-bit long prefix used in sorting.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -72,6 +73,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
Batch("Aggregate", fixedPoint,
RemoveLiteralFromGroupExpressions,
RemoveRepetitionFromGroupExpressions) ::
Batch("Output Layout Optimizations", Once,
DistributeAndSortOutputData(conf)) ::
Batch("Operator Optimizations", fixedPoint,
// Operator push down
SetOperationPushDown,
Expand Down Expand Up @@ -1737,3 +1740,82 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
}
}
}

case class DistributeAndSortOutputData(conf: CatalystConf) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case insertInto @ InsertIntoTable(rel: CatalogRelation, partition, data, _, _, _)
if insertInto.resolved && insertInto.writersPerPartition.isDefined =>
insertInto.copy(child =
buildRepartitionAndSort(rel.catalogTable, data, insertInto.writersPerPartition))

case insertInto @ InsertIntoTable(rel: CatalogRelation, partition, data, _, _, _)
if insertInto.resolved && requiresSort(rel.catalogTable) =>
insertInto.copy(child = buildSort(rel.catalogTable, data))

case insertInto @ InsertIntoTable(rel: CatalogRelation, partition, data, _, _, _)
if insertInto.resolved && isColumnar(rel.catalogTable) && shouldRepartition(data) =>
insertInto.copy(child = buildRepartitionAndSort(rel.catalogTable, data, None))
}

private def isColumnar(table: CatalogTable): Boolean = {
table.storage.serde.map(_.toLowerCase)
.forall(serde => serde.contains("parquet") || serde.contains("orc"))
}

private def shouldRepartition(plan: LogicalPlan): Boolean = {
// automatically add repartitioning for columnar formats if enabled and doesn't conflict
conf.repartitionColumnarData && !hasSortOrRepartition(plan);
}

private def hasSortOrRepartition(plan: LogicalPlan): Boolean = {
plan.collectFirst {
case _: RepartitionByExpression => true
case _: Sort => true
}.getOrElse(false)
}

private def requiresSort(table: CatalogTable): Boolean = {
(table.bucketColumnNames.size + table.sortColumnNames.size) > 0
}

private def buildSort(table: CatalogTable, data: LogicalPlan): LogicalPlan = {
val partitionExprs = asExpr(table.partitionColumns, data)
val bucketExpr = asBucketExpr(table.bucketColumns, table.numBuckets, data)
val sortExprs = partitionExprs ++ bucketExpr ++ asExpr(table.sortColumns, data)
// add a sort without a repartition
Sort(sortExprs.map(expr => SortOrder(expr, Ascending)), global = false, data)
}

private def buildRepartitionAndSort(
table: CatalogTable,
data: LogicalPlan,
numWriters: Option[Int]): LogicalPlan = {
val partitionExprs = asExpr(table.partitionColumns, data) ++ asDistributeExpr(numWriters)
val bucketExpr = asBucketExpr(table.bucketColumns, table.numBuckets, data)
val sortExprs = partitionExprs ++ bucketExpr ++ asExpr(table.sortColumns, data)

// add a sort with an inner repartition
Sort(
sortExprs.map(expr => SortOrder(expr, Ascending)),
global = false,
RepartitionByExpression(partitionExprs, data, None))
}

private def asExpr(columns: Seq[CatalogColumn], data: LogicalPlan): Seq[Attribute] = {
columns.map(col => data.output.find(_.name == col.name).get)
}

private def asDistributeExpr(numWriters: Option[Int]): Option[Expression] = {
numWriters.map(n => Pmod(Cast(Multiply(Rand(0L), Literal(n)), IntegerType), Literal(n)))
}

private def asBucketExpr(columns: Seq[CatalogColumn], numBuckets: Int,
data: LogicalPlan): Option[Expression] = {
if (columns.isEmpty) {
None
} else {
Some(HashPartitioning(asExpr(columns, data), numBuckets).partitionIdExpression)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
UnresolvedRelation(tableIdent, None),
partitionKeys,
query,
ctx.OVERWRITE != null,
ctx.EXISTS != null)
overwrite = ctx.OVERWRITE != null,
ifNotExists = ctx.EXISTS != null,
Map.empty /* SQL always matches by position */)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,28 +348,45 @@ case class InsertIntoTable(
partition: Map[String, Option[String]],
child: LogicalPlan,
overwrite: Boolean,
ifNotExists: Boolean)
ifNotExists: Boolean,
options: Map[String, String])
extends LogicalPlan {

override def children: Seq[LogicalPlan] = child :: Nil
override def output: Seq[Attribute] = Seq.empty

private[spark] def isMatchByName: Boolean = {
options.get("matchByName").map(_.toBoolean).getOrElse(false)
}

private[spark] def writersPerPartition: Option[Int] = {
options.get("writersPerPartition").map(_.toInt)
}

private[spark] lazy val expectedColumns = {
if (table.output.isEmpty) {
None
} else {
val numDynamicPartitions = partition.values.count(_.isEmpty)
val dynamicPartitionNames = partition.filter {
case (name, Some(_)) => false
case (name, None) => true
}.keySet
val (partitionColumns, dataColumns) = table.output
.partition(a => partition.keySet.contains(a.name))
Some(dataColumns ++ partitionColumns.takeRight(numDynamicPartitions))
Some(dataColumns ++ partitionColumns.filter(col => dynamicPartitionNames.contains(col.name)))
}
}

assert(overwrite || !ifNotExists)
override lazy val resolved: Boolean = childrenResolved && expectedColumns.forall { expected =>
child.output.size == expected.size && child.output.zip(expected).forall {
case (childAttr, tableAttr) =>
DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType)
override lazy val resolved: Boolean = childrenResolved && {
expectedColumns match {
case Some(expected) =>
child.output.size == expected.size && child.output.zip(expected).forall {
case (childAttr, tableAttr) =>
childAttr.name == tableAttr.name && // required by some relations
DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType)
}
case None => true
}
}
}
Expand Down
Loading