Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
}

protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext = self.sparkContext
val sparkContext: SparkContext = self.sparkContext

val sqlContext: SQLContext = self

def numPartitions = self.numShufflePartitions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.SQLContext

/**
* :: DeveloperApi ::
Expand All @@ -41,7 +42,7 @@ case class Aggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)(@transient sc: SparkContext)
child: SparkPlan)(@transient sqlContext: SQLContext)
extends UnaryNode with NoBind {

override def requiredChildDistribution =
Expand All @@ -55,7 +56,7 @@ case class Aggregate(
}
}

override def otherCopyArgs = sc :: Nil
override def otherCopyArgs = sqlContext :: Nil

// HACK: Generators don't correctly preserve their output through serializations so we grab
// out child's output attributes statically here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
execution.LeftSemiJoinBNL(
planLater(left), planLater(right), condition)(sparkContext) :: Nil
planLater(left), planLater(right), condition)(sqlContext) :: Nil
case _ => Nil
}
}
Expand Down Expand Up @@ -103,7 +103,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
partial = true,
groupingExpressions,
partialComputation,
planLater(child))(sparkContext))(sparkContext) :: Nil
planLater(child))(sqlContext))(sqlContext) :: Nil
} else {
Nil
}
Expand All @@ -115,7 +115,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
execution.BroadcastNestedLoopJoin(
planLater(left), planLater(right), joinType, condition)(sparkContext) :: Nil
planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil
case _ => Nil
}
}
Expand Down Expand Up @@ -143,7 +143,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object TakeOrdered extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: Nil
execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil
case _ => Nil
}
}
Expand All @@ -155,9 +155,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val relation =
ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
// Note: overwrite=false because otherwise the metadata we just created will be deleted
InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sparkContext) :: Nil
InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil
InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
val prunePushedDownFilters =
if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
Expand Down Expand Up @@ -186,7 +186,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
projectList,
filters,
prunePushedDownFilters,
ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil
ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil

case _ => Nil
}
Expand All @@ -211,7 +211,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Distinct(child) =>
execution.Aggregate(
partial = false, child.output, child.output, planLater(child))(sparkContext) :: Nil
partial = false, child.output, child.output, planLater(child))(sqlContext) :: Nil
case logical.Sort(sortExprs, child) =>
// This sort is a global sort. Its requiredDistribution will be an OrderedDistribution.
execution.Sort(sortExprs, global = true, planLater(child)):: Nil
Expand All @@ -224,7 +224,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child))(sparkContext) :: Nil
execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
Expand All @@ -233,9 +233,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row))
execution.ExistingRdd(output, dataAsRdd) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child))(sparkContext) :: Nil
execution.Limit(limit, planLater(child))(sqlContext) :: Nil
case Unions(unionChildren) =>
execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil
execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil
case logical.Generate(generator, join, outer, _, child) =>
execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil
case logical.NoRelation =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.{HashPartitioner, SparkConf, SparkContext}
import org.apache.spark.{HashPartitioner, SparkConf}
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -70,12 +71,12 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child:
* :: DeveloperApi ::
*/
@DeveloperApi
case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan {
case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan {
// TODO: attributes output by union should be distinct for nullability purposes
override def output = children.head.output
override def execute() = sc.union(children.map(_.execute()))
override def execute() = sqlContext.sparkContext.union(children.map(_.execute()))

override def otherCopyArgs = sc :: Nil
override def otherCopyArgs = sqlContext :: Nil
}

/**
Expand All @@ -87,11 +88,12 @@ case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends
* data to a single partition to compute the global limit.
*/
@DeveloperApi
case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext)
extends UnaryNode {
// TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
// partition local limit -> exchange into one partition -> partition local limit again

override def otherCopyArgs = sc :: Nil
override def otherCopyArgs = sqlContext :: Nil

override def output = child.output

Expand All @@ -117,8 +119,8 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) exte
*/
@DeveloperApi
case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
(@transient sc: SparkContext) extends UnaryNode {
override def otherCopyArgs = sc :: Nil
(@transient sqlContext: SQLContext) extends UnaryNode {
override def otherCopyArgs = sqlContext :: Nil

override def output = child.output

Expand All @@ -129,7 +131,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)

// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
override def execute() = sc.makeRDD(executeCollect(), 1)
override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1)
}

/**
Expand Down
21 changes: 11 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ package org.apache.spark.sql.execution

import scala.collection.mutable.{ArrayBuffer, BitSet}

import org.apache.spark.SparkContext

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
Expand Down Expand Up @@ -200,13 +199,13 @@ case class LeftSemiJoinHash(
@DeveloperApi
case class LeftSemiJoinBNL(
streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
(@transient sc: SparkContext)
(@transient sqlContext: SQLContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.

override def outputPartitioning: Partitioning = streamed.outputPartitioning

override def otherCopyArgs = sc :: Nil
override def otherCopyArgs = sqlContext :: Nil

def output = left.output

Expand All @@ -223,7 +222,8 @@ case class LeftSemiJoinBNL(


def execute() = {
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
val broadcastedRelation =
sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)

streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow
Expand Down Expand Up @@ -263,13 +263,13 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
@DeveloperApi
case class BroadcastNestedLoopJoin(
streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression])
(@transient sc: SparkContext)
(@transient sqlContext: SQLContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.

override def outputPartitioning: Partitioning = streamed.outputPartitioning

override def otherCopyArgs = sc :: Nil
override def otherCopyArgs = sqlContext :: Nil

def output = left.output ++ right.output

Expand All @@ -286,7 +286,8 @@ case class BroadcastNestedLoopJoin(


def execute() = {
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
val broadcastedRelation =
sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)

val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
val matchedRows = new ArrayBuffer[Row]
Expand Down Expand Up @@ -337,7 +338,7 @@ case class BroadcastNestedLoopJoin(
}

// TODO: Breaks lineage.
sc.union(
streamedPlusMatches.flatMap(_._1), sc.makeRDD(rightOuterMatches))
sqlContext.sparkContext.union(
streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ import parquet.hadoop.util.ContextUtil
import parquet.io.InvalidRecordException
import parquet.schema.MessageType

import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext}
import org.apache.spark.{Logging, SerializableWritable, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}

/**
Expand All @@ -49,10 +49,11 @@ case class ParquetTableScan(
output: Seq[Attribute],
relation: ParquetRelation,
columnPruningPred: Seq[Expression])(
@transient val sc: SparkContext)
@transient val sqlContext: SQLContext)
extends LeafNode {

override def execute(): RDD[Row] = {
val sc = sqlContext.sparkContext
val job = new Job(sc.hadoopConfiguration)
ParquetInputFormat.setReadSupportClass(
job,
Expand Down Expand Up @@ -93,7 +94,7 @@ case class ParquetTableScan(
.filter(_ != null) // Parquet's record filters may produce null values
}

override def otherCopyArgs = sc :: Nil
override def otherCopyArgs = sqlContext :: Nil

/**
* Applies a (candidate) projection.
Expand All @@ -104,7 +105,7 @@ case class ParquetTableScan(
def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = {
val success = validateProjection(prunedAttributes)
if (success) {
ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sc)
ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext)
} else {
sys.error("Warning: Could not validate Parquet schema projection in pruneColumns")
this
Expand Down Expand Up @@ -152,7 +153,7 @@ case class InsertIntoParquetTable(
relation: ParquetRelation,
child: SparkPlan,
overwrite: Boolean = false)(
@transient val sc: SparkContext)
@transient val sqlContext: SQLContext)
extends UnaryNode with SparkHadoopMapReduceUtil {

/**
Expand All @@ -168,7 +169,7 @@ case class InsertIntoParquetTable(
val childRdd = child.execute()
assert(childRdd != null)

val job = new Job(sc.hadoopConfiguration)
val job = new Job(sqlContext.sparkContext.hadoopConfiguration)

val writeSupport =
if (child.output.map(_.dataType).forall(_.isPrimitive)) {
Expand Down Expand Up @@ -204,7 +205,7 @@ case class InsertIntoParquetTable(

override def output = child.output

override def otherCopyArgs = sc :: Nil
override def otherCopyArgs = sqlContext :: Nil

/**
* Stores the given Row RDD as a Hadoop file.
Expand All @@ -231,7 +232,7 @@ case class InsertIntoParquetTable(
val wrappedConf = new SerializableWritable(job.getConfiguration)
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
val jobtrackerID = formatter.format(new Date())
val stageId = sc.newRddId()
val stageId = sqlContext.sparkContext.newRddId()

val taskIdOffset =
if (overwrite) {
Expand Down Expand Up @@ -270,7 +271,7 @@ case class InsertIntoParquetTable(
val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
jobCommitter.setupJob(jobTaskContext)
sc.runJob(rdd, writeShard _)
sqlContext.sparkContext.runJob(rdd, writeShard _)
jobCommitter.commitJob(jobTaskContext)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
val scanner = new ParquetTableScan(
ParquetTestData.testData.output,
ParquetTestData.testData,
Seq())(TestSQLContext.sparkContext)
Seq())(TestSQLContext)
val projected = scanner.pruneColumns(ParquetTypesConverter
.convertToAttributes(MessageTypeParser
.parseMessageType(ParquetTestData.subTestSchema)))
Expand Down