Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
* commands can be used by parsers to represent DDL operations. Commands, unlike queries, are
* eagerly executed.
*/
trait Command extends LeafNode {
trait Command extends LogicalPlan {
override def output: Seq[Attribute] = Seq.empty
override def children: Seq[LogicalPlan] = Seq.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// Can we automate these 'pass through' operations?
object BasicOperators extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case r: RunnableCommand => ExecutedCommandExec(r) :: Nil
case r: RunnableCommand => ExecutedCommandExec(r, r.children.map(planLater)) :: Nil

case MemoryPlan(sink, output) =>
val encoder = RowEncoder(sink.schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
Expand All @@ -36,14 +35,20 @@ import org.apache.spark.sql.types._
* wrapped in `ExecutedCommand` during execution.
*/
trait RunnableCommand extends logical.Command {
def run(sparkSession: SparkSession): Seq[Row]
def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
throw new NotImplementedError
}

def run(sparkSession: SparkSession): Seq[Row] = {
throw new NotImplementedError
}
}

/**
* A physical operator that executes the run method of a `RunnableCommand` and
* saves the result to prevent multiple executions.
*/
case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan {
case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) extends SparkPlan {
/**
* A concrete command should override this lazy field to wrap up any side effects caused by the
* command or any other computation that should be evaluated exactly once. The value of this field
Expand All @@ -55,14 +60,17 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan {
*/
protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow])
val rows = if (children.isEmpty) {
cmd.run(sqlContext.sparkSession)
} else {
cmd.run(sqlContext.sparkSession, children)
}
rows.map(converter(_).asInstanceOf[InternalRow])
}

override protected def innerChildren: Seq[QueryPlan[_]] = cmd :: Nil

override def output: Seq[Attribute] = cmd.output

override def children: Seq[SparkPlan] = Nil
override def nodeName: String = cmd.nodeName

override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
Expand Down Expand Up @@ -408,16 +408,6 @@ case class DataSource(
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive)

// SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does
// not need to have the query as child, to avoid to analyze an optimized query,
// because InsertIntoHadoopFsRelationCommand will be optimized first.
val partitionAttributes = partitionColumns.map { name =>
val plan = data.logicalPlan
plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse {
throw new AnalysisException(
s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]")
}.asInstanceOf[Attribute]
}
val fileIndex = catalogTable.map(_.identifier).map { tableIdent =>
sparkSession.table(tableIdent).queryExecution.analyzed.collect {
case LogicalRelation(t: HadoopFsRelation, _, _) => t.location
Expand All @@ -430,7 +420,7 @@ case class DataSource(
InsertIntoHadoopFsRelationCommand(
outputPath = outputPath,
staticPartitions = Map.empty,
partitionColumns = partitionAttributes,
partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted),
bucketSpec = bucketSpec,
fileFormat = format,
options = options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.execution.{SQLExecution, SortExec, SparkPlan}
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.{SerializableConfiguration, Utils}


Expand Down Expand Up @@ -96,7 +96,7 @@ object FileFormatWriter extends Logging {
*/
def write(
sparkSession: SparkSession,
queryExecution: QueryExecution,
plan: SparkPlan,
fileFormat: FileFormat,
committer: FileCommitProtocol,
outputSpec: OutputSpec,
Expand All @@ -111,9 +111,9 @@ object FileFormatWriter extends Logging {
job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))

val allColumns = queryExecution.logical.output
val allColumns = plan.output
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains)
val dataColumns = allColumns.filterNot(partitionSet.contains)

val bucketIdExpression = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
Expand Down Expand Up @@ -151,7 +151,7 @@ object FileFormatWriter extends Logging {
// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
// the sort order doesn't matter
val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child)
val actualOrdering = plan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
false
} else {
Expand All @@ -170,12 +170,12 @@ object FileFormatWriter extends Logging {

try {
val rdd = if (orderingMatched) {
queryExecution.toRdd
plan.execute()
} else {
SortExec(
requiredOrdering.map(SortOrder(_, Ascending)),
global = false,
child = queryExecution.executedPlan).execute()
child = plan).execute()
}
val ret = new Array[WriteTaskResult](rdd.partitions.length)
sparkSession.sparkContext.runJob(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command._

/**
Expand All @@ -50,12 +51,13 @@ case class InsertIntoHadoopFsRelationCommand(
catalogTable: Option[CatalogTable],
fileIndex: Option[FileIndex])
extends RunnableCommand {

import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName

override protected def innerChildren: Seq[LogicalPlan] = query :: Nil
override def children: Seq[LogicalPlan] = query :: Nil

override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
assert(children.length == 1)

override def run(sparkSession: SparkSession): Seq[Row] = {
// Most formats don't do well with duplicate columns, so lets not allow that
if (query.schema.fieldNames.length != query.schema.fieldNames.distinct.length) {
val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect {
Expand Down Expand Up @@ -136,7 +138,7 @@ case class InsertIntoHadoopFsRelationCommand(

FileFormatWriter.write(
sparkSession = sparkSession,
queryExecution = Dataset.ofRows(sparkSession, query).queryExecution,
plan = children.head,
fileFormat = fileFormat,
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ import org.apache.hadoop.hive.ql.ErrorMsg
import org.apache.hadoop.hive.ql.plan.TableDesc

import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession}
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.FileFormatWriter
import org.apache.spark.sql.hive._
Expand Down Expand Up @@ -80,7 +81,7 @@ case class InsertIntoHiveTable(
overwrite: Boolean,
ifNotExists: Boolean) extends RunnableCommand {

override protected def innerChildren: Seq[LogicalPlan] = query :: Nil
override def children: Seq[LogicalPlan] = query :: Nil

var createdTempDir: Option[Path] = None

Expand Down Expand Up @@ -217,7 +218,9 @@ case class InsertIntoHiveTable(
* `org.apache.hadoop.hive.serde2.SerDe` and the
* `org.apache.hadoop.mapred.OutputFormat` provided by the table definition.
*/
override def run(sparkSession: SparkSession): Seq[Row] = {
override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
assert(children.length == 1)

val sessionState = sparkSession.sessionState
val externalCatalog = sparkSession.sharedState.externalCatalog
val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version
Expand Down Expand Up @@ -310,7 +313,7 @@ case class InsertIntoHiveTable(

FileFormatWriter.write(
sparkSession = sparkSession,
queryExecution = Dataset.ofRows(sparkSession, query).queryExecution,
plan = children.head,
fileFormat = new HiveFileFormat(fileSinkConf),
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(tmpLocation.toString, Map.empty),
Expand Down