From 495eacdc12b0c5bf0198d43088ce3bde345417de Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Sun, 2 Nov 2014 20:57:51 -0600 Subject: [PATCH 01/79] Adding Timestamp and Date classes which support the standard comparison operators, as well as implicit conversions to support using these classes in the catalalyst DSL. This commit also adds a method to Row which builds a row from a schema and a list of strings. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../spark/sql/catalyst/dsl/package.scala | 33 ++++++- .../spark/sql/catalyst/expressions/Cast.scala | 4 +- .../sql/catalyst/expressions/Projection.scala | 35 +++++++ .../spark/sql/catalyst/expressions/Row.scala | 48 ++++++++- .../expressions/SpecificMutableRow.scala | 49 ++++++++++ .../sql/catalyst/expressions/literals.scala | 2 +- .../spark/sql/catalyst/types/dataTypes.scala | 5 +- .../spark/sql/catalyst/types/timetypes.scala | 97 +++++++++++++++++++ .../spark/sql/columnar/ColumnType.scala | 5 + .../scala/org/apache/spark/sql/package.scala | 38 +++++++- 11 files changed, 303 insertions(+), 15 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8fbdf664b71e4..8e8ba2dda4f37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -124,9 +124,9 @@ object ScalaReflection { case obj: LongType.JvmType => LongType case obj: FloatType.JvmType => FloatType case obj: DoubleType.JvmType => DoubleType - case obj: DateType.JvmType => DateType case obj: BigDecimal => DecimalType.Unlimited case obj: Decimal => DecimalType.Unlimited + case obj: DateType.JvmType => DateType case obj: TimestampType.JvmType => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 7e6d770314f5a..9b3aadc3536c0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -125,9 +125,9 @@ package object dsl { implicit def floatToLiteral(f: Float) = Literal(f) implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) - implicit def dateToLiteral(d: Date) = Literal(d) implicit def bigDecimalToLiteral(d: BigDecimal) = Literal(d) implicit def decimalToLiteral(d: Decimal) = Literal(d) + implicit def dateToLiteral(d: Date) = Literal(d) implicit def timestampToLiteral(t: Timestamp) = Literal(t) implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) @@ -146,6 +146,31 @@ package object dsl { def upper(e: Expression) = Upper(e) def lower(e: Expression) = Lower(e) + /* + * Conversions to provide the standard operators in the special case + * where a literal is being combined with a symbol. Without these an + * expression such as 0 < 'x is not recognized. + */ + implicit class InitialLiteral(x: Any) { + val literal = Literal(x) + def + (other: Symbol):Expression = {literal + other} + def - (other: Symbol):Expression = {literal - other} + def * (other: Symbol):Expression = {literal * other} + def / (other: Symbol):Expression = {literal / other} + def % (other: Symbol):Expression = {literal % other} + + def && (other: Symbol):Expression = {literal && other} + def || (other: Symbol):Expression = {literal || other} + + def < (other: Symbol):Expression = {literal < other} + def <= (other: Symbol):Expression = {literal <= other} + def > (other: Symbol):Expression = {literal > other} + def >= (other: Symbol):Expression = {literal >= other} + def === (other: Symbol):Expression = {literal === other} + def <=> (other: Symbol):Expression = {literal <=> other} + def !== (other: Symbol):Expression = {literal !== other} + } + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { @@ -182,9 +207,6 @@ package object dsl { /** Creates a new AttributeReference of type string */ def string = AttributeReference(s, StringType, nullable = true)() - /** Creates a new AttributeReference of type date */ - def date = AttributeReference(s, DateType, nullable = true)() - /** Creates a new AttributeReference of type decimal */ def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)() @@ -192,6 +214,9 @@ package object dsl { def decimal(precision: Int, scale: Int) = AttributeReference(s, DecimalType(precision, scale), nullable = true)() + /** Creates a new AttributeReference of type date */ + def date = AttributeReference(s, DateType, nullable = true)() + /** Creates a new AttributeReference of type timestamp */ def timestamp = AttributeReference(s, TimestampType, nullable = true)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 22009666196a1..38172eb1a50fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -31,8 +31,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w override def nullable = (child.dataType, dataType) match { case (StringType, _: NumericType) => true - case (StringType, TimestampType) => true case (StringType, DateType) => true + case (StringType, TimestampType) => true case (_: NumericType, DateType) => true case (BooleanType, DateType) => true case (DateType, _: NumericType) => true @@ -333,8 +333,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case dt if dt == child.dataType => identity[Any] case StringType => castToString case BinaryType => castToBinary - case DateType => castToDate case decimal: DecimalType => castToDecimal(decimal) + case DateType => castToDate case TimestampType => castToTimestamp case BooleanType => castToBoolean case ByteType => castToByte diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index e7e81a21fdf03..4391bcede66d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -139,6 +140,12 @@ class JoinedRow extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + def getDate(i: Int): Date = + if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -231,6 +238,13 @@ class JoinedRow2 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def getDate(i: Int): Date = + if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -317,6 +331,13 @@ class JoinedRow3 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def getDate(i: Int): Date = + if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -403,6 +424,13 @@ class JoinedRow4 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def getDate(i: Int): Date = + if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -489,6 +517,13 @@ class JoinedRow5 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def getDate(i: Int): Date = + if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index d00ec39774c35..738a5f3a53f3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.catalyst.types._ +import java.sql.{Date, Timestamp} +import java.math.BigDecimal object Row { /** @@ -42,6 +44,31 @@ object Row { * This method can be used to construct a [[Row]] from a [[Seq]] of values. */ def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) + + /** + * This method can be used to construct a [[Row]] from a [[Seq]] of Strings, + * converting each item to the type specified in a [[StructType]] schema. + * Only primitive types can be used. + */ + def fromStringsBySchema(strings: Seq[String], schema: StructType): Row = { + val values = for { + (field, str) <- schema.fields zip strings + item = field.dataType match { + case IntegerType => str.toInt + case LongType => str.toLong + case DoubleType => str.toDouble + case FloatType => str.toFloat + case ByteType => str.toByte + case ShortType => str.toShort + case StringType => str + case BooleanType => (str != "") + case DateType => Date.valueOf(str) + case TimestampType => Timestamp.valueOf(str) + case DecimalType() => new BigDecimal(str) + } + } yield item + new GenericRow(values.toArray) + } } /** @@ -64,6 +91,8 @@ trait Row extends Seq[Any] with Serializable { def getShort(i: Int): Short def getByte(i: Int): Byte def getString(i: Int): String + def getDate(i: Int): Date + def getTimestamp(i: Int): Timestamp def getAs[T](i: Int): T = apply(i).asInstanceOf[T] override def toString() = @@ -99,6 +128,8 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) + def setDate(ordinal: Int, value: Date) + def setTimestamp(ordinal: Int, value: Timestamp) } /** @@ -119,6 +150,9 @@ object EmptyRow extends Row { def getShort(i: Int): Short = throw new UnsupportedOperationException def getByte(i: Int): Byte = throw new UnsupportedOperationException def getString(i: Int): String = throw new UnsupportedOperationException + def getDate(i: Int): Date = throw new UnsupportedOperationException + def getTimestamp(i: Int): Timestamp = throw new UnsupportedOperationException + override def getAs[T](i: Int): T = throw new UnsupportedOperationException def copy() = this @@ -183,6 +217,16 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { values(i).asInstanceOf[String] } + def getDate(i: Int): Date = { + if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") + values(i).asInstanceOf[Date] + } + + def getTimestamp(i: Int): Timestamp = { + if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") + values(i).asInstanceOf[Timestamp] + } + // Custom hashCode function that matches the efficient code generated version. override def hashCode(): Int = { var result: Int = 37 @@ -226,6 +270,8 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } + override def setDate(ordinal: Int,value: Date): Unit = { values(ordinal) = value } + override def setTimestamp(ordinal: Int,value: Timestamp): Unit = { values(ordinal) = value } override def setNullAt(i: Int): Unit = { values(i) = null } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 570379c533e1f..34a10cf4a6945 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types._ +import java.sql.{Date, Timestamp} /** * A parent class for mutable container objects that are reused when the values are changed, @@ -169,6 +170,35 @@ final class MutableByte extends MutableValue { newCopy.asInstanceOf[this.type] } } +final class MutableDate extends MutableValue { + var value: Date = new Date(0) + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Date] + } + def copy() = { + val newCopy = new MutableDate + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableTimestamp extends MutableValue { + var value: Timestamp = new Timestamp(0) + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Timestamp] + } + def copy() = { + val newCopy = new MutableTimestamp + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} final class MutableAny extends MutableValue { var value: Any = _ @@ -307,6 +337,25 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).asInstanceOf[MutableByte].value } + override def setDate(ordinal: Int, value: Date): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableDate] + currentValue.isNull = false + currentValue.value = value + } + + override def getDate(i: Int): Date = { + values(i).asInstanceOf[MutableDate].value + } + override def setTimestamp(ordinal: Int, value: Timestamp): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableTimestamp] + currentValue.isNull = false + currentValue.value = value + } + + override def getTimestamp(i: Int): Timestamp = { + values(i).asInstanceOf[MutableTimestamp].value + } + override def getAs[T](i: Int): T = { values(i).boxed.asInstanceOf[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 93c19325151bf..548a9185998c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -34,8 +34,8 @@ object Literal { case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: Decimal => Literal(d, DecimalType.Unlimited) - case t: Timestamp => Literal(t, TimestampType) case d: Date => Literal(d, DateType) + case t: Timestamp => Literal(t, TimestampType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 8dda0b182805c..11e049061e6b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -92,9 +92,9 @@ object DataType { | "LongType" ^^^ LongType | "BinaryType" ^^^ BinaryType | "BooleanType" ^^^ BooleanType - | "DateType" ^^^ DateType | "DecimalType()" ^^^ DecimalType.Unlimited | fixedDecimalType + | "DateType" ^^^ DateType | "TimestampType" ^^^ TimestampType ) @@ -187,7 +187,8 @@ case object NullType extends DataType object NativeType { val all = Seq( - IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, + ByteType, StringType, DateType, TimestampType) def unapply(dt: DataType): Boolean = all.contains(dt) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala new file mode 100644 index 0000000000000..189412b312ccc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date => JDate, Timestamp => JTimestamp} +import scala.language.implicitConversions + +/* + * Subclass of java.sql.Date which provides the usual comparison + * operators (as required for catalyst expressions) and which can + * be constructed from a string. + * + * scala> val d1 = Date("2014-02-01") + * d1: Date = 2014-02-01 + * + * scala> val d2 = Date("2014-02-02") + * d2: Date = 2014-02-02 + * + * scala> d1 < d2 + * res1: Boolean = true + */ + +class Date(milliseconds: Long) extends JDate(milliseconds) { + def <(that: Date): Boolean = this.before(that) + def >(that: Date): Boolean = this.after(that) + def <=(that: Date): Boolean = (this.before(that) || this.equals(that)) + def >=(that: Date): Boolean = (this.after(that) || this.equals(that)) + def ===(that: Date): Boolean = this.equals(that) +} + +object Date { + def apply(init: String) = new Date(JDate.valueOf(init).getTime) +} + +/* + * Analogous subclass of java.sql.Timestamp. + * + * scala> val ts1 = Timestamp("2014-03-04 12:34:56.12") + * ts1: Timestamp = 2014-03-04 12:34:56.12 + * + * scala> val ts2 = Timestamp("2014-03-04 12:34:56.13") + * ts2: Timestamp = 2014-03-04 12:34:56.13 + * + * scala> ts1 < ts2 + * res13: Boolean = true + */ + +class Timestamp(milliseconds: Long) extends JTimestamp(milliseconds) { + def <(that: Timestamp): Boolean = this.before(that) + def >(that: Timestamp): Boolean = this.after(that) + def <=(that: Timestamp): Boolean = (this.before(that) || this.equals(that)) + def >=(that: Timestamp): Boolean = (this.after(that) || this.equals(that)) + def ===(that: Timestamp): Boolean = this.equals(that) +} + +object Timestamp { + def apply(init: String) = new Timestamp(JTimestamp.valueOf(init).getTime) +} + +/* + * Implicit conversions. + */ + +object TimeConversions { + + implicit def JDateToDate(jdate: JDate): Date = { + new Date(jdate.getTime) + } + + implicit def JTimestampToTimestamp(jtimestamp: JTimestamp): Timestamp = { + new Timestamp(jtimestamp.getTime) + } + + implicit def DateToJDate(date: Date): JDate = { + new JDate(date.getTime) + } + + implicit def TimestampToJTimestamp(timestamp: Timestamp): JTimestamp = { + new JTimestamp(timestamp.getTime) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index ab66c85c4f242..475b65c2798c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -372,6 +372,11 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = { row(ordinal) = value } + + def append(v: Date, buffer: ByteBuffer) { + buffer.putLong(v.getTime) + } + } private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 05926a24c5307..2d73db0d6bc48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -161,22 +161,22 @@ package object sql { /** * :: DeveloperApi :: * - * The data type representing `java.sql.Timestamp` values. + * The data type representing `java.sql.Date` values. * * @group dataType */ @DeveloperApi - val TimestampType = catalyst.types.TimestampType + val DateType = catalyst.types.DateType /** * :: DeveloperApi :: * - * The data type representing `java.sql.Date` values. + * The data type representing `java.sql.Timestamp` values. * * @group dataType */ @DeveloperApi - val DateType = catalyst.types.DateType + val TimestampType = catalyst.types.TimestampType /** * :: DeveloperApi :: @@ -451,4 +451,34 @@ package object sql { * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. */ type MetadataBuilder = catalyst.util.MetadataBuilder + + /** + * :: DeveloperApi :: + * + * A Timestamp class which support the standard comparison + * operators, for use in DSL expressions. Implicit conversions to + * java.sql.Date are provided. The class intializer accepts a + * String, e.g. + * + * val ts = Date("2014-01-01") + * + * @group dataType + */ + @DeveloperApi + val Date = catalyst.expressions.Date + + /** + * :: DeveloperApi :: + * + * A Timestamp class which support the standard comparison + * operators, for use in DSL expressions. Implicit conversions to + * java.sql.timestamp are provided. The class intializer accepts a + * String, e.g. + * + * val ts = Timestamp("2014-01-01 12:34:56.78") + * + * @group timeClasses + */ + @DeveloperApi + val Timestamp = catalyst.expressions.Timestamp } From 0c389a7e485d346c4dc26d09a993695d9826f37c Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Sun, 2 Nov 2014 21:11:35 -0600 Subject: [PATCH 02/79] Correcting a typo in the documentation. --- sql/core/src/main/scala/org/apache/spark/sql/package.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 2d73db0d6bc48..e105dbb8d47bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -455,10 +455,9 @@ package object sql { /** * :: DeveloperApi :: * - * A Timestamp class which support the standard comparison - * operators, for use in DSL expressions. Implicit conversions to - * java.sql.Date are provided. The class intializer accepts a - * String, e.g. + * A Date class which support the standard comparison operators, for + * use in DSL expressions. Implicit conversions to java.sql.Date + * are provided. The class intializer accepts a String, e.g. * * val ts = Date("2014-01-01") * From d6e4c5917522b9fb6653ddc0634e93ff2dcf82be Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 2 Nov 2014 21:56:07 -0800 Subject: [PATCH 03/79] Close #2971. From 001acc446345ccb1e494af9ff1d16dd65db8034e Mon Sep 17 00:00:00 2001 From: wangfei Date: Sun, 2 Nov 2014 22:02:05 -0800 Subject: [PATCH 04/79] [SPARK-4177][Doc]update build doc since JDBC/CLI support hive 13 now Author: wangfei Closes #3042 from scwf/patch-9 and squashes the following commits: 3784ed1 [wangfei] remove 'TODO' 1891553 [wangfei] update build doc since JDBC/CLI support hive 13 --- docs/building-spark.md | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index 4cc0b1f2e5116..238ddae15545e 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -99,14 +99,11 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package {% endhighlight %} - - # Building With Hive and JDBC Support To enable Hive integration for Spark SQL along with its JDBC server and CLI, add the `-Phive` profile to your existing build options. By default Spark will build with Hive 0.13.1 bindings. You can also build for Hive 0.12.0 using -the `-Phive-0.12.0` profile. NOTE: currently the JDBC server is only -supported for Hive 0.12.0. +the `-Phive-0.12.0` profile. {% highlight bash %} # Apache Hadoop 2.4.X with Hive 13 support mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package @@ -121,8 +118,8 @@ Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.o Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: - mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-0.12.0 clean package - mvn -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 test + mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive clean package + mvn -Pyarn -Phadoop-2.3 -Phive test The ScalaTest plugin also supports running only a specific test suite as follows: @@ -185,16 +182,16 @@ can be set to control the SBT build. For example: Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 assembly - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive assembly + sbt/sbt -Pyarn -Phadoop-2.3 -Phive test To run only a specific test suite as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 "test-only org.apache.spark.repl.ReplSuite" + sbt/sbt -Pyarn -Phadoop-2.3 -Phive "test-only org.apache.spark.repl.ReplSuite" To run test suites of a specific sub project as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 core/test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive core/test # Speeding up Compilation with Zinc From 76386e1a23c55a58c0aeea67820aab2bac71b24b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 2 Nov 2014 23:20:22 -0800 Subject: [PATCH 05/79] [SPARK-4163][Core][WebUI] Send the fetch failure message back to Web UI This is a PR to send the fetch failure message back to Web UI. Before: ![f1](https://cloud.githubusercontent.com/assets/1000778/4856595/1f036c80-60be-11e4-956f-335147fbccb7.png) ![f2](https://cloud.githubusercontent.com/assets/1000778/4856596/1f11cbea-60be-11e4-8fe9-9f9b2b35c884.png) After (Please ignore the meaning of exception, I threw it in the code directly because it's hard to simulate a fetch failure): ![e1](https://cloud.githubusercontent.com/assets/1000778/4856600/2657ea38-60be-11e4-9f2d-d56c5f900f10.png) ![e2](https://cloud.githubusercontent.com/assets/1000778/4856601/26595008-60be-11e4-912b-2744af786991.png) Author: zsxwing Closes #3032 from zsxwing/SPARK-4163 and squashes the following commits: f7e1faf [zsxwing] Discard changes for FetchFailedException and minor modification 4e946f7 [zsxwing] Add e as the cause of SparkException 316767d [zsxwing] Add private[storage] to FetchResult d51b0b6 [zsxwing] Set e as the cause of FetchFailedException b88c919 [zsxwing] Use 'private[storage]' for case classes instead of 'sealed' 62103fd [zsxwing] Update as per review 0c07d1f [zsxwing] Backward-compatible support a3bca65 [zsxwing] Send the fetch failure message back to Web UI --- .../org/apache/spark/TaskEndReason.scala | 6 +- .../apache/spark/scheduler/DAGScheduler.scala | 4 +- .../apache/spark/scheduler/JobLogger.scala | 2 +- .../spark/shuffle/FetchFailedException.scala | 16 ++-- .../hash/BlockStoreShuffleFetcher.scala | 14 ++-- .../storage/ShuffleBlockFetcherIterator.scala | 82 ++++++++++++------- .../org/apache/spark/util/JsonProtocol.scala | 7 +- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../spark/scheduler/DAGSchedulerSuite.scala | 10 +-- .../ShuffleBlockFetcherIteratorSuite.scala | 8 +- .../ui/jobs/JobProgressListenerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 4 +- 12 files changed, 92 insertions(+), 65 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 202fba699ab26..f45b463fb6f62 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -69,11 +69,13 @@ case class FetchFailed( bmAddress: BlockManagerId, // Note that bmAddress can be null shuffleId: Int, mapId: Int, - reduceId: Int) + reduceId: Int, + message: String) extends TaskFailedReason { override def toErrorString: String = { val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString - s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId)" + s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " + + s"message=\n$message\n)" } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index af17b5d5d2571..96114c0423a9e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1053,7 +1053,7 @@ class DAGScheduler( logInfo("Resubmitted " + task + ", so marking it as still running") stage.pendingTasks += task - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleToMapStage(shuffleId) @@ -1063,7 +1063,7 @@ class DAGScheduler( if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some("Fetch failure")) + markStageAsFinished(failedStage, Some("Fetch failure: " + failureMessage)) runningStages -= failedStage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 54904bffdf10b..4e3d9de540783 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -215,7 +215,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + " STAGE_ID=" + taskEnd.stageId stageLogInfo(taskEnd.stageId, taskStatus) - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) => taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + mapId + " REDUCE_ID=" + reduceId diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 71c08e9d5a8c3..0c1b6f4defdb3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -19,6 +19,7 @@ package org.apache.spark.shuffle import org.apache.spark.storage.BlockManagerId import org.apache.spark.{FetchFailed, TaskEndReason} +import org.apache.spark.util.Utils /** * Failed to fetch a shuffle block. The executor catches this exception and propagates it @@ -30,13 +31,11 @@ private[spark] class FetchFailedException( bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, - reduceId: Int) - extends Exception { - - override def getMessage: String = - "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId) + reduceId: Int, + message: String) + extends Exception(message) { - def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId) + def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) } /** @@ -46,7 +45,4 @@ private[spark] class MetadataFetchFailedException( shuffleId: Int, reduceId: Int, message: String) - extends FetchFailedException(null, shuffleId, -1, reduceId) { - - override def getMessage: String = message -} + extends FetchFailedException(null, shuffleId, -1, reduceId, message) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index f49917b7fe833..0d5247f4176d4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -19,12 +19,13 @@ package org.apache.spark.shuffle.hash import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import scala.util.{Failure, Success, Try} import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.{CompletionIterator, Utils} private[hash] object BlockStoreShuffleFetcher extends Logging { def fetch[T]( @@ -52,21 +53,22 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = { + def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { - case Some(block) => { + case Success(block) => { block.asInstanceOf[Iterator[T]] } - case None => { + case Failure(e) => { blockId match { case ShuffleBlockId(shufId, mapId, _) => val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId) + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, + Utils.exceptionString(e)) case _ => throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block") + "Failed to get block " + blockId + ", which is not a shuffle block", e) } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index ee89c7e521f4e..1e579187e4193 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import scala.util.{Failure, Success, Try} import org.apache.spark.{Logging, TaskContext} import org.apache.spark.network.BlockTransferService @@ -55,7 +56,7 @@ final class ShuffleBlockFetcherIterator( blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { import ShuffleBlockFetcherIterator._ @@ -118,16 +119,18 @@ final class ShuffleBlockFetcherIterator( private[this] def cleanup() { isZombie = true // Release the current buffer if necessary - if (currentResult != null && !currentResult.failed) { - currentResult.buf.release() + currentResult match { + case SuccessFetchResult(_, _, buf) => buf.release() + case _ => } // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { val result = iter.next() - if (!result.failed) { - result.buf.release() + result match { + case SuccessFetchResult(_, _, buf) => buf.release() + case _ => } } } @@ -151,7 +154,7 @@ final class ShuffleBlockFetcherIterator( // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() - results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), buf)) + results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf)) shuffleMetrics.remoteBytesRead += buf.size shuffleMetrics.remoteBlocksFetched += 1 } @@ -160,7 +163,7 @@ final class ShuffleBlockFetcherIterator( override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FetchResult(BlockId(blockId), -1, null)) + results.put(new FailureFetchResult(BlockId(blockId), e)) } } ) @@ -231,12 +234,12 @@ final class ShuffleBlockFetcherIterator( val buf = blockManager.getBlockData(blockId) shuffleMetrics.localBlocksFetched += 1 buf.retain() - results.put(new FetchResult(blockId, 0, buf)) + results.put(new SuccessFetchResult(blockId, 0, buf)) } catch { case e: Exception => // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FetchResult(blockId, -1, null)) + results.put(new FailureFetchResult(blockId, e)) return } } @@ -267,15 +270,17 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Option[Iterator[Any]]) = { + override def next(): (BlockId, Try[Iterator[Any]]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() val result = currentResult val stopFetchWait = System.currentTimeMillis() shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) - if (!result.failed) { - bytesInFlight -= result.size + + result match { + case SuccessFetchResult(_, size, _) => bytesInFlight -= size + case _ => } // Send fetch requests up to maxBytesInFlight while (fetchRequests.nonEmpty && @@ -283,20 +288,21 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorOpt: Option[Iterator[Any]] = if (result.failed) { - None - } else { - val is = blockManager.wrapForCompression(result.blockId, result.buf.createInputStream()) - val iter = serializer.newInstance().deserializeStream(is).asIterator - Some(CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - result.buf.release() - })) + val iteratorTry: Try[Iterator[Any]] = result match { + case FailureFetchResult(_, e) => Failure(e) + case SuccessFetchResult(blockId, _, buf) => { + val is = blockManager.wrapForCompression(blockId, buf.createInputStream()) + val iter = serializer.newInstance().deserializeStream(is).asIterator + Success(CompletionIterator[Any, Iterator[Any]](iter, { + // Once the iterator is exhausted, release the buffer and set currentResult to null + // so we don't release it again in cleanup. + currentResult = null + buf.release() + })) + } } - (result.blockId, iteratorOpt) + (result.blockId, iteratorTry) } } @@ -315,14 +321,30 @@ object ShuffleBlockFetcherIterator { } /** - * Result of a fetch from a remote block. A failure is represented as size == -1. + * Result of a fetch from a remote block. + */ + private[storage] sealed trait FetchResult { + val blockId: BlockId + } + + /** + * Result of a fetch from a remote block successfully. * @param blockId block id * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. -1 if failure is present. - * @param buf [[ManagedBuffer]] for the content. null is error. + * Note that this is NOT the exact bytes. + * @param buf [[ManagedBuffer]] for the content. */ - case class FetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) { - def failed: Boolean = size == -1 - if (failed) assert(buf == null) else assert(buf != null) + private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) + extends FetchResult { + require(buf != null) + require(size >= 0) } + + /** + * Result of a fetch from a remote block unsuccessfully. + * @param blockId block id + * @param e the failure exception + */ + private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable) + extends FetchResult } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 43c7fba06694a..f7ae1f7f334de 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -279,7 +279,8 @@ private[spark] object JsonProtocol { ("Block Manager Address" -> blockManagerAddress) ~ ("Shuffle ID" -> fetchFailed.shuffleId) ~ ("Map ID" -> fetchFailed.mapId) ~ - ("Reduce ID" -> fetchFailed.reduceId) + ("Reduce ID" -> fetchFailed.reduceId) ~ + ("Message" -> fetchFailed.message) case exceptionFailure: ExceptionFailure => val stackTrace = stackTraceToJson(exceptionFailure.stackTrace) val metrics = exceptionFailure.metrics.map(taskMetricsToJson).getOrElse(JNothing) @@ -629,7 +630,9 @@ private[spark] object JsonProtocol { val shuffleId = (json \ "Shuffle ID").extract[Int] val mapId = (json \ "Map ID").extract[Int] val reduceId = (json \ "Reduce ID").extract[Int] - new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId) + val message = Utils.jsonOption(json \ "Message").map(_.extract[String]) + new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId, + message.getOrElse("Unknown reason")) case `exceptionFailure` => val className = (json \ "Class Name").extract[String] val description = (json \ "Description").extract[String] diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b402c5f334bb0..a33046d2040d8 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1597,7 +1597,7 @@ private[spark] object Utils extends Logging { } /** Return a nice string representation of the exception, including the stack trace. */ - def exceptionString(e: Exception): String = { + def exceptionString(e: Throwable): String = { if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a2e4f712db55b..819f95634bcdc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -431,7 +431,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), - (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) // this will get called // blockManagerMaster.removeExecutor("exec-hostA") // ask the scheduler to try it again @@ -461,7 +461,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(CompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null, Map[Long, Any](), null, @@ -472,7 +472,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // The second ResultTask fails, with a fetch failure for the output from the second mapper. runEvent(CompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"), null, Map[Long, Any](), null, @@ -624,7 +624,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F (Success, makeMapStatus("hostC", 1)))) // fail the third stage because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // have DAGScheduler try again @@ -655,7 +655,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F (Success, makeMapStatus("hostB", 1)))) // pretend stage 0 failed because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 28f766570e96f..1eaabb93adbed 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -102,7 +102,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") val (blockId, subIterator) = iterator.next() - assert(subIterator.isDefined, + assert(subIterator.isSuccess, s"iterator should have 5 elements defined but actually has $i elements") // Make sure we release the buffer once the iterator is exhausted. @@ -230,8 +230,8 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { sem.acquire() // The first block should be defined, and the last two are not defined (due to failure) - assert(iterator.next()._2.isDefined === true) - assert(iterator.next()._2.isDefined === false) - assert(iterator.next()._2.isDefined === false) + assert(iterator.next()._2.isSuccess) + assert(iterator.next()._2.isFailure) + assert(iterator.next()._2.isFailure) } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 6567c5ab836e7..2efbae689771a 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -115,7 +115,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // Go through all the failure cases to make sure we are counting them as failures. val taskFailedReasons = Seq( Resubmitted, - new FetchFailed(null, 0, 0, 0), + new FetchFailed(null, 0, 0, 0, "ignored"), new ExceptionFailure("Exception", "description", null, None), TaskResultLost, TaskKilled, diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index d235d7a0ed839..a91c9ddeaef36 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -107,7 +107,8 @@ class JsonProtocolSuite extends FunSuite { testJobResult(jobFailed) // TaskEndReason - val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19) + val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + "Some exception") val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, None) testTaskEndReason(Success) testTaskEndReason(Resubmitted) @@ -396,6 +397,7 @@ class JsonProtocolSuite extends FunSuite { assert(r1.mapId === r2.mapId) assert(r1.reduceId === r2.reduceId) assertEquals(r1.bmAddress, r2.bmAddress) + assert(r1.message === r2.message) case (r1: ExceptionFailure, r2: ExceptionFailure) => assert(r1.className === r2.className) assert(r1.description === r2.description) From 7c8b2c041056c167bfadb10ef1a140173f549aa6 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Mon, 3 Nov 2014 08:45:43 -0600 Subject: [PATCH 06/79] Correcting the bugs and issues pointed out in liancheng's very helpful comments. --- .../spark/sql/catalyst/dsl/package.scala | 47 ++++++++++++------- .../sql/catalyst/expressions/Projection.scala | 34 -------------- .../spark/sql/catalyst/expressions/Row.scala | 27 ----------- .../expressions/SpecificMutableRow.scala | 7 --- .../spark/sql/catalyst/types/timetypes.scala | 26 +++++----- .../scala/org/apache/spark/sql/package.scala | 10 ++-- 6 files changed, 50 insertions(+), 101 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 9b3aadc3536c0..c258235444720 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -151,26 +151,39 @@ package object dsl { * where a literal is being combined with a symbol. Without these an * expression such as 0 < 'x is not recognized. */ - implicit class InitialLiteral(x: Any) { + case class LhsLiteral(x: Any) { val literal = Literal(x) - def + (other: Symbol):Expression = {literal + other} - def - (other: Symbol):Expression = {literal - other} - def * (other: Symbol):Expression = {literal * other} - def / (other: Symbol):Expression = {literal / other} - def % (other: Symbol):Expression = {literal % other} - - def && (other: Symbol):Expression = {literal && other} - def || (other: Symbol):Expression = {literal || other} - - def < (other: Symbol):Expression = {literal < other} - def <= (other: Symbol):Expression = {literal <= other} - def > (other: Symbol):Expression = {literal > other} - def >= (other: Symbol):Expression = {literal >= other} - def === (other: Symbol):Expression = {literal === other} - def <=> (other: Symbol):Expression = {literal <=> other} - def !== (other: Symbol):Expression = {literal !== other} + def + (other: Symbol): Expression = literal + other + def - (other: Symbol): Expression = literal - other + def * (other: Symbol): Expression = literal * other + def / (other: Symbol): Expression = literal / other + def % (other: Symbol): Expression = literal % other + + def && (other: Symbol): Expression = literal && other + def || (other: Symbol): Expression = literal || other + + def < (other: Symbol): Expression = literal < other + def <= (other: Symbol): Expression = literal <= other + def > (other: Symbol): Expression = literal > other + def >= (other: Symbol): Expression = literal >= other + def === (other: Symbol): Expression = literal === other + def <=> (other: Symbol): Expression = literal <=> other + def !== (other: Symbol): Expression = literal !== other } + implicit def booleanToLhsLiteral(b: Boolean) = new LhsLiteral(b) + implicit def byteToLhsLiteral(b: Byte) = new LhsLiteral(b) + implicit def shortToLhsLiteral(s: Short) = new LhsLiteral(s) + implicit def intToLhsLiteral(i: Int) = new LhsLiteral(i) + implicit def longToLhsLiteral(l: Long) = new LhsLiteral(l) + implicit def floatToLhsLiteral(f: Float) = new LhsLiteral(f) + implicit def doubleToLhsLiteral(d: Double) = new LhsLiteral(d) + implicit def stringToLhsLiteral(s: String) = new LhsLiteral(s) + implicit def bigDecimalToLhsLiteral(d: BigDecimal) = new LhsLiteral(d) + implicit def decimalToLhsLiteral(d: Decimal) = new LhsLiteral(d) + implicit def dateToLhsLiteral(d: Date) = new LhsLiteral(d) + implicit def timestampToLhsLiteral(t: Timestamp) = new LhsLiteral(t) + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 4391bcede66d1..45b5e6e2c289a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -140,12 +140,6 @@ class JoinedRow extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) - def getDate(i: Int): Date = - if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) - - def getTimestamp(i: Int): Timestamp = - if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) - override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -238,13 +232,6 @@ class JoinedRow2 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) - - def getDate(i: Int): Date = - if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) - - def getTimestamp(i: Int): Timestamp = - if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) - override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -331,13 +318,6 @@ class JoinedRow3 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) - - def getDate(i: Int): Date = - if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) - - def getTimestamp(i: Int): Timestamp = - if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) - override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -424,13 +404,6 @@ class JoinedRow4 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) - - def getDate(i: Int): Date = - if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) - - def getTimestamp(i: Int): Timestamp = - if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) - override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -517,13 +490,6 @@ class JoinedRow5 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) - - def getDate(i: Int): Date = - if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) - - def getTimestamp(i: Int): Timestamp = - if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) - override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 738a5f3a53f3c..5c0864c896628 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -44,31 +44,6 @@ object Row { * This method can be used to construct a [[Row]] from a [[Seq]] of values. */ def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) - - /** - * This method can be used to construct a [[Row]] from a [[Seq]] of Strings, - * converting each item to the type specified in a [[StructType]] schema. - * Only primitive types can be used. - */ - def fromStringsBySchema(strings: Seq[String], schema: StructType): Row = { - val values = for { - (field, str) <- schema.fields zip strings - item = field.dataType match { - case IntegerType => str.toInt - case LongType => str.toLong - case DoubleType => str.toDouble - case FloatType => str.toFloat - case ByteType => str.toByte - case ShortType => str.toShort - case StringType => str - case BooleanType => (str != "") - case DateType => Date.valueOf(str) - case TimestampType => Timestamp.valueOf(str) - case DecimalType() => new BigDecimal(str) - } - } yield item - new GenericRow(values.toArray) - } } /** @@ -91,8 +66,6 @@ trait Row extends Seq[Any] with Serializable { def getShort(i: Int): Short def getByte(i: Int): Byte def getString(i: Int): String - def getDate(i: Int): Date - def getTimestamp(i: Int): Timestamp def getAs[T](i: Int): T = apply(i).asInstanceOf[T] override def toString() = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 34a10cf4a6945..0c4eb53ee81f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -343,19 +343,12 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR currentValue.value = value } - override def getDate(i: Int): Date = { - values(i).asInstanceOf[MutableDate].value - } override def setTimestamp(ordinal: Int, value: Timestamp): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableTimestamp] currentValue.isNull = false currentValue.value = value } - override def getTimestamp(i: Int): Timestamp = { - values(i).asInstanceOf[MutableTimestamp].value - } - override def getAs[T](i: Int): T = { values(i).boxed.asInstanceOf[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala index 189412b312ccc..fcb77e640a8ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Date => JDate, Timestamp => JTimestamp} +import java.sql.{Date, Timestamp} import scala.language.implicitConversions /* @@ -35,7 +35,7 @@ import scala.language.implicitConversions * res1: Boolean = true */ -class Date(milliseconds: Long) extends JDate(milliseconds) { +class RichDate(milliseconds: Long) extends Date(milliseconds) { def <(that: Date): Boolean = this.before(that) def >(that: Date): Boolean = this.after(that) def <=(that: Date): Boolean = (this.before(that) || this.equals(that)) @@ -43,8 +43,8 @@ class Date(milliseconds: Long) extends JDate(milliseconds) { def ===(that: Date): Boolean = this.equals(that) } -object Date { - def apply(init: String) = new Date(JDate.valueOf(init).getTime) +object RichDate { + def apply(init: String) = new RichDate(Date.valueOf(init).getTime) } /* @@ -60,7 +60,7 @@ object Date { * res13: Boolean = true */ -class Timestamp(milliseconds: Long) extends JTimestamp(milliseconds) { +class RichTimestamp(milliseconds: Long) extends Timestamp(milliseconds) { def <(that: Timestamp): Boolean = this.before(that) def >(that: Timestamp): Boolean = this.after(that) def <=(that: Timestamp): Boolean = (this.before(that) || this.equals(that)) @@ -68,8 +68,8 @@ class Timestamp(milliseconds: Long) extends JTimestamp(milliseconds) { def ===(that: Timestamp): Boolean = this.equals(that) } -object Timestamp { - def apply(init: String) = new Timestamp(JTimestamp.valueOf(init).getTime) +object RichTimestamp { + def apply(init: String) = new RichTimestamp(Timestamp.valueOf(init).getTime) } /* @@ -78,20 +78,20 @@ object Timestamp { object TimeConversions { - implicit def JDateToDate(jdate: JDate): Date = { + implicit def javaDateToRichDate(jdate: Date): RichDate = { new Date(jdate.getTime) } - implicit def JTimestampToTimestamp(jtimestamp: JTimestamp): Timestamp = { + implicit def javaTimestampToRichTimestamp(jtimestamp: Timestamp): RichTimestamp = { new Timestamp(jtimestamp.getTime) } - implicit def DateToJDate(date: Date): JDate = { - new JDate(date.getTime) + implicit def richDateToJavaDate(date: RichDate): Date = { + new Date(date.getTime) } - implicit def TimestampToJTimestamp(timestamp: Timestamp): JTimestamp = { - new JTimestamp(timestamp.getTime) + implicit def richTimestampToJavaTimestamp(timestamp: RichTimestamp): Timestamp = { + new Timestamp(timestamp.getTime) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index e105dbb8d47bd..42411b7dd840d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -459,12 +459,14 @@ package object sql { * use in DSL expressions. Implicit conversions to java.sql.Date * are provided. The class intializer accepts a String, e.g. * - * val ts = Date("2014-01-01") + * {{{ + * val d = Date("2014-01-01") + * }}} * * @group dataType */ @DeveloperApi - val Date = catalyst.expressions.Date + val Date = catalyst.expressions.RichDate /** * :: DeveloperApi :: @@ -474,10 +476,12 @@ package object sql { * java.sql.timestamp are provided. The class intializer accepts a * String, e.g. * + * {{{ * val ts = Timestamp("2014-01-01 12:34:56.78") + * }}} * * @group timeClasses */ @DeveloperApi - val Timestamp = catalyst.expressions.Timestamp + val Timestamp = catalyst.expressions.RichTimestamp } From 2aca97c7cfdefea8b6f9dbb88951e9acdfd606d9 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Mon, 3 Nov 2014 09:02:35 -0800 Subject: [PATCH 07/79] [EC2] Factor out Mesos spark-ec2 branch We reference a specific branch in two places. This patch makes it one place. Author: Nicholas Chammas Closes #3008 from nchammas/mesos-spark-ec2-branch and squashes the following commits: 10a6089 [Nicholas Chammas] factor out mess spark-ec2 branch --- ec2/spark_ec2.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 0d6b82b4944f3..50f88f735650e 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -41,8 +41,9 @@ DEFAULT_SPARK_VERSION = "1.1.0" +MESOS_SPARK_EC2_BRANCH = "v4" # A URL prefix from which to fetch AMI information -AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list" +AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH) class UsageError(Exception): @@ -583,7 +584,13 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten - ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v4") + ssh( + host=master, + opts=opts, + command="rm -rf spark-ec2" + + " && " + + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH) + ) print "Deploying files to master..." deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules) From 3cca1962207745814b9d83e791713c91b659c36c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 3 Nov 2014 12:24:24 -0800 Subject: [PATCH 08/79] [SPARK-4148][PySpark] fix seed distribution and add some tests for rdd.sample The current way of seed distribution makes the random sequences from partition i and i+1 offset by 1. ~~~ In [14]: import random In [15]: r1 = random.Random(10) In [16]: r1.randint(0, 1) Out[16]: 1 In [17]: r1.random() Out[17]: 0.4288890546751146 In [18]: r1.random() Out[18]: 0.5780913011344704 In [19]: r2 = random.Random(10) In [20]: r2.randint(0, 1) Out[20]: 1 In [21]: r2.randint(0, 1) Out[21]: 0 In [22]: r2.random() Out[22]: 0.5780913011344704 ~~~ Note: The new tests are not for this bug fix. Author: Xiangrui Meng Closes #3010 from mengxr/SPARK-4148 and squashes the following commits: 869ae4b [Xiangrui Meng] move tests tests.py c1bacd9 [Xiangrui Meng] fix seed distribution and add some tests for rdd.sample --- python/pyspark/rdd.py | 3 --- python/pyspark/rddsampler.py | 11 +++++------ python/pyspark/tests.py | 15 +++++++++++++++ 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 550c9dd80522f..4f025b9f11707 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -316,9 +316,6 @@ def sample(self, withReplacement, fraction, seed=None): """ Return a sampled subset of this RDD (relies on numpy and falls back on default random generator if numpy is unavailable). - - >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP - [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 528a181e8905a..f5c3cfd259a5b 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -40,14 +40,13 @@ def __init__(self, withReplacement, seed=None): def initRandomGenerator(self, split): if self._use_numpy: import numpy - self._random = numpy.random.RandomState(self._seed) + self._random = numpy.random.RandomState(self._seed ^ split) else: - self._random = random.Random(self._seed) + self._random = random.Random(self._seed ^ split) - for _ in range(0, split): - # discard the next few values in the sequence to have a - # different seed for the different splits - self._random.randint(0, 2 ** 32 - 1) + # mixing because the initial seeds are close to each other + for _ in xrange(10): + self._random.randint(0, 1) self._split = split self._rand_initialized = True diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 37a128907b3a7..253a471849c3a 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -648,6 +648,21 @@ def test_distinct(self): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_sample(self): + rdd = self.sc.parallelize(range(0, 100), 4) + wo = rdd.sample(False, 0.1, 2).collect() + wo_dup = rdd.sample(False, 0.1, 2).collect() + self.assertSetEqual(set(wo), set(wo_dup)) + wr = rdd.sample(True, 0.2, 5).collect() + wr_dup = rdd.sample(True, 0.2, 5).collect() + self.assertSetEqual(set(wr), set(wr_dup)) + wo_s10 = rdd.sample(False, 0.3, 10).collect() + wo_s20 = rdd.sample(False, 0.3, 20).collect() + self.assertNotEqual(set(wo_s10), set(wo_s20)) + wr_s11 = rdd.sample(True, 0.4, 11).collect() + wr_s21 = rdd.sample(True, 0.4, 21).collect() + self.assertNotEqual(set(wr_s11), set(wr_s21)) + class ProfilerTests(PySparkTestCase): From 75690de234df13fe27bab28f7f1a9276ecdb168f Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Mon, 3 Nov 2014 14:32:10 -0600 Subject: [PATCH 09/79] Make implicit conversions for Literal op Symbol return a specific type, e.g. Add(1, 'x). --- .../spark/sql/catalyst/dsl/package.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index c258235444720..45d74c32e5969 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -153,22 +153,22 @@ package object dsl { */ case class LhsLiteral(x: Any) { val literal = Literal(x) - def + (other: Symbol): Expression = literal + other - def - (other: Symbol): Expression = literal - other - def * (other: Symbol): Expression = literal * other - def / (other: Symbol): Expression = literal / other - def % (other: Symbol): Expression = literal % other - - def && (other: Symbol): Expression = literal && other - def || (other: Symbol): Expression = literal || other - - def < (other: Symbol): Expression = literal < other - def <= (other: Symbol): Expression = literal <= other - def > (other: Symbol): Expression = literal > other - def >= (other: Symbol): Expression = literal >= other - def === (other: Symbol): Expression = literal === other - def <=> (other: Symbol): Expression = literal <=> other - def !== (other: Symbol): Expression = literal !== other + def + (other: Symbol) = Add(literal, other) + def - (other: Symbol) = Subtract(literal, other) + def * (other: Symbol) = Multiply(literal, other) + def / (other: Symbol) = Divide(literal, other) + def % (other: Symbol) = Remainder(literal, other) + + def && (other: Symbol) = And(literal, other) + def || (other: Symbol) = Or(literal, other) + + def < (other: Symbol) = LessThan(literal, other) + def <= (other: Symbol) = LessThanOrEqual(literal, other) + def > (other: Symbol) = GreaterThan(literal, other) + def >= (other: Symbol) = GreaterThanOrEqual(literal, other) + def === (other: Symbol) = EqualTo(literal, other) + def <=> (other: Symbol) = EqualNullSafe(literal, other) + def !== (other: Symbol) = Not(EqualTo(literal, other)) } implicit def booleanToLhsLiteral(b: Boolean) = new LhsLiteral(b) From df607da025488d6c924d3d70eddb67f5523080d3 Mon Sep 17 00:00:00 2001 From: fi Date: Mon, 3 Nov 2014 12:56:56 -0800 Subject: [PATCH 10/79] [SPARK-4211][Build] Fixes hive.version in Maven profile hive-0.13.1 instead of `hive.version=0.13.1`. e.g. mvn -Phive -Phive=0.13.1 Note: `hive.version=0.13.1a` is the default property value. However, when explicitly specifying the `hive-0.13.1` maven profile, the wrong one would be selected. References: PR #2685, which resolved a package incompatibility issue with Hive-0.13.1 by introducing a special version Hive-0.13.1a Author: fi Closes #3072 from coderfi/master and squashes the following commits: 7ca4b1e [fi] Fixes the `hive-0.13.1` maven profile referencing `hive.version=0.13.1` instead of the Spark compatible `hive.version=0.13.1a` Note: `hive.version=0.13.1a` is the default version. However, when explicitly specifying the `hive-0.13.1` maven profile, the wrong one would be selected. e.g. mvn -Phive -Phive=0.13.1 See PR #2685 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 6191cd3a541e2..eb613531b8a5f 100644 --- a/pom.xml +++ b/pom.xml @@ -1359,7 +1359,7 @@ false - 0.13.1 + 0.13.1a 0.13.1 10.10.1.1 From 2b6e1ce6ee7b1ba8160bcbee97f5bbff5c46ca09 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Mon, 3 Nov 2014 13:07:41 -0800 Subject: [PATCH 11/79] [SPARK-4207][SQL] Query which has syntax like 'not like' is not working in Spark SQL Queries which has 'not like' is not working spark sql. sql("SELECT * FROM records where value not like 'val%'") same query works in Spark HiveQL Author: ravipesala Closes #3075 from ravipesala/SPARK-4207 and squashes the following commits: 35c11e7 [ravipesala] Supported 'not like' syntax in sql --- .../main/scala/org/apache/spark/sql/catalyst/SqlParser.scala | 1 + .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 00fc4d75c9ea9..5e613e0f18ba6 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -242,6 +242,7 @@ class SqlParser extends AbstractSparkSQLParser { | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) } + | termExpression ~ (NOT ~ LIKE ~> termExpression) ^^ { case e1 ~ e2 => Not(Like(e1, e2)) } | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { case e1 ~ e2 => In(e1, e2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6bf439377aa3e..702714af5308d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -938,4 +938,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), (11 to 100).map(i => Seq(i))) } + + test("SPARK-4207 Query which has syntax like 'not like' is not working in Spark SQL") { + checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"), + (1 to 99).map(i => Seq(i))) + } } From 24544fbce05665ab4999a1fe5aac434d29cd912c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 3 Nov 2014 13:17:09 -0800 Subject: [PATCH 12/79] [SPARK-3594] [PySpark] [SQL] take more rows to infer schema or sampling This patch will try to infer schema for RDD which has empty value (None, [], {}) in the first row. It will try first 100 rows and merge the types into schema, also merge fields of StructType together. If there is still NullType in schema, then it will show an warning, tell user to try with sampling. If sampling is presented, it will infer schema from all the rows after sampling. Also, add samplingRatio for jsonFile() and jsonRDD() Author: Davies Liu Author: Davies Liu Closes #2716 from davies/infer and squashes the following commits: e678f6d [Davies Liu] Merge branch 'master' of github.com:apache/spark into infer 34b5c63 [Davies Liu] Merge branch 'master' of github.com:apache/spark into infer 567dc60 [Davies Liu] update docs 9767b27 [Davies Liu] Merge branch 'master' into infer e48d7fb [Davies Liu] fix tests 29e94d5 [Davies Liu] let NullType inherit from PrimitiveType ee5d524 [Davies Liu] Merge branch 'master' of github.com:apache/spark into infer 540d1d5 [Davies Liu] merge fields for StructType f93fd84 [Davies Liu] add more tests 3603e00 [Davies Liu] take more rows to infer schema, or infer the schema by sampling the RDD --- python/pyspark/sql.py | 196 ++++++++++++------ python/pyspark/tests.py | 19 ++ .../spark/sql/catalyst/types/dataTypes.scala | 2 +- 3 files changed, 148 insertions(+), 69 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 98e41f8575679..675df084bf303 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -109,6 +109,15 @@ def __eq__(self, other): return self is other +class NullType(PrimitiveType): + + """Spark SQL NullType + + The data type representing None, used for the types which has not + been inferred. + """ + + class StringType(PrimitiveType): """Spark SQL StringType @@ -331,7 +340,7 @@ class StructField(DataType): """ - def __init__(self, name, dataType, nullable, metadata=None): + def __init__(self, name, dataType, nullable=True, metadata=None): """Creates a StructField :param name: the name of this field. :param dataType: the data type of this field. @@ -484,6 +493,7 @@ def _parse_datatype_json_value(json_value): # Mapping Python types to Spark SQL DataType _type_mappings = { + type(None): NullType, bool: BooleanType, int: IntegerType, long: LongType, @@ -500,22 +510,22 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): """Infer the DataType from obj""" - if obj is None: - raise ValueError("Can not infer type for None") - dataType = _type_mappings.get(type(obj)) if dataType is not None: return dataType() if isinstance(obj, dict): - if not obj: - raise ValueError("Can not infer type for empty dict") - key, value = obj.iteritems().next() - return MapType(_infer_type(key), _infer_type(value), True) + for key, value in obj.iteritems(): + if key is not None and value is not None: + return MapType(_infer_type(key), _infer_type(value), True) + else: + return MapType(NullType(), NullType(), True) elif isinstance(obj, (list, array)): - if not obj: - raise ValueError("Can not infer type for empty list/array") - return ArrayType(_infer_type(obj[0]), True) + for v in obj: + if v is not None: + return ArrayType(_infer_type(obj[0]), True) + else: + return ArrayType(NullType(), True) else: try: return _infer_schema(obj) @@ -548,60 +558,93 @@ def _infer_schema(row): return StructType(fields) -def _create_converter(obj, dataType): +def _has_nulltype(dt): + """ Return whether there is NullType in `dt` or not """ + if isinstance(dt, StructType): + return any(_has_nulltype(f.dataType) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_nulltype((dt.elementType)) + elif isinstance(dt, MapType): + return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) + else: + return isinstance(dt, NullType) + + +def _merge_type(a, b): + if isinstance(a, NullType): + return b + elif isinstance(b, NullType): + return a + elif type(a) is not type(b): + # TODO: type cast (such as int -> long) + raise TypeError("Can not merge type %s and %s" % (a, b)) + + # same type + if isinstance(a, StructType): + nfs = dict((f.name, f.dataType) for f in b.fields) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + for f in a.fields] + names = set([f.name for f in fields]) + for n in nfs: + if n not in names: + fields.append(StructField(n, nfs[n])) + return StructType(fields) + + elif isinstance(a, ArrayType): + return ArrayType(_merge_type(a.elementType, b.elementType), True) + + elif isinstance(a, MapType): + return MapType(_merge_type(a.keyType, b.keyType), + _merge_type(a.valueType, b.valueType), + True) + else: + return a + + +def _create_converter(dataType): """Create an converter to drop the names of fields in obj """ if isinstance(dataType, ArrayType): - conv = _create_converter(obj[0], dataType.elementType) + conv = _create_converter(dataType.elementType) return lambda row: map(conv, row) elif isinstance(dataType, MapType): - value = obj.values()[0] - conv = _create_converter(value, dataType.valueType) + conv = _create_converter(dataType.valueType) return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + elif isinstance(dataType, NullType): + return lambda x: None + elif not isinstance(dataType, StructType): return lambda x: x # dataType must be StructType names = [f.name for f in dataType.fields] + converters = [_create_converter(f.dataType) for f in dataType.fields] + + def convert_struct(obj): + if obj is None: + return + + if isinstance(obj, tuple): + if hasattr(obj, "fields"): + d = dict(zip(obj.fields, obj)) + if hasattr(obj, "__FIELDS__"): + d = dict(zip(obj.__FIELDS__, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): + d = dict(obj) + else: + raise ValueError("unexpected tuple: %s" % obj) - if isinstance(obj, dict): - conv = lambda o: tuple(o.get(n) for n in names) - - elif isinstance(obj, tuple): - if hasattr(obj, "_fields"): # namedtuple - conv = tuple - elif hasattr(obj, "__FIELDS__"): - conv = tuple - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): - conv = lambda o: tuple(v for k, v in o) + elif isinstance(obj, dict): + d = obj + elif hasattr(obj, "__dict__"): # object + d = obj.__dict__ else: - raise ValueError("unexpected tuple") + raise ValueError("Unexpected obj: %s" % obj) - elif hasattr(obj, "__dict__"): # object - conv = lambda o: [o.__dict__.get(n, None) for n in names] + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) - if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields): - return conv - - row = conv(obj) - convs = [_create_converter(v, f.dataType) - for v, f in zip(row, dataType.fields)] - - def nested_conv(row): - return tuple(f(v) for f, v in zip(convs, conv(row))) - - return nested_conv - - -def _drop_schema(rows, schema): - """ all the names of fields, becoming tuples""" - iterator = iter(rows) - row = iterator.next() - converter = _create_converter(row, schema) - yield converter(row) - for i in iterator: - yield converter(i) + return convert_struct _BRACKETS = {'(': ')', '[': ']', '{': '}'} @@ -713,7 +756,7 @@ def _infer_schema_type(obj, dataType): return _infer_type(obj) if not obj: - raise ValueError("Can not infer type from empty value") + return NullType() if isinstance(dataType, ArrayType): eType = _infer_schema_type(obj[0], dataType.elementType) @@ -1049,18 +1092,20 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc._javaAccumulator, returnType.json()) - def inferSchema(self, rdd): + def inferSchema(self, rdd, samplingRatio=None): """Infer and apply a schema to an RDD of L{Row}. - We peek at the first row of the RDD to determine the fields' names - and types. Nested collections are supported, which include array, - dict, list, Row, tuple, namedtuple, or object. + When samplingRatio is specified, the schema is inferred by looking + at the types of each row in the sampled dataset. Otherwise, the + first 100 rows of the RDD are inspected. Nested collections are + supported, which can include array, dict, list, Row, tuple, + namedtuple, or object. - All the rows in `rdd` should have the same type with the first one, - or it will cause runtime exceptions. + Each row could be L{pyspark.sql.Row} object or namedtuple or objects. + Using top level dicts is deprecated, as dict is used to represent Maps. - Each row could be L{pyspark.sql.Row} object or namedtuple or objects, - using dict is deprecated. + If a single column has multiple distinct inferred types, it may cause + runtime exceptions. >>> rdd = sc.parallelize( ... [Row(field1=1, field2="row1"), @@ -1097,8 +1142,23 @@ def inferSchema(self, rdd): warnings.warn("Using RDD of dict to inferSchema is deprecated," "please use pyspark.sql.Row instead") - schema = _infer_schema(first) - rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + warnings.warn("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio > 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + + converter = _create_converter(schema) + rdd = rdd.map(converter) return self.applySchema(rdd, schema) def applySchema(self, rdd, schema): @@ -1219,7 +1279,7 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD() return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path, schema=None): + def jsonFile(self, path, schema=None, samplingRatio=1.0): """ Loads a text file storing one JSON object per line as a L{SchemaRDD}. @@ -1227,8 +1287,8 @@ def jsonFile(self, path, schema=None): If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine - the schema. + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -1274,20 +1334,20 @@ def jsonFile(self, path, schema=None): [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - srdd = self._ssql_ctx.jsonFile(path) + srdd = self._ssql_ctx.jsonFile(path, samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) - def jsonRDD(self, rdd, schema=None): + def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine - the schema. + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. >>> srdd1 = sqlCtx.jsonRDD(json) >>> sqlCtx.registerRDDAsTable(srdd1, "table1") @@ -1344,7 +1404,7 @@ def func(iterator): keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 253a471849c3a..68fd756876219 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -796,6 +796,25 @@ def test_serialize_nested_array_and_map(self): self.assertEqual(1.0, row.c) self.assertEqual("2", row.d) + def test_infer_schema(self): + d = [Row(l=[], d={}), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], srdd.map(lambda r: r.l).first()) + self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect()) + srdd.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.first()[0]) + + srdd2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(srdd.schema(), srdd2.schema()) + self.assertEqual({}, srdd2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect()) + srdd2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.first()[0]) + def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index cc5015ad3c013..e1b5992a36e5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -213,7 +213,7 @@ trait PrimitiveType extends DataType { } object PrimitiveType { - private val nonDecimals = Seq(DateType, TimestampType, BinaryType) ++ NativeType.all + private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap /** Given the string representation of a type, return its DataType */ From c238fb423d1011bd1b1e6201d769b72e52664fc6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 3 Nov 2014 13:20:33 -0800 Subject: [PATCH 13/79] [SPARK-4202][SQL] Simple DSL support for Scala UDF This feature is based on an offline discussion with mengxr, hopefully can be useful for the new MLlib pipeline API. For the following test snippet ```scala case class KeyValue(key: Int, value: String) val testData = sc.parallelize(1 to 10).map(i => KeyValue(i, i.toString)).toSchemaRDD def foo(a: Int, b: String) => a.toString + b ``` the newly introduced DSL enables the following syntax ```scala import org.apache.spark.sql.catalyst.dsl._ testData.select(Star(None), foo.call('key, 'value) as 'result) ``` which is equivalent to ```scala testData.registerTempTable("testData") sqlContext.registerFunction("foo", foo) sql("SELECT *, foo(key, value) AS result FROM testData") ``` Author: Cheng Lian Closes #3067 from liancheng/udf-dsl and squashes the following commits: f132818 [Cheng Lian] Adds DSL support for Scala UDF --- .../spark/sql/catalyst/dsl/package.scala | 59 +++++++++++++++++++ .../org/apache/spark/sql/DslQuerySuite.scala | 17 ++++-- 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 7e6d770314f5a..3314e15477016 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.language.implicitConversions +import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ @@ -285,4 +286,62 @@ package object dsl { def writeToFile(path: String) = WriteToFile(path, logicalPlan) } } + + case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { + def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + } + + // scalastyle:off + /** functionToUdfBuilder 1-22 were generated by this script + + (1 to 22).map { x => + val argTypes = Seq.fill(x)("_").mkString(", ") + s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]) = ScalaUdfBuilder(func)" + } + */ + + implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + // scalastyle:on } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 45e58afe9d9a2..e70ad891eea36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.test._ /* Implicits */ -import TestSQLContext._ +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.test.TestSQLContext._ class DslQuerySuite extends QueryTest { - import TestData._ + import org.apache.spark.sql.TestData._ test("table scan") { checkAnswer( @@ -216,4 +215,14 @@ class DslQuerySuite extends QueryTest { (4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) } + + test("udf") { + val foo = (a: Int, b: String) => a.toString + b + + checkAnswer( + // SELECT *, foo(key, value) FROM testData + testData.select(Star(None), foo.call('key, 'value)).limit(3), + (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil + ) + } } From e83f13e8d37ca33f4e183e977d077221b90c6025 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 3 Nov 2014 13:59:43 -0800 Subject: [PATCH 14/79] [SPARK-4152] [SQL] Avoid data change in CTAS while table already existed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CREATE TABLE t1 (a String); CREATE TABLE t1 AS SELECT key FROM src; – throw exception CREATE TABLE if not exists t1 AS SELECT key FROM src; – expect do nothing, currently it will overwrite the t1, which is incorrect. Author: Cheng Hao Closes #3013 from chenghao-intel/ctas_unittest and squashes the following commits: 194113e [Cheng Hao] fix bug in CTAS when table already existed --- .../spark/sql/catalyst/analysis/Catalog.scala | 22 +++++++++++++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 6 +++++ .../hive/execution/CreateTableAsSelect.scala | 12 +++++++++- .../sql/hive/execution/SQLQuerySuite.scala | 9 ++++++-- 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 2059a91ba0612..0415d74bd8141 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -28,6 +28,8 @@ trait Catalog { def caseSensitive: Boolean + def tableExists(db: Option[String], tableName: String): Boolean + def lookupRelation( databaseName: Option[String], tableName: String, @@ -82,6 +84,14 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { tables.clear() } + override def tableExists(db: Option[String], tableName: String): Boolean = { + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + tables.get(tblName) match { + case Some(_) => true + case None => false + } + } + override def lookupRelation( databaseName: Option[String], tableName: String, @@ -107,6 +117,14 @@ trait OverrideCatalog extends Catalog { // TODO: This doesn't work when the database changes... val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]() + abstract override def tableExists(db: Option[String], tableName: String): Boolean = { + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + overrides.get((dbName, tblName)) match { + case Some(_) => true + case None => super.tableExists(db, tableName) + } + } + abstract override def lookupRelation( databaseName: Option[String], tableName: String, @@ -149,6 +167,10 @@ object EmptyCatalog extends Catalog { val caseSensitive: Boolean = true + def tableExists(db: Option[String], tableName: String): Boolean = { + throw new UnsupportedOperationException + } + def lookupRelation( databaseName: Option[String], tableName: String, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 096b4a07aa2ea..0baf4c9f8c7ab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -57,6 +57,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val caseSensitive: Boolean = false + def tableExists(db: Option[String], tableName: String): Boolean = { + val (databaseName, tblName) = processDatabaseAndTableName( + db.getOrElse(hive.sessionState.getCurrentDatabase), tableName) + client.getTable(databaseName, tblName, false) != null + } + def lookupRelation( db: Option[String], tableName: String, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 2fce414734579..3d24d87bc3d38 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -71,7 +71,17 @@ case class CreateTableAsSelect( // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. - sc.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true)).toRdd + if (sc.catalog.tableExists(Some(database), tableName)) { + if (allowExisting) { + // table already exists, will do nothing, to keep consistent with Hive + } else { + throw + new org.apache.hadoop.hive.metastore.api.AlreadyExistsException(s"$database.$tableName") + } + } else { + sc.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true)).toRdd + } + Seq.empty[Row] } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 76a0ec01a6075..e9b1943ff8db7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -56,7 +56,7 @@ class SQLQuerySuite extends QueryTest { sql( """CREATE TABLE IF NOT EXISTS ctas4 AS | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect - // expect the string => integer for field key cause the table ctas4 already existed. + // do nothing cause the table ctas4 already existed. sql( """CREATE TABLE IF NOT EXISTS ctas4 AS | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect @@ -78,9 +78,14 @@ class SQLQuerySuite extends QueryTest { SELECT key, value FROM src ORDER BY key, value""").collect().toSeq) + intercept[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] { + sql( + """CREATE TABLE ctas4 AS + | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect + } checkAnswer( sql("SELECT key, value FROM ctas4 ORDER BY key, value"), - sql("SELECT CAST(key AS int) k, value FROM src ORDER BY k, value").collect().toSeq) + sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) checkExistence(sql("DESC EXTENDED ctas2"), true, "name:key", "type:string", "name:value", "ctas2", From 25bef7e6951301e93004567fc0cef96bf8d1a224 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 3 Nov 2014 14:08:27 -0800 Subject: [PATCH 15/79] [SQL] More aggressive defaults - Turns on compression for in-memory cached data by default - Changes the default parquet compression format back to gzip (we have seen more OOMs with production workloads due to the way Snappy allocates memory) - Ups the batch size to 10,000 rows - Increases the broadcast threshold to 10mb. - Uses our parquet implementation instead of the hive one by default. - Cache parquet metadata by default. Author: Michael Armbrust Closes #3064 from marmbrus/fasterDefaults and squashes the following commits: 97ee9f8 [Michael Armbrust] parquet codec docs e641694 [Michael Armbrust] Remote also a12866a [Michael Armbrust] Cache metadata. 2d73acc [Michael Armbrust] Update docs defaults. d63d2d5 [Michael Armbrust] document parquet option da373f9 [Michael Armbrust] More aggressive defaults --- docs/sql-programming-guide.md | 18 +++++++++++++----- .../scala/org/apache/spark/sql/SQLConf.scala | 10 +++++----- .../sql/parquet/ParquetTableOperations.scala | 6 +++--- .../apache/spark/sql/hive/HiveContext.scala | 2 +- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d4ade939c3a6e..e399fecbbc78c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -582,19 +582,27 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or spark.sql.parquet.cacheMetadata - false + true Turns on caching of Parquet schema metadata. Can speed up querying of static data. spark.sql.parquet.compression.codec - snappy + gzip Sets the compression codec use when writing Parquet files. Acceptable values include: uncompressed, snappy, gzip, lzo. + + spark.sql.hive.convertMetastoreParquet + true + + When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of the built in + support. + + ## JSON Datasets @@ -815,7 +823,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL Property NameDefaultMeaning spark.sql.inMemoryColumnarStorage.compressed - false + true When set to true Spark SQL will automatically select a compression codec for each column based on statistics of the data. @@ -823,7 +831,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL spark.sql.inMemoryColumnarStorage.batchSize - 1000 + 10000 Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization and compression, but risk OOMs when caching data. @@ -841,7 +849,7 @@ that these options will be deprecated in future release as more optimizations ar Property NameDefaultMeaning spark.sql.autoBroadcastJoinThreshold - 10000 + 10485760 (10 MB) Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 07e6e2eccddf4..279495aa64755 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -79,13 +79,13 @@ private[sql] trait SQLConf { private[spark] def dialect: String = getConf(DIALECT, "sql") /** When true tables cached using the in-memory columnar caching will be compressed. */ - private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean + private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "true").toBoolean /** The compression codec for writing to a Parquetfile */ - private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "snappy") + private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "gzip") /** The number of rows that will be */ - private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "1000").toInt + private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "10000").toInt /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt @@ -106,10 +106,10 @@ private[sql] trait SQLConf { * a broadcast value during the physical executions of join operations. Setting this to -1 * effectively disables auto conversion. * - * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is also 10000. + * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is 10000. */ private[spark] def autoBroadcastJoinThreshold: Int = - getConf(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt + getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 9664c565a0b86..d00860a8bb8a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -123,7 +123,7 @@ case class ParquetTableScan( // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata conf.set( SQLConf.PARQUET_CACHE_METADATA, - sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "false")) + sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true")) val baseRDD = new org.apache.spark.rdd.NewHadoopRDD( @@ -394,7 +394,7 @@ private[parquet] class FilteringParquetRowInputFormat if (footers eq null) { val conf = ContextUtil.getConfiguration(jobContext) - val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) + val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) val statuses = listStatus(jobContext) fileStatuses = statuses.map(file => file.getPath -> file).toMap if (statuses.isEmpty) { @@ -493,7 +493,7 @@ private[parquet] class FilteringParquetRowInputFormat import parquet.filter2.compat.FilterCompat.Filter; import parquet.filter2.compat.RowGroupFilter; - val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) + val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] val filter: Filter = ParquetInputFormat.getFilter(configuration) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index f025169ad5063..e88afaaf001c0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -90,7 +90,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * SerDe. */ private[spark] def convertMetastoreParquet: Boolean = - getConf("spark.sql.hive.convertMetastoreParquet", "false") == "true" + getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } From 28128150e7e0c2b7d1c483e67214bdaef59f7d75 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Mon, 3 Nov 2014 15:19:01 -0800 Subject: [PATCH 16/79] SPARK-4178. Hadoop input metrics ignore bytes read in RecordReader insta... ...ntiation Author: Sandy Ryza Closes #3045 from sryza/sandy-spark-4178 and squashes the following commits: 8d2e70e [Sandy Ryza] Kostas's review feedback e5b27c0 [Sandy Ryza] SPARK-4178. Hadoop input metrics ignore bytes read in RecordReader instantiation --- .../org/apache/spark/rdd/HadoopRDD.scala | 25 +++++++++-------- .../org/apache/spark/rdd/NewHadoopRDD.scala | 26 +++++++++--------- .../spark/metrics/InputMetricsSuite.scala | 27 +++++++++++++++++-- 3 files changed, 53 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 946fb5616d3ec..a157e36e2286e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -211,20 +211,11 @@ class HadoopRDD[K, V]( val split = theSplit.asInstanceOf[HadoopPartition] logInfo("Input split: " + split.inputSplit) - var reader: RecordReader[K, V] = null val jobConf = getJobConf() - val inputFormat = getInputFormat(jobConf) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.stageId, theSplit.index, context.attemptId.toInt, jobConf) - reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) - - // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener{ context => closeIfNeeded() } - val key: K = reader.createKey() - val value: V = reader.createValue() val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - // Find a function that will return the FileSystem bytes read by this thread. + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) { SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf) @@ -234,6 +225,18 @@ class HadoopRDD[K, V]( if (bytesReadCallback.isDefined) { context.taskMetrics.inputMetrics = Some(inputMetrics) } + + var reader: RecordReader[K, V] = null + val inputFormat = getInputFormat(jobConf) + HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), + context.stageId, theSplit.index, context.attemptId.toInt, jobConf) + reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener{ context => closeIfNeeded() } + val key: K = reader.createKey() + val value: V = reader.createValue() + var recordsSinceMetricsUpdate = 0 override def getNext() = { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 6d6b86721ca74..351e145f96f9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -107,20 +107,10 @@ class NewHadoopRDD[K, V]( val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) val conf = confBroadcast.value.value - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) - val format = inputFormatClass.newInstance - format match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - // Find a function that will return the FileSystem bytes read by this thread. + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf) @@ -131,6 +121,18 @@ class NewHadoopRDD[K, V]( context.taskMetrics.inputMetrics = Some(inputMetrics) } + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val format = inputFormatClass.newInstance + format match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) var havePair = false diff --git a/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala index 33bd1afea2470..48c386ba04311 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala @@ -27,7 +27,7 @@ import scala.collection.mutable.ArrayBuffer import java.io.{FileWriter, PrintWriter, File} class InputMetricsSuite extends FunSuite with SharedSparkContext { - test("input metrics when reading text file") { + test("input metrics when reading text file with single split") { val file = new File(getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(file)) pw.println("some stuff") @@ -48,6 +48,29 @@ class InputMetricsSuite extends FunSuite with SharedSparkContext { // Wait for task end events to come in sc.listenerBus.waitUntilEmpty(500) assert(taskBytesRead.length == 2) - assert(taskBytesRead.sum == file.length()) + assert(taskBytesRead.sum >= file.length()) + } + + test("input metrics when reading text file with multiple splits") { + val file = new File(getClass.getSimpleName + ".txt") + val pw = new PrintWriter(new FileWriter(file)) + for (i <- 0 until 10000) { + pw.println("some stuff") + } + pw.close() + file.deleteOnExit() + + val taskBytesRead = new ArrayBuffer[Long]() + sc.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead + } + }) + sc.textFile("file://" + file.getAbsolutePath, 2).count() + + // Wait for task end events to come in + sc.listenerBus.waitUntilEmpty(500) + assert(taskBytesRead.length == 2) + assert(taskBytesRead.sum >= file.length()) } } From 15b58a2234ab7ba30c9c0cbb536177a3c725e350 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 3 Nov 2014 18:04:51 -0800 Subject: [PATCH 17/79] [SQL] Convert arguments to Scala UDFs Author: Michael Armbrust Closes #3077 from marmbrus/udfsWithUdts and squashes the following commits: 34b5f27 [Michael Armbrust] style 504adef [Michael Armbrust] Convert arguments to Scala UDFs --- .../sql/catalyst/expressions/ScalaUdf.scala | 560 ++++++++++-------- .../spark/sql/UserDefinedTypeSuite.scala | 18 +- 2 files changed, 316 insertions(+), 262 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index fa1786e74bb3e..18c96da2f87fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -34,320 +34,366 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi override def toString = s"scalaUDF(${children.mkString(",")})" + // scalastyle:off + /** This method has been generated by this script (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) - val evals = (0 to x - 1).map(x => s"children($x).eval(input)").reduce(_ + ",\n " + _) + val evals = (0 to x - 1).map(x => s" ScalaReflection.convertToScala(children($x).eval(input), children($x).dataType)").reduce(_ + ",\n " + _) s""" case $x => function.asInstanceOf[($anys) => Any]( - $evals) + $evals) """ - } + }.foreach(println) */ - // scalastyle:off override def eval(input: Row): Any = { val result = children.size match { case 0 => function.asInstanceOf[() => Any]() - case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) + case 1 => + function.asInstanceOf[(Any) => Any]( + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType)) + + case 2 => function.asInstanceOf[(Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType)) + + case 3 => function.asInstanceOf[(Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType)) + + case 4 => function.asInstanceOf[(Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType)) + + case 5 => function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType)) + + case 6 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType)) + + case 7 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType)) + + case 8 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType)) + + case 9 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType)) + + case 10 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType)) + + case 11 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType)) + + case 12 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType)) + + case 13 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType)) + + case 14 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType)) + + case 15 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType)) + + case 16 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType)) + + case 17 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType)) + + case 18 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType)) + + case 19 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType)) + + case 20 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType)) + + case 21 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input), - children(20).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), + ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType)) + + case 22 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input), - children(20).eval(input), - children(21).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), + ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType), + ScalaReflection.convertToScala(children(21).eval(input), children(21).dataType)) + } // scalastyle:on diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 666235e57f812..1806a1dd82023 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -60,13 +60,13 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } class UserDefinedTypeSuite extends QueryTest { + val points = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) + val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) - test("register user type: MyDenseVector for MyLabeledPoint") { - val points = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) + test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) @@ -80,4 +80,12 @@ class UserDefinedTypeSuite extends QueryTest { assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) } + + test("UDTs and UDFs") { + registerFunction("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + pointsRDD.registerTempTable("points") + checkAnswer( + sql("SELECT testType(features) from points"), + Seq(Row(true), Row(true))) + } } From 97a466eca0a629f17e9662ca2b59eeca99142c54 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Mon, 3 Nov 2014 18:17:32 -0800 Subject: [PATCH 18/79] [SPARK-4168][WebUI] web statges number should show correctly when stages are more than 1000 The number of completed stages and failed stages showed on webUI will always be less than 1000. This is really misleading when there are already thousands of stages completed or failed. The number should be correct even when only partial stages listed on the webUI (stage info will be removed if the number is too large). Author: Zhang, Liye Closes #3035 from liyezhang556520/webStageNum and squashes the following commits: d9e29fb [Zhang, Liye] add detailed comments for variables 4ea8fd1 [Zhang, Liye] change variable name accroding to comments f4c404d [Zhang, Liye] [SPARK-4168][WebUI] web statges number should show correctly when stages are more than 1000 --- .../org/apache/spark/ui/jobs/JobProgressListener.scala | 9 +++++++++ .../org/apache/spark/ui/jobs/JobProgressPage.scala | 10 ++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index b5207360510dd..e3223403c17f4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -59,6 +59,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val failedStages = ListBuffer[StageInfo]() val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData] val stageIdToInfo = new HashMap[StageId, StageInfo] + + // Number of completed and failed stages, may not actually equal to completedStages.size and + // failedStages.size respectively due to completedStage and failedStages only maintain the latest + // part of the stages, the earlier ones will be removed when there are too many stages for + // memory sake. + var numCompletedStages = 0 + var numFailedStages = 0 // Map from pool name to a hash map (map from stage id to StageInfo). val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]() @@ -110,9 +117,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { activeStages.remove(stage.stageId) if (stage.failureReason.isEmpty) { completedStages += stage + numCompletedStages += 1 trimIfNecessary(completedStages) } else { failedStages += stage + numFailedStages += 1 trimIfNecessary(failedStages) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala index 6e718eecdd52a..83a7898071c9b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala @@ -34,7 +34,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") listener.synchronized { val activeStages = listener.activeStages.values.toSeq val completedStages = listener.completedStages.reverse.toSeq + val numCompletedStages = listener.numCompletedStages val failedStages = listener.failedStages.reverse.toSeq + val numFailedStages = listener.numFailedStages val now = System.currentTimeMillis val activeStagesTable = @@ -69,11 +71,11 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
  • Completed Stages: - {completedStages.size} + {numCompletedStages}
  • Failed Stages: - {failedStages.size} + {numFailedStages}
  • @@ -86,9 +88,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") }} ++

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq ++ -

    Completed Stages ({completedStages.size})

    ++ +

    Completed Stages ({numCompletedStages})

    ++ completedStagesTable.toNodeSeq ++ -

    Failed Stages ({failedStages.size})

    ++ +

    Failed Stages ({numFailedStages})

    ++ failedStagesTable.toNodeSeq UIUtils.headerSparkPage("Spark Stages", content, parent) From 4f035dd2cd6f1ec9059811f3495f3e0a8ec5fb84 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 3 Nov 2014 18:18:47 -0800 Subject: [PATCH 19/79] [SPARK-611] Display executor thread dumps in web UI This patch allows executor thread dumps to be collected on-demand and viewed in the Spark web UI. The thread dumps are collected using Thread.getAllStackTraces(). To allow remote thread dumps to be triggered from the web UI, I added a new `ExecutorActor` that runs inside of the Executor actor system and responds to RPCs from the driver. The driver's mechanism for obtaining a reference to this actor is a little bit hacky: it uses the block manager master actor to determine the host/port of the executor actor systems in order to construct ActorRefs to ExecutorActor. Unfortunately, I couldn't find a much cleaner way to do this without a big refactoring of the executor -> driver communication. Screenshots: ![image](https://cloud.githubusercontent.com/assets/50748/4781793/7e7a0776-5cbf-11e4-874d-a91cd04620bd.png) ![image](https://cloud.githubusercontent.com/assets/50748/4781794/8bce76aa-5cbf-11e4-8d13-8477748c9f7e.png) ![image](https://cloud.githubusercontent.com/assets/50748/4781797/bd11a8b8-5cbf-11e4-9ad7-a7459467ec8e.png) Author: Josh Rosen Closes #2944 from JoshRosen/jstack-in-web-ui and squashes the following commits: 3c21a5d [Josh Rosen] Address review comments: 880f7f7 [Josh Rosen] Merge remote-tracking branch 'origin/master' into jstack-in-web-ui f719266 [Josh Rosen] Merge remote-tracking branch 'origin/master' into jstack-in-web-ui 19707b0 [Josh Rosen] Add one comment. 127a130 [Josh Rosen] Update to use SparkContext.DRIVER_IDENTIFIER b8e69aa [Josh Rosen] Merge remote-tracking branch 'origin/master' into jstack-in-web-ui 3dfc2d4 [Josh Rosen] Add missing file. bc1e675 [Josh Rosen] Undo some leftover changes from the earlier approach. f4ac1c1 [Josh Rosen] Switch to on-demand collection of thread dumps dfec08b [Josh Rosen] Add option to disable thread dumps in UI. 4c87d7f [Josh Rosen] Use separate RPC for sending thread dumps. 2b8bdf3 [Josh Rosen] Enable thread dumps from the driver when running in non-local mode. cc3e6b3 [Josh Rosen] Fix test code in DAGSchedulerSuite. 87b8b65 [Josh Rosen] Add new listener event for thread dumps. 8c10216 [Josh Rosen] Add missing file. 0f198ac [Josh Rosen] [SPARK-611] Display executor thread dumps in web UI --- .../scala/org/apache/spark/SparkContext.scala | 29 +++++++- .../CoarseGrainedExecutorBackend.scala | 3 +- .../org/apache/spark/executor/Executor.scala | 7 +- .../apache/spark/executor/ExecutorActor.scala | 41 +++++++++++ .../spark/storage/BlockManagerMaster.scala | 4 + .../storage/BlockManagerMasterActor.scala | 18 +++++ .../spark/storage/BlockManagerMessages.scala | 2 + .../ui/exec/ExecutorThreadDumpPage.scala | 73 +++++++++++++++++++ .../apache/spark/ui/exec/ExecutorsPage.scala | 15 +++- .../apache/spark/ui/exec/ExecutorsTab.scala | 8 +- .../org/apache/spark/util/AkkaUtils.scala | 14 ++++ .../apache/spark/util/ThreadStackTrace.scala | 27 +++++++ .../scala/org/apache/spark/util/Utils.scala | 13 ++++ 13 files changed, 247 insertions(+), 7 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala create mode 100644 core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8b4db783979ec..40444c237b738 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -21,9 +21,8 @@ import scala.language.implicitConversions import java.io._ import java.net.URI -import java.util.Arrays +import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger -import java.util.{Properties, UUID} import java.util.UUID.randomUUID import scala.collection.{Map, Set} import scala.collection.generic.Growable @@ -41,6 +40,7 @@ import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} +import org.apache.spark.executor.TriggerThreadDump import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ @@ -51,7 +51,7 @@ import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ import org.apache.spark.ui.SparkUI import org.apache.spark.ui.jobs.JobProgressListener -import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} +import org.apache.spark.util._ /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -361,6 +361,29 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { override protected def childValue(parent: Properties): Properties = new Properties(parent) } + /** + * Called by the web UI to obtain executor thread dumps. This method may be expensive. + * Logs an error and returns None if we failed to obtain a thread dump, which could occur due + * to an executor being dead or unresponsive or due to network issues while sending the thread + * dump message back to the driver. + */ + private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = { + try { + if (executorId == SparkContext.DRIVER_IDENTIFIER) { + Some(Utils.getThreadDump()) + } else { + val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get + val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) + Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, + AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + } + } catch { + case e: Exception => + logError(s"Exception getting thread dump from executor $executorId", e) + None + } + } + private[spark] def getLocalProperties: Properties = localProperties.get() private[spark] def setLocalProperties(props: Properties) { diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 697154d762d41..3711824a40cfc 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -131,7 +131,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Create a new ActorSystem using driver's Spark properties to run the backend. val driverConf = new SparkConf().setAll(props) val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf)) + SparkEnv.executorActorSystemName, + hostname, port, driverConf, new SecurityManager(driverConf)) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e24a15f015e1c..8b095e23f32ff 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.ActorSystem +import akka.actor.{Props, ActorSystem} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -92,6 +92,10 @@ private[spark] class Executor( } } + // Create an actor for receiving RPCs from the driver + private val executorActor = env.actorSystem.actorOf( + Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() @@ -131,6 +135,7 @@ private[spark] class Executor( def stop() { env.metricsSystem.report() + env.actorSystem.stop(executorActor) isStopped = true threadPool.shutdown() if (!isLocal) { diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala new file mode 100644 index 0000000000000..41925f7e97e84 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import akka.actor.Actor +import org.apache.spark.Logging + +import org.apache.spark.util.{Utils, ActorLogReceive} + +/** + * Driver -> Executor message to trigger a thread dump. + */ +private[spark] case object TriggerThreadDump + +/** + * Actor that runs inside of executors to enable driver -> executor RPC. + */ +private[spark] +class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { + case TriggerThreadDump => + sender ! Utils.getThreadDump() + } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index d08e1419e3e41..b63c7f191155c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -88,6 +88,10 @@ class BlockManagerMaster( askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } + def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + } + /** * Remove a block from the slaves that have it. This can only be used to remove * blocks that the driver knows about. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 5e375a2553979..685b2e11440fb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -86,6 +86,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetPeers(blockManagerId) => sender ! getPeers(blockManagerId) + case GetActorSystemHostPortForExecutor(executorId) => + sender ! getActorSystemHostPortForExecutor(executorId) + case GetMemoryStatus => sender ! memoryStatus @@ -412,6 +415,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus Seq.empty } } + + /** + * Returns the hostname and port of an executor's actor system, based on the Akka address of its + * BlockManagerSlaveActor. + */ + private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + for ( + blockManagerId <- blockManagerIdByExecutor.get(executorId); + info <- blockManagerInfo.get(blockManagerId); + host <- info.slaveActor.path.address.host; + port <- info.slaveActor.path.address.port + ) yield { + (host, port) + } + } } @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 291ddfcc113ac..3f32099d08cc9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -92,6 +92,8 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class RemoveExecutor(execId: String) extends ToBlockManagerMaster case object StopBlockManagerMaster extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala new file mode 100644 index 0000000000000..e9c755e36f716 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.exec + +import javax.servlet.http.HttpServletRequest + +import scala.util.Try +import scala.xml.{Text, Node} + +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") { + + private val sc = parent.sc + + def render(request: HttpServletRequest): Seq[Node] = { + val executorId = Option(request.getParameter("executorId")).getOrElse { + return Text(s"Missing executorId parameter") + } + val time = System.currentTimeMillis() + val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) + + val content = maybeThreadDump.map { threadDump => + val dumpRows = threadDump.map { thread => + + } + +
    +

    Updated at {UIUtils.formatDate(time)}

    + { + // scalastyle:off +

    + Expand All +

    +

    + // scalastyle:on + } +
    {dumpRows}
    +
    + }.getOrElse(Text("Error fetching thread dump")) + UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent) + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b0e3bb3b552fd..048fee3ce1ff4 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -41,7 +41,10 @@ private case class ExecutorSummaryInfo( totalShuffleWrite: Long, maxMemory: Long) -private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { +private[ui] class ExecutorsPage( + parent: ExecutorsTab, + threadDumpEnabled: Boolean) + extends WebUIPage("") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -75,6 +78,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { Shuffle Write + {if (threadDumpEnabled) Thread Dump else Seq.empty} {execInfoSorted.map(execRow)} @@ -133,6 +137,15 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { {Utils.bytesToString(info.totalShuffleWrite)} + { + if (threadDumpEnabled) { + + Thread Dump + + } else { + Seq.empty + } + } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 9e0e71a51a408..ba97630f025c1 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -27,8 +27,14 @@ import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { val listener = parent.executorsListener + val sc = parent.sc + val threadDumpEnabled = + sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true) - attachPage(new ExecutorsPage(this)) + attachPage(new ExecutorsPage(this, threadDumpEnabled)) + if (threadDumpEnabled) { + attachPage(new ExecutorThreadDumpPage(this)) + } } /** diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 79e398eb8c104..10010bdfa1a51 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -212,4 +212,18 @@ private[spark] object AkkaUtils extends Logging { logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } + + def makeExecutorRef( + name: String, + conf: SparkConf, + host: String, + port: Int, + actorSystem: ActorSystem): ActorRef = { + val executorActorSystemName = SparkEnv.executorActorSystemName + Utils.checkHost(host, "Expected hostname") + val url = s"akka.tcp://$executorActorSystemName@$host:$port/user/$name" + val timeout = AkkaUtils.lookupTimeout(conf) + logInfo(s"Connecting to $name: $url") + Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + } } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala new file mode 100644 index 0000000000000..d4e0ad93b966a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Used for shipping per-thread stacktraces from the executors to driver. + */ +private[spark] case class ThreadStackTrace( + threadId: Long, + threadName: String, + threadState: Thread.State, + stackTrace: String) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a33046d2040d8..6ab94af9f3739 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io._ +import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer import java.util.jar.Attributes.Name @@ -1611,6 +1612,18 @@ private[spark] object Utils extends Logging { s"$className: $desc\n$st" } + /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ + def getThreadDump(): Array[ThreadStackTrace] = { + // We need to filter out null values here because dumpAllThreads() may return null array + // elements for threads that are dead / don't exist. + val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) + threadInfos.sortBy(_.getThreadId).map { case threadInfo => + val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n") + ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, + threadInfo.getThreadState, stackTrace) + } + } + /** * Convert all spark properties set in the given SparkConf to a sequence of java options. */ From c5912ecc7b392a13089ae735c07c2d7256de36c6 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 3 Nov 2014 18:50:37 -0800 Subject: [PATCH 20/79] [FIX][MLLIB] fix seed in BaggedPointSuite Saw Jenkins test failures due to random seeds. jkbradley manishamde Author: Xiangrui Meng Closes #3084 from mengxr/fix-baggedpoint-suite and squashes the following commits: f735a43 [Xiangrui Meng] fix seed in BaggedPointSuite --- .../spark/mllib/tree/impl/BaggedPointSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala index c0a62e00432a3..5cb433232e714 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -30,7 +30,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) baggedRDD.collect().foreach { baggedPoint => assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) } @@ -44,7 +44,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) @@ -60,7 +60,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) @@ -75,7 +75,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) @@ -91,7 +91,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) From 04450d11548cfb25d4fb77d4a33e3a7cd4254183 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 3 Nov 2014 19:29:11 -0800 Subject: [PATCH 21/79] [SPARK-4192][SQL] Internal API for Python UDT Following #2919, this PR adds Python UDT (for internal use only) with tests under "pyspark.tests". Before `SQLContext.applySchema`, we check whether we need to convert user-type instances into SQL recognizable data. In the current implementation, a Python UDT must be paired with a Scala UDT for serialization on the JVM side. A following PR will add VectorUDT in MLlib for both Scala and Python. marmbrus jkbradley davies Author: Xiangrui Meng Closes #3068 from mengxr/SPARK-4192-sql and squashes the following commits: acff637 [Xiangrui Meng] merge master dba5ea7 [Xiangrui Meng] only use pyClass for Python UDT output sqlType as well 2c9d7e4 [Xiangrui Meng] move import to global setup; update needsConversion 7c4a6a9 [Xiangrui Meng] address comments 75223db [Xiangrui Meng] minor update f740379 [Xiangrui Meng] remove UDT from default imports e98d9d0 [Xiangrui Meng] fix py style 4e84fce [Xiangrui Meng] remove local hive tests and add more tests 39f19e0 [Xiangrui Meng] add tests b7f666d [Xiangrui Meng] add Python UDT --- python/pyspark/sql.py | 206 +++++++++++++++++- python/pyspark/tests.py | 93 +++++++- .../spark/sql/catalyst/types/dataTypes.scala | 9 +- .../org/apache/spark/sql/SQLContext.scala | 2 + .../spark/sql/execution/pythonUdfs.scala | 5 + .../spark/sql/test/ExamplePointUDT.scala | 64 ++++++ .../sql/types/util/DataTypeConversions.scala | 1 - 7 files changed, 375 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 675df084bf303..d16c18bc79fe4 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -417,6 +417,75 @@ def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) +class UserDefinedType(DataType): + """ + :: WARN: Spark Internal Use Only :: + SQL User-Defined Type (UDT). + """ + + @classmethod + def typeName(cls): + return cls.__name__.lower() + + @classmethod + def sqlType(cls): + """ + Underlying SQL storage type for this UDT. + """ + raise NotImplementedError("UDT must implement sqlType().") + + @classmethod + def module(cls): + """ + The Python module of the UDT. + """ + raise NotImplementedError("UDT must implement module().") + + @classmethod + def scalaUDT(cls): + """ + The class name of the paired Scala UDT. + """ + raise NotImplementedError("UDT must have a paired Scala UDT.") + + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement serialize().") + + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement deserialize().") + + def json(self): + return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) + + def jsonValue(self): + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + return schema + + @classmethod + def fromJson(cls, json): + pyUDT = json["pyClass"] + split = pyUDT.rfind(".") + pyModule = pyUDT[:split] + pyClass = pyUDT[split+1:] + m = __import__(pyModule, globals(), locals(), [pyClass], -1) + UDT = getattr(m, pyClass) + return UDT() + + def __eq__(self, other): + return type(self) == type(other) + + _all_primitive_types = dict((v.typeName(), v) for v in globals().itervalues() if type(v) is PrimitiveTypeSingleton and @@ -469,6 +538,12 @@ def _parse_datatype_json_string(json_string): ... complex_arraytype, False) >>> check_datatype(complex_maptype) True + >>> check_datatype(ExamplePointUDT()) + True + >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> check_datatype(structtype_with_udt) + True """ return _parse_datatype_json_value(json.loads(json_string)) @@ -488,7 +563,13 @@ def _parse_datatype_json_value(json_value): else: raise ValueError("Could not parse datatype: %s" % json_value) else: - return _all_complex_types[json_value["type"]].fromJson(json_value) + tpe = json_value["type"] + if tpe in _all_complex_types: + return _all_complex_types[tpe].fromJson(json_value) + elif tpe == 'udt': + return UserDefinedType.fromJson(json_value) + else: + raise ValueError("not supported type: %s" % tpe) # Mapping Python types to Spark SQL DataType @@ -509,7 +590,18 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): - """Infer the DataType from obj""" + """Infer the DataType from obj + + >>> p = ExamplePoint(1.0, 2.0) + >>> _infer_type(p) + ExamplePointUDT + """ + if obj is None: + raise ValueError("Can not infer type for None") + + if hasattr(obj, '__UDT__'): + return obj.__UDT__ + dataType = _type_mappings.get(type(obj)) if dataType is not None: return dataType() @@ -558,6 +650,93 @@ def _infer_schema(row): return StructType(fields) +def _need_python_to_sql_conversion(dataType): + """ + Checks whether we need python to sql conversion for the given type. + For now, only UDTs need this conversion. + + >>> _need_python_to_sql_conversion(DoubleType()) + False + >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), + ... StructField("values", ArrayType(DoubleType(), False), False)]) + >>> _need_python_to_sql_conversion(schema0) + False + >>> _need_python_to_sql_conversion(ExamplePointUDT()) + True + >>> schema1 = ArrayType(ExamplePointUDT(), False) + >>> _need_python_to_sql_conversion(schema1) + True + >>> schema2 = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> _need_python_to_sql_conversion(schema2) + True + """ + if isinstance(dataType, StructType): + return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + elif isinstance(dataType, ArrayType): + return _need_python_to_sql_conversion(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_python_to_sql_conversion(dataType.keyType) or \ + _need_python_to_sql_conversion(dataType.valueType) + elif isinstance(dataType, UserDefinedType): + return True + else: + return False + + +def _python_to_sql_converter(dataType): + """ + Returns a converter that converts a Python object into a SQL datum for the given type. + + >>> conv = _python_to_sql_converter(DoubleType()) + >>> conv(1.0) + 1.0 + >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) + >>> conv([1.0, 2.0]) + [1.0, 2.0] + >>> conv = _python_to_sql_converter(ExamplePointUDT()) + >>> conv(ExamplePoint(1.0, 2.0)) + [1.0, 2.0] + >>> schema = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> conv = _python_to_sql_converter(schema) + >>> conv((1.0, ExamplePoint(1.0, 2.0))) + (1.0, [1.0, 2.0]) + """ + if not _need_python_to_sql_conversion(dataType): + return lambda x: x + + if isinstance(dataType, StructType): + names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) + converters = map(_python_to_sql_converter, types) + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + return tuple(c(v) for c, v in zip(converters, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs + d = dict(obj) + return tuple(c(d.get(n)) for n, c in zip(names, converters)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + else: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return converter + elif isinstance(dataType, ArrayType): + element_converter = _python_to_sql_converter(dataType.elementType) + return lambda a: [element_converter(v) for v in a] + elif isinstance(dataType, MapType): + key_converter = _python_to_sql_converter(dataType.keyType) + value_converter = _python_to_sql_converter(dataType.valueType) + return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): + return lambda obj: dataType.serialize(obj) + else: + raise ValueError("Unexpected type %r" % dataType) + + def _has_nulltype(dt): """ Return whether there is NullType in `dt` or not """ if isinstance(dt, StructType): @@ -818,11 +997,22 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... + >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... """ # all objects are nullable if obj is None: return + if isinstance(dataType, UserDefinedType): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError("%r is not an instance of type %r" % (obj, dataType)) + _verify_type(dataType.serialize(obj), dataType.sqlType()) + return + _type = type(dataType) assert _type in _acceptable_types, "unkown datatype: %s" % dataType @@ -897,6 +1087,8 @@ def _has_struct_or_date(dt): return _has_struct_or_date(dt.valueType) elif isinstance(dt, DateType): return True + elif isinstance(dt, UserDefinedType): + return True return False @@ -967,6 +1159,9 @@ def Dict(d): elif isinstance(dataType, DateType): return datetime.date + elif isinstance(dataType, UserDefinedType): + return lambda datum: dataType.deserialize(datum) + elif not isinstance(dataType, StructType): raise Exception("unexpected data type: %s" % dataType) @@ -1244,6 +1439,10 @@ def applySchema(self, rdd, schema): for row in rows: _verify_type(row, schema) + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) @@ -1877,6 +2076,7 @@ def _test(): # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql from pyspark.sql import Row, SQLContext + from pyspark.tests import ExamplePoint, ExamplePointUDT globs = pyspark.sql.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: @@ -1888,6 +2088,8 @@ def _test(): Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) + globs['ExamplePoint'] = ExamplePoint + globs['ExamplePointUDT'] = ExamplePointUDT jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 68fd756876219..e947b09468108 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -49,7 +49,8 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType from pyspark import shuffle _have_scipy = False @@ -694,8 +695,65 @@ def heavy_foo(x): self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + class SQLTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) + def setUp(self): self.sqlCtx = SQLContext(self.sc) @@ -824,6 +882,39 @@ def test_convert_row_to_dict(self): row = self.sqlCtx.sql("select l[0].a AS la from test").first() self.assertEqual(1, row.asDict()["la"]) + def test_infer_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd = self.sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + srdd.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + srdd = self.sqlCtx.applySchema(rdd, schema) + point = srdd.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd0 = self.sqlCtx.inferSchema(rdd) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + srdd0.saveAsParquetFile(output_dir) + srdd1 = self.sqlCtx.parquetFile(output_dir) + point = srdd1.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + class InputFormatTests(ReusedPySparkTestCase): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index e1b5992a36e5f..5dd19dd12d8dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -71,6 +71,8 @@ object DataType { case JSortedObject( ("class", JString(udtClass)), + ("pyClass", _), + ("sqlType", _), ("type", JString("udt"))) => Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] } @@ -593,6 +595,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Underlying storage type for this UDT */ def sqlType: DataType + /** Paired Python UDT class, if exists. */ + def pyUDT: String = null + /** * Convert the user type to a SQL datum * @@ -606,7 +611,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def jsonValue: JValue = { ("type" -> "udt") ~ - ("class" -> this.getClass.getName) + ("class" -> this.getClass.getName) ~ + ("pyClass" -> pyUDT) ~ + ("sqlType" -> sqlType.jsonValue) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9e61d18f7e926..84eaf401f240c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation @@ -483,6 +484,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true + case udt: UserDefinedType[_] => needsConversion(udt.sqlType) case other => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 997669051ed07..a83cf5d441d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -135,6 +135,8 @@ object EvaluatePython { case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type }.asJava + case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal // Pyrolite can handle Timestamp @@ -177,6 +179,9 @@ object EvaluatePython { case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) + case (_, udt: UserDefinedType[_]) => + fromJava(obj, udt.sqlType) + case (c: Int, ByteType) => c.toByte case (c: Long, ByteType) => c.toByte case (c: Int, ShortType) => c.toShort diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala new file mode 100644 index 0000000000000..b9569e96c0312 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.types._ + +/** + * An example class to demonstrate UDT in Scala, Java, and Python. + * @param x x coordinate + * @param y y coordinate + */ +@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) +private[sql] class ExamplePoint(val x: Double, val y: Double) + +/** + * User-defined type for [[ExamplePoint]]. + */ +private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { + + override def sqlType: DataType = ArrayType(DoubleType, false) + + override def pyUDT: String = "pyspark.tests.ExamplePointUDT" + + override def serialize(obj: Any): Seq[Double] = { + obj match { + case p: ExamplePoint => + Seq(p.x, p.y) + } + } + + override def deserialize(datum: Any): ExamplePoint = { + datum match { + case values: Seq[_] => + val xy = values.asInstanceOf[Seq[Double]] + assert(xy.length == 2) + new ExamplePoint(xy(0), xy(1)) + case values: util.ArrayList[_] => + val xy = values.asInstanceOf[util.ArrayList[Double]].asScala + new ExamplePoint(xy(0), xy(1)) + } + } + + override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 1bc15146f0fe8..3fa4a7c6481d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.UserDefinedType - protected[sql] object DataTypeConversions { /** From 39b8ad1c7eafedc65beae9ab8460efdfb672c4cd Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Tue, 4 Nov 2014 00:14:13 -0600 Subject: [PATCH 22/79] Reversed random line permutations. Eliminated all getters and setters for Date and Timestamp. Added Date and Timestamp to NativeType.defaultSizeOf. --- .../spark/sql/catalyst/dsl/package.scala | 8 ++++---- .../spark/sql/catalyst/expressions/Cast.scala | 4 ++-- .../sql/catalyst/expressions/Projection.scala | 2 -- .../spark/sql/catalyst/expressions/Row.scala | 19 +------------------ .../expressions/SpecificMutableRow.scala | 12 ------------ .../sql/catalyst/expressions/literals.scala | 2 +- .../spark/sql/catalyst/types/dataTypes.scala | 4 +++- .../spark/sql/columnar/ColumnType.scala | 5 ----- .../scala/org/apache/spark/sql/package.scala | 8 ++++---- 9 files changed, 15 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 45d74c32e5969..bc4fbac7af1e1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -125,9 +125,9 @@ package object dsl { implicit def floatToLiteral(f: Float) = Literal(f) implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) + implicit def dateToLiteral(d: Date) = Literal(d) implicit def bigDecimalToLiteral(d: BigDecimal) = Literal(d) implicit def decimalToLiteral(d: Decimal) = Literal(d) - implicit def dateToLiteral(d: Date) = Literal(d) implicit def timestampToLiteral(t: Timestamp) = Literal(t) implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) @@ -220,6 +220,9 @@ package object dsl { /** Creates a new AttributeReference of type string */ def string = AttributeReference(s, StringType, nullable = true)() + /** Creates a new AttributeReference of type date */ + def date = AttributeReference(s, DateType, nullable = true)() + /** Creates a new AttributeReference of type decimal */ def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)() @@ -227,9 +230,6 @@ package object dsl { def decimal(precision: Int, scale: Int) = AttributeReference(s, DecimalType(precision, scale), nullable = true)() - /** Creates a new AttributeReference of type date */ - def date = AttributeReference(s, DateType, nullable = true)() - /** Creates a new AttributeReference of type timestamp */ def timestamp = AttributeReference(s, TimestampType, nullable = true)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 38172eb1a50fb..22009666196a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -31,8 +31,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w override def nullable = (child.dataType, dataType) match { case (StringType, _: NumericType) => true - case (StringType, DateType) => true case (StringType, TimestampType) => true + case (StringType, DateType) => true case (_: NumericType, DateType) => true case (BooleanType, DateType) => true case (DateType, _: NumericType) => true @@ -333,8 +333,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case dt if dt == child.dataType => identity[Any] case StringType => castToString case BinaryType => castToBinary - case decimal: DecimalType => castToDecimal(decimal) case DateType => castToDate + case decimal: DecimalType => castToDecimal(decimal) case TimestampType => castToTimestamp case BooleanType => castToBoolean case ByteType => castToByte diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 45b5e6e2c289a..80a54eb74a352 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Date, Timestamp} - /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. * @param expressions a sequence of expressions that determine the value of each column of the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 5c0864c896628..99b9e6efbab90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.NativeType import java.sql.{Date, Timestamp} import java.math.BigDecimal @@ -101,8 +101,6 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) - def setDate(ordinal: Int, value: Date) - def setTimestamp(ordinal: Int, value: Timestamp) } /** @@ -123,9 +121,6 @@ object EmptyRow extends Row { def getShort(i: Int): Short = throw new UnsupportedOperationException def getByte(i: Int): Byte = throw new UnsupportedOperationException def getString(i: Int): String = throw new UnsupportedOperationException - def getDate(i: Int): Date = throw new UnsupportedOperationException - def getTimestamp(i: Int): Timestamp = throw new UnsupportedOperationException - override def getAs[T](i: Int): T = throw new UnsupportedOperationException def copy() = this @@ -190,16 +185,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { values(i).asInstanceOf[String] } - def getDate(i: Int): Date = { - if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") - values(i).asInstanceOf[Date] - } - - def getTimestamp(i: Int): Timestamp = { - if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") - values(i).asInstanceOf[Timestamp] - } - // Custom hashCode function that matches the efficient code generated version. override def hashCode(): Int = { var result: Int = 37 @@ -243,8 +228,6 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } - override def setDate(ordinal: Int,value: Date): Unit = { values(ordinal) = value } - override def setTimestamp(ordinal: Int,value: Timestamp): Unit = { values(ordinal) = value } override def setNullAt(i: Int): Unit = { values(i) = null } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 0c4eb53ee81f2..9f977bf6c2a0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -337,18 +337,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).asInstanceOf[MutableByte].value } - override def setDate(ordinal: Int, value: Date): Unit = { - val currentValue = values(ordinal).asInstanceOf[MutableDate] - currentValue.isNull = false - currentValue.value = value - } - - override def setTimestamp(ordinal: Int, value: Timestamp): Unit = { - val currentValue = values(ordinal).asInstanceOf[MutableTimestamp] - currentValue.isNull = false - currentValue.value = value - } - override def getAs[T](i: Int): T = { values(i).boxed.asInstanceOf[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 548a9185998c3..93c19325151bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -34,8 +34,8 @@ object Literal { case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: Decimal => Literal(d, DecimalType.Unlimited) - case d: Date => Literal(d, DateType) case t: Timestamp => Literal(t, TimestampType) + case d: Date => Literal(d, DateType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 11e049061e6b7..7782cb05b8a40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -92,9 +92,9 @@ object DataType { | "LongType" ^^^ LongType | "BinaryType" ^^^ BinaryType | "BooleanType" ^^^ BooleanType + | "DateType" ^^^ DateType | "DecimalType()" ^^^ DecimalType.Unlimited | fixedDecimalType - | "DateType" ^^^ DateType | "TimestampType" ^^^ TimestampType ) @@ -200,6 +200,8 @@ object NativeType { FloatType -> 4, ShortType -> 2, ByteType -> 1, + DateType -> 8, + TimestampType -> 12, StringType -> 4096) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 475b65c2798c3..ab66c85c4f242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -372,11 +372,6 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = { row(ordinal) = value } - - def append(v: Date, buffer: ByteBuffer) { - buffer.putLong(v.getTime) - } - } private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 42411b7dd840d..ea2f6ef103d22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -161,22 +161,22 @@ package object sql { /** * :: DeveloperApi :: * - * The data type representing `java.sql.Date` values. + * The data type representing `java.sql.Timestamp` values. * * @group dataType */ @DeveloperApi - val DateType = catalyst.types.DateType + val TimestampType = catalyst.types.TimestampType /** * :: DeveloperApi :: * - * The data type representing `java.sql.Timestamp` values. + * The data type representing `java.sql.Date` values. * * @group dataType */ @DeveloperApi - val TimestampType = catalyst.types.TimestampType + val DateType = catalyst.types.DateType /** * :: DeveloperApi :: From 3dd0da90d5d94506d27fa1ea7c07c0fd883243bc Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Tue, 4 Nov 2014 00:25:20 -0600 Subject: [PATCH 23/79] A couple more pointless changes undone. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 2 +- .../org/apache/spark/sql/catalyst/expressions/Projection.scala | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8e8ba2dda4f37..8fbdf664b71e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -124,9 +124,9 @@ object ScalaReflection { case obj: LongType.JvmType => LongType case obj: FloatType.JvmType => FloatType case obj: DoubleType.JvmType => DoubleType + case obj: DateType.JvmType => DateType case obj: BigDecimal => DecimalType.Unlimited case obj: Decimal => DecimalType.Unlimited - case obj: DateType.JvmType => DateType case obj: TimestampType.JvmType => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 80a54eb74a352..e7e81a21fdf03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions + /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. * @param expressions a sequence of expressions that determine the value of each column of the From 1a9c6cddadebdc53d083ac3e0da276ce979b5d1f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 3 Nov 2014 22:29:48 -0800 Subject: [PATCH 24/79] [SPARK-3573][MLLIB] Make MLlib's Vector compatible with SQL's SchemaRDD Register MLlib's Vector as a SQL user-defined type (UDT) in both Scala and Python. With this PR, we can easily map a RDD[LabeledPoint] to a SchemaRDD, and then select columns or save to a Parquet file. Examples in Scala/Python are attached. The Scala code was copied from jkbradley. ~~This PR contains the changes from #3068 . I will rebase after #3068 is merged.~~ marmbrus jkbradley Author: Xiangrui Meng Closes #3070 from mengxr/SPARK-3573 and squashes the following commits: 3a0b6e5 [Xiangrui Meng] organize imports 236f0a0 [Xiangrui Meng] register vector as UDT and provide dataset examples --- dev/run-tests | 2 +- .../src/main/python/mllib/dataset_example.py | 62 +++++++++ .../spark/examples/mllib/DatasetExample.scala | 121 ++++++++++++++++++ mllib/pom.xml | 5 + .../apache/spark/mllib/linalg/Vectors.scala | 69 +++++++++- .../spark/mllib/linalg/VectorsSuite.scala | 11 ++ python/pyspark/mllib/linalg.py | 50 ++++++++ python/pyspark/mllib/tests.py | 39 +++++- 8 files changed, 353 insertions(+), 6 deletions(-) create mode 100644 examples/src/main/python/mllib/dataset_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala diff --git a/dev/run-tests b/dev/run-tests index 0e9eefa76a18b..de607e4344453 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -180,7 +180,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS if [ -n "$_SQL_TESTS_ONLY" ]; then # This must be an array of individual arguments. Otherwise, having one long string #+ will be interpreted as a single test, which doesn't work. - SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test") + SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test") else SBT_MAVEN_TEST_ARGS=("test") fi diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py new file mode 100644 index 0000000000000..540dae785f6ea --- /dev/null +++ b/examples/src/main/python/mllib/dataset_example.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example of how to use SchemaRDD as a dataset for ML. Run with:: + bin/spark-submit examples/src/main/python/mllib/dataset_example.py +""" + +import os +import sys +import tempfile +import shutil + +from pyspark import SparkContext +from pyspark.sql import SQLContext +from pyspark.mllib.util import MLUtils +from pyspark.mllib.stat import Statistics + + +def summarize(dataset): + print "schema: %s" % dataset.schema().json() + labels = dataset.map(lambda r: r.label) + print "label average: %f" % labels.mean() + features = dataset.map(lambda r: r.features) + summary = Statistics.colStats(features) + print "features average: %r" % summary.mean() + +if __name__ == "__main__": + if len(sys.argv) > 2: + print >> sys.stderr, "Usage: dataset_example.py " + exit(-1) + sc = SparkContext(appName="DatasetExample") + sqlCtx = SQLContext(sc) + if len(sys.argv) == 2: + input = sys.argv[1] + else: + input = "data/mllib/sample_libsvm_data.txt" + points = MLUtils.loadLibSVMFile(sc, input) + dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + summarize(dataset0) + tempdir = tempfile.NamedTemporaryFile(delete=False).name + os.unlink(tempdir) + print "Save dataset as a Parquet file to %s." % tempdir + dataset0.saveAsParquetFile(tempdir) + print "Load it back and summarize it again." + dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + summarize(dataset1) + shutil.rmtree(tempdir) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala new file mode 100644 index 0000000000000..f8d83f4ec7327 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import java.io.File + +import com.google.common.io.Files +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} + +/** + * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DatasetExample { + + case class Params( + input: String = "data/mllib/sample_libsvm_data.txt", + dataFormat: String = "libsvm") extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DatasetExample") { + head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + opt[String]("input") + .text(s"input path to dataset") + .action((x, c) => c.copy(input = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(input = x)) + checkConfig { params => + success + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DatasetExample with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ // for implicit conversions + + // Load input data + val origData: RDD[LabeledPoint] = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) + } + println(s"Loaded ${origData.count()} instances from file: ${params.input}") + + // Convert input data to SchemaRDD explicitly. + val schemaRDD: SchemaRDD = origData + println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") + println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + + // Select columns, using implicit conversion to SchemaRDD. + val labelsSchemaRDD: SchemaRDD = origData.select('label) + val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + val numLabels = labels.count() + val meanLabel = labels.fold(0.0)(_ + _) / numLabels + println(s"Selected label column with average value $meanLabel") + + val featuresSchemaRDD: SchemaRDD = origData.select('features) + val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + + val tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + val outputDir = new File(tmpDir, "dataset").toString + println(s"Saving to $outputDir as Parquet file.") + schemaRDD.saveAsParquetFile(outputDir) + + println(s"Loading Parquet file with UDT from $outputDir.") + val newDataset = sqlContext.parquetFile(outputDir) + + println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") + val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") + + sc.stop() + } + +} diff --git a/mllib/pom.xml b/mllib/pom.xml index fb7239e779aae..87a7ddaba97f2 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -45,6 +45,11 @@ spark-streaming_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + org.eclipse.jetty jetty-server diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6af225b7f49f7..ac217edc619ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -17,22 +17,26 @@ package org.apache.spark.mllib.linalg -import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import java.util +import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import scala.annotation.varargs import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} -import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException +import org.apache.spark.mllib.util.NumericParser +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.catalyst.types._ /** * Represents a numeric vector, whose index type is Int and value type is Double. * * Note: Users should not implement this interface. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) sealed trait Vector extends Serializable { /** @@ -74,6 +78,65 @@ sealed trait Vector extends Serializable { } } +/** + * User-defined type for [[Vector]] which allows easy interaction with SQL + * via [[org.apache.spark.sql.SchemaRDD]]. + */ +private[spark] class VectorUDT extends UserDefinedType[Vector] { + + override def sqlType: StructType = { + // type: 0 = sparse, 1 = dense + // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse + // vectors. The "values" field is nullable because we might want to add binary vectors later, + // which uses "size" and "indices", but not "values". + StructType(Seq( + StructField("type", ByteType, nullable = false), + StructField("size", IntegerType, nullable = true), + StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) + } + + override def serialize(obj: Any): Row = { + val row = new GenericMutableRow(4) + obj match { + case sv: SparseVector => + row.setByte(0, 0) + row.setInt(1, sv.size) + row.update(2, sv.indices.toSeq) + row.update(3, sv.values.toSeq) + case dv: DenseVector => + row.setByte(0, 1) + row.setNullAt(1) + row.setNullAt(2) + row.update(3, dv.values.toSeq) + } + row + } + + override def deserialize(datum: Any): Vector = { + datum match { + case row: Row => + require(row.length == 4, + s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") + val tpe = row.getByte(0) + tpe match { + case 0 => + val size = row.getInt(1) + val indices = row.getAs[Iterable[Int]](2).toArray + val values = row.getAs[Iterable[Double]](3).toArray + new SparseVector(size, indices, values) + case 1 => + val values = row.getAs[Iterable[Double]](3).toArray + new DenseVector(values) + } + } + } + + override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT" + + override def userClass: Class[Vector] = classOf[Vector] +} + /** * Factory methods for [[org.apache.spark.mllib.linalg.Vector]]. * We don't use the name `Vector` because Scala imports @@ -191,6 +254,7 @@ object Vectors { /** * A dense vector represented by a value array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { override def size: Int = values.length @@ -215,6 +279,7 @@ class DenseVector(val values: Array[Double]) extends Vector { * @param indices index array, assume to be strictly increasing. * @param values value array, must have the same length as the index array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class SparseVector( override val size: Int, val indices: Array[Int], diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index cd651fe2d2ddf..93a84fe07b32a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -155,4 +155,15 @@ class VectorsSuite extends FunSuite { throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.") } } + + test("VectorUDT") { + val dv0 = Vectors.dense(Array.empty[Double]) + val dv1 = Vectors.dense(1.0, 2.0) + val sv0 = Vectors.sparse(2, Array.empty, Array.empty) + val sv1 = Vectors.sparse(2, Array(1), Array(2.0)) + val udt = new VectorUDT() + for (v <- Seq(dv0, dv1, sv0, sv1)) { + assert(v === udt.deserialize(udt.serialize(v))) + } + } } diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index d0a0e102a1a07..c0c3dff31e7f8 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,6 +29,9 @@ import numpy as np +from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ + IntegerType, ByteType, Row + __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] @@ -106,7 +109,54 @@ def _format_float(f, digits=4): return s +class VectorUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Vector. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("size", IntegerType(), True), + StructField("indices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True)]) + + @classmethod + def module(cls): + return "pyspark.mllib.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.mllib.linalg.VectorUDT" + + def serialize(self, obj): + if isinstance(obj, SparseVector): + indices = [int(i) for i in obj.indices] + values = [float(v) for v in obj.values] + return (0, obj.size, indices, values) + elif isinstance(obj, DenseVector): + values = [float(v) for v in obj] + return (1, None, None, values) + else: + raise ValueError("cannot serialize %r of type %r" % (obj, type(obj))) + + def deserialize(self, datum): + assert len(datum) == 4, \ + "VectorUDT.deserialize given row with length %d but requires 4" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseVector(datum[1], datum[2], datum[3]) + elif tpe == 1: + return DenseVector(datum[3]) + else: + raise ValueError("do not recognize type %r" % tpe) + + class Vector(object): + + __UDT__ = VectorUDT() + """ Abstract class for DenseVector and SparseVector """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index d6fb87b378b4a..9fa4d6f6a2f5f 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -33,14 +33,14 @@ else: import unittest -from pyspark.serializers import PickleSerializer -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.serializers import PickleSerializer +from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase - _have_scipy = False try: import scipy.sparse @@ -221,6 +221,39 @@ def test_col_with_different_rdds(self): self.assertEqual(10, summary.count()) +class VectorUDTTests(PySparkTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + sqlCtx = SQLContext(self.sc) + rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) + srdd = sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = srdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise ValueError("expecting a vector but got %r of type %r" % (v, type(v))) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): From 9bdc8412a0160e06e8182bd8b2f9bb65b478c590 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 3 Nov 2014 22:40:43 -0800 Subject: [PATCH 25/79] [SPARK-4163][Core] Add a backward compatibility test for FetchFailed /cc aarondav Author: zsxwing Closes #3086 from zsxwing/SPARK-4163-back-comp and squashes the following commits: 21cb2a8 [zsxwing] Add a backward compatibility test for FetchFailed --- .../org/apache/spark/util/JsonProtocolSuite.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a91c9ddeaef36..01030120ae548 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -177,6 +177,17 @@ class JsonProtocolSuite extends FunSuite { deserializedBmRemoved) } + test("FetchFailed backwards compatibility") { + // FetchFailed in Spark 1.1.0 does not have an "Message" property. + val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + "ignored") + val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed) + .removeField({ _._1 == "Message" }) + val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + "Unknown reason") + assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + test("SparkListenerApplicationStart backwards compatibility") { // SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property. val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user") From b671ce047d036b8923007902826038b01e836e8a Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 3 Nov 2014 22:47:45 -0800 Subject: [PATCH 26/79] [SPARK-4166][Core] Add a backward compatibility test for ExecutorLostFailure Author: zsxwing Closes #3085 from zsxwing/SPARK-4166-back-comp and squashes the following commits: 89329f4 [zsxwing] Add a backward compatibility test for ExecutorLostFailure --- .../scala/org/apache/spark/util/JsonProtocolSuite.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 01030120ae548..aec1e409db95c 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -196,6 +196,15 @@ class JsonProtocolSuite extends FunSuite { assert(applicationStart === JsonProtocol.applicationStartFromJson(oldEvent)) } + test("ExecutorLostFailure backward compatibility") { + // ExecutorLostFailure in Spark 1.1.0 does not have an "Executor ID" property. + val executorLostFailure = ExecutorLostFailure("100") + val oldEvent = JsonProtocol.taskEndReasonToJson(executorLostFailure) + .removeField({ _._1 == "Executor ID" }) + val expectedExecutorLostFailure = ExecutorLostFailure("Unknown") + assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ From e4f42631a68b473ce706429915f3f08042af2119 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 3 Nov 2014 23:56:14 -0800 Subject: [PATCH 27/79] [SPARK-3886] [PySpark] simplify serializer, use AutoBatchedSerializer by default. This PR simplify serializer, always use batched serializer (AutoBatchedSerializer as default), even batch size is 1. Author: Davies Liu This patch had conflicts when merged, resolved by Committer: Josh Rosen Closes #2920 from davies/fix_autobatch and squashes the following commits: e544ef9 [Davies Liu] revert unrelated change 6880b14 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch 1d557fc [Davies Liu] fix tests 8180907 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch 76abdce [Davies Liu] clean up 53fa60b [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch d7ac751 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch 2cc2497 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch b4292ce [Davies Liu] fix bug in master d79744c [Davies Liu] recover hive tests be37ece [Davies Liu] refactor eb3938d [Davies Liu] refactor serializer in scala 8d77ef2 [Davies Liu] simplify serializer, use AutoBatchedSerializer by default. --- .../spark/api/python/PythonHadoopUtil.scala | 6 +- .../apache/spark/api/python/PythonRDD.scala | 110 +--------------- .../apache/spark/api/python/SerDeUtil.scala | 121 +++++++++++++----- .../WriteInputFormatTestDataGenerator.scala | 10 +- .../mllib/api/python/PythonMLLibAPI.scala | 2 +- python/pyspark/context.py | 58 +++------ python/pyspark/mllib/common.py | 2 +- python/pyspark/mllib/recommendation.py | 2 +- python/pyspark/rdd.py | 91 ++++++------- python/pyspark/serializers.py | 36 ++---- python/pyspark/shuffle.py | 7 +- python/pyspark/sql.py | 18 +-- python/pyspark/tests.py | 66 ++-------- .../org/apache/spark/sql/SchemaRDD.scala | 10 +- 14 files changed, 201 insertions(+), 338 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 49dc95f349eac..5ba66178e2b78 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -61,8 +61,7 @@ private[python] object Converter extends Logging { * Other objects are passed through without conversion. */ private[python] class WritableToJavaConverter( - conf: Broadcast[SerializableWritable[Configuration]], - batchSize: Int) extends Converter[Any, Any] { + conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or @@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter( map.put(convertWritable(k), convertWritable(v)) } map - case w: Writable => - if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w + case w: Writable => WritableUtils.clone(w, conf.value.value) case other => other } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 61b125ef7c6c1..e94ccdcd47bb7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -22,12 +22,10 @@ import java.net._ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials import com.google.common.base.Charsets.UTF_8 -import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec @@ -442,7 +440,7 @@ private[spark] object PythonRDD extends Logging { val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -468,7 +466,7 @@ private[spark] object PythonRDD extends Logging { Some(path), inputFormatClass, keyClass, valueClass, mergedConf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -494,7 +492,7 @@ private[spark] object PythonRDD extends Logging { None, inputFormatClass, keyClass, valueClass, conf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -537,7 +535,7 @@ private[spark] object PythonRDD extends Logging { Some(path), inputFormatClass, keyClass, valueClass, mergedConf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -563,7 +561,7 @@ private[spark] object PythonRDD extends Logging { None, inputFormatClass, keyClass, valueClass, conf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -746,104 +744,6 @@ private[spark] object PythonRDD extends Logging { converted.saveAsHadoopDataset(new JobConf(conf)) } } - - - /** - * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). - */ - @deprecated("PySpark does not use it anymore", "1.1") - def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - SerDeUtil.initialize() - iter.flatMap { row => - unpickle.loads(row) match { - // in case of objects are pickled in batch mode - case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // not in batch mode - case obj: JMap[String @unchecked, _] => Seq(obj.toMap) - } - } - } - } - - /** - * Convert an RDD of serialized Python tuple to Array (no recursive conversions). - * It is only used by pyspark.sql. - */ - def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = { - - def toArray(obj: Any): Array[_] = { - obj match { - case objs: JArrayList[_] => - objs.toArray - case obj if obj.getClass.isArray => - obj.asInstanceOf[Array[_]].toArray - } - } - - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]].map(toArray) - } else { - Seq(toArray(obj)) - } - } - }.toJavaRDD() - } - - private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { - private val pickle = new Pickler() - private var batch = 1 - private val buffer = new mutable.ArrayBuffer[Any] - - override def hasNext(): Boolean = iter.hasNext - - override def next(): Array[Byte] = { - while (iter.hasNext && buffer.length < batch) { - buffer += iter.next() - } - val bytes = pickle.dumps(buffer.toArray) - val size = bytes.length - // let 1M < size < 10M - if (size < 1024 * 1024) { - batch *= 2 - } else if (size > 1024 * 1024 * 10 && batch > 1) { - batch /= 2 - } - buffer.clear() - bytes - } - } - - /** - * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by - * PySpark. - */ - def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { - jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } - } - - /** - * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. - */ - def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { - pyRDD.rdd.mapPartitions { iter => - SerDeUtil.initialize() - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala - } else { - Seq(obj) - } - } - }.toJavaRDD() - } } private diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index ebdc3533e0992..a4153aaa926f8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -18,8 +18,13 @@ package org.apache.spark.api.python import java.nio.ByteOrder +import java.util.{ArrayList => JArrayList} + +import org.apache.spark.api.java.JavaRDD import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Failure import scala.util.Try @@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging { } initialize() + + /** + * Convert an RDD of Java objects to Array (no recursive conversions). + * It is only used by pyspark.sql. + */ + def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = { + jrdd.rdd.map { + case objs: JArrayList[_] => + objs.toArray + case obj if obj.getClass.isArray => + obj.asInstanceOf[Array[_]].toArray + }.toJavaRDD() + } + + /** + * Choose batch size based on size of objects + */ + private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { + private val pickle = new Pickler() + private var batch = 1 + private val buffer = new mutable.ArrayBuffer[Any] + + override def hasNext: Boolean = iter.hasNext + + override def next(): Array[Byte] = { + while (iter.hasNext && buffer.length < batch) { + buffer += iter.next() + } + val bytes = pickle.dumps(buffer.toArray) + val size = bytes.length + // let 1M < size < 10M + if (size < 1024 * 1024) { + batch *= 2 + } else if (size > 1024 * 1024 * 10 && batch > 1) { + batch /= 2 + } + buffer.clear() + bytes + } + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = { + jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } + } + + /** + * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { + pyRDD.rdd.mapPartitions { iter => + initialize() + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].asScala + } else { + Seq(obj) + } + } + }.toJavaRDD() + } + private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = { val pickle = new Pickler val kt = Try { @@ -128,17 +200,18 @@ private[spark] object SerDeUtil extends Logging { */ def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = { val (keyFailed, valueFailed) = checkPickle(rdd.first()) + rdd.mapPartitions { iter => - val pickle = new Pickler val cleaned = iter.map { case (k, v) => val key = if (keyFailed) k.toString else k val value = if (valueFailed) v.toString else v Array[Any](key, value) } - if (batchSize > 1) { - cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + if (batchSize == 0) { + new AutoBatchedPickler(cleaned) } else { - cleaned.map(pickle.dumps(_)) + val pickle = new Pickler + cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) } } } @@ -146,36 +219,22 @@ private[spark] object SerDeUtil extends Logging { /** * Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)]. */ - def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = { + def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = { def isPair(obj: Any): Boolean = { - Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) && + Option(obj.getClass.getComponentType).exists(!_.isPrimitive) && obj.asInstanceOf[Array[_]].length == 2 } - pyRDD.mapPartitions { iter => - initialize() - val unpickle = new Unpickler - val unpickled = - if (batchSerialized) { - iter.flatMap { batch => - unpickle.loads(batch) match { - case objs: java.util.List[_] => collectionAsScalaIterable(objs) - case other => throw new SparkException( - s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD") - } - } - } else { - iter.map(unpickle.loads(_)) - } - unpickled.map { - case obj if isPair(obj) => - // we only accept (K, V) - val arr = obj.asInstanceOf[Array[_]] - (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) - case other => throw new SparkException( - s"RDD element of type ${other.getClass.getName} cannot be used") - } + + val rdd = pythonToJava(pyRDD, batched).rdd + rdd.first match { + case obj if isPair(obj) => + // we only accept (K, V) + case other => throw new SparkException( + s"RDD element of type ${other.getClass.getName} cannot be used") + } + rdd.map { obj => + val arr = obj.asInstanceOf[Array[_]] + (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) } } - } - diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index e9ca9166eb4d6..c0cbd28a845be 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -176,11 +176,11 @@ object WriteInputFormatTestDataGenerator { // Create test data for arbitrary custom writable TestWritable val testClass = Seq( - ("1", TestWritable("test1", 123, 54.0)), - ("2", TestWritable("test2", 456, 8762.3)), - ("1", TestWritable("test3", 123, 423.1)), - ("3", TestWritable("test56", 456, 423.5)), - ("2", TestWritable("test2", 123, 5435.2)) + ("1", TestWritable("test1", 1, 1.0)), + ("2", TestWritable("test2", 2, 2.3)), + ("3", TestWritable("test3", 3, 3.1)), + ("5", TestWritable("test56", 5, 5.5)), + ("4", TestWritable("test4", 4, 4.2)) ) val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) } rdd.saveAsNewAPIHadoopFile(classPath, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index acdc67ddc660a..65b98a8ceea55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -736,7 +736,7 @@ private[spark] object SerDe extends Serializable { def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { jRDD.rdd.mapPartitions { iter => initialize() // let it called in executor - new PythonRDD.AutoBatchedPickler(iter) + new SerDeUtil.AutoBatchedPickler(iter) } } diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 5f8dcedb1eea2..a0e4821728c8b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -63,7 +63,6 @@ class SparkContext(object): _active_spark_context = None _lock = Lock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH - _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, @@ -115,9 +114,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size self._unbatched_serializer = serializer - if batchSize == 1: - self.serializer = self._unbatched_serializer - elif batchSize == 0: + if batchSize == 0: self.serializer = AutoBatchedSerializer(self._unbatched_serializer) else: self.serializer = BatchedSerializer(self._unbatched_serializer, @@ -305,12 +302,8 @@ def parallelize(self, c, numSlices=None): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = min(len(c) // numSlices, self._batchSize) - if batchSize > 1: - serializer = BatchedSerializer(self._unbatched_serializer, - batchSize) - else: - serializer = self._unbatched_serializer + batchSize = max(1, min(len(c) // numSlices, self._batchSize)) + serializer = BatchedSerializer(self._unbatched_serializer, batchSize) serializer.dump_stream(c, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile @@ -328,8 +321,7 @@ def pickleFile(self, name, minPartitions=None): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ minPartitions = minPartitions or self.defaultMinPartitions - return RDD(self._jsc.objectFile(name, minPartitions), self, - BatchedSerializer(PickleSerializer())) + return RDD(self._jsc.objectFile(name, minPartitions), self) def textFile(self, name, minPartitions=None, use_unicode=True): """ @@ -405,7 +397,7 @@ def _dictToJavaMap(self, d): return jm def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, - valueConverter=None, minSplits=None, batchSize=None): + valueConverter=None, minSplits=None, batchSize=0): """ Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -427,17 +419,15 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, :param minSplits: minimum splits in dataset (default min(2, sc.defaultParallelism)) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ minSplits = minSplits or min(self.defaultParallelism, 2) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass, keyConverter, valueConverter, minSplits, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -458,18 +448,16 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -487,18 +475,16 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -519,18 +505,16 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -548,15 +532,13 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) @@ -836,7 +818,7 @@ def _test(): import doctest import tempfile globs = globals().copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') globs['tempdir'] = tempfile.mkdtemp() atexit.register(lambda: shutil.rmtree(globs['tempdir'])) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 76864d8163586..dbe5f698b7345 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -96,7 +96,7 @@ def _java2py(sc, r): if clsName == 'JavaRDD': jrdd = sc._jvm.SerDe.javaToPython(r) - return RDD(jrdd, sc, AutoBatchedSerializer(PickleSerializer())) + return RDD(jrdd, sc) elif isinstance(r, (JavaArray, JavaList)) or clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 6b32af07c9be2..e8b998414d319 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -117,7 +117,7 @@ def _test(): import doctest import pyspark.mllib.recommendation globs = pyspark.mllib.recommendation.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4f025b9f11707..879655dc53f4a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -120,7 +120,7 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx, jrdd_deserializer): + def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSerializer())): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False @@ -129,12 +129,8 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): self._id = jrdd.id() self._partitionFunc = None - def _toPickleSerialization(self): - if (self._jrdd_deserializer == PickleSerializer() or - self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): - return self - else: - return self._reserialize(BatchedSerializer(PickleSerializer(), 10)) + def _pickled(self): + return self._reserialize(AutoBatchedSerializer(PickleSerializer())) def id(self): """ @@ -446,12 +442,11 @@ def intersection(self, other): def _reserialize(self, serializer=None): serializer = serializer or self.ctx.serializer - if self._jrdd_deserializer == serializer: - return self - else: - converted = self.map(lambda x: x, preservesPartitioning=True) - converted._jrdd_deserializer = serializer - return converted + if self._jrdd_deserializer != serializer: + if not isinstance(self, PipelinedRDD): + self = self.map(lambda x: x, preservesPartitioning=True) + self._jrdd_deserializer = serializer + return self def __add__(self, other): """ @@ -1120,9 +1115,8 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, True) def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, @@ -1147,9 +1141,8 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl :param conf: Hadoop job configuration, passed in as a dict (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, True, path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf) @@ -1166,9 +1159,8 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, False) def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, @@ -1195,9 +1187,8 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No :param compressionCodecClass: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, True, path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, @@ -1215,9 +1206,8 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): :param path: path to sequence file :param compressionCodecClass: (None by default) """ - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, batched, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, True, path, compressionCodecClass) def saveAsPickleFile(self, path, batchSize=10): @@ -1232,8 +1222,11 @@ def saveAsPickleFile(self, path, batchSize=10): >>> sorted(sc.pickleFile(tmpFile.name, 5).collect()) [1, 2, 'rdd', 'spark'] """ - self._reserialize(BatchedSerializer(PickleSerializer(), - batchSize))._jrdd.saveAsObjectFile(path) + if batchSize == 0: + ser = AutoBatchedSerializer(PickleSerializer()) + else: + ser = BatchedSerializer(PickleSerializer(), batchSize) + self._reserialize(ser)._jrdd.saveAsObjectFile(path) def saveAsTextFile(self, path): """ @@ -1774,13 +1767,10 @@ def zip(self, other): >>> x.zip(y).collect() [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)] """ - if self.getNumPartitions() != other.getNumPartitions(): - raise ValueError("Can only zip with RDD which has the same number of partitions") - def get_batch_size(ser): if isinstance(ser, BatchedSerializer): return ser.batchSize - return 0 + return 1 def batch_as(rdd, batchSize): ser = rdd._jrdd_deserializer @@ -1790,12 +1780,16 @@ def batch_as(rdd, batchSize): my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) - if my_batch != other_batch: - # use the greatest batchSize to batch the other one. - if my_batch > other_batch: - other = batch_as(other, my_batch) - else: - self = batch_as(self, other_batch) + # use the smallest batchSize for both of them + batchSize = min(my_batch, other_batch) + if batchSize <= 0: + # auto batched or unlimited + batchSize = 100 + other = batch_as(other, batchSize) + self = batch_as(self, batchSize) + + if self.getNumPartitions() != other.getNumPartitions(): + raise ValueError("Can only zip with RDD which has the same number of partitions") # There will be an Exception in JVM if there are different number # of items in each partitions. @@ -1934,25 +1928,14 @@ def lookup(self, key): return values.collect() - def _is_pickled(self): - """ Return this RDD is serialized by Pickle or not. """ - der = self._jrdd_deserializer - if isinstance(der, PickleSerializer): - return True - if isinstance(der, BatchedSerializer) and isinstance(der.serializer, PickleSerializer): - return True - return False - def _to_java_object_rdd(self): """ Return an JavaRDD of Object by unpickling It will convert each Python object into Java object by Pyrolite, whenever the RDD is serialized in batch or not. """ - rdd = self._reserialize(AutoBatchedSerializer(PickleSerializer())) \ - if not self._is_pickled() else self - is_batch = isinstance(rdd._jrdd_deserializer, BatchedSerializer) - return self.ctx._jvm.PythonRDD.pythonToJava(rdd._jrdd, is_batch) + rdd = self._pickled() + return self.ctx._jvm.SerDeUtil.pythonToJava(rdd._jrdd, True) def countApprox(self, timeout, confidence=0.95): """ @@ -2132,7 +2115,7 @@ def _test(): globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 904bd9f2652d3..d597cbf94e1b1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -33,9 +33,8 @@ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.stop() -By default, PySpark serialize objects in batches; the batch size can be -controlled through SparkContext's C{batchSize} parameter -(the default size is 1024 objects): +PySpark serialize objects in batches; By default, the batch size is chosen based +on the size of objects, also configurable by SparkContext's C{batchSize} parameter: >>> sc = SparkContext('local', 'test', batchSize=2) >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) @@ -48,16 +47,6 @@ >>> rdd._jrdd.count() 8L >>> sc.stop() - -A batch size of -1 uses an unlimited batch size, and a size of 1 disables -batching: - ->>> sc = SparkContext('local', 'test', batchSize=1) ->>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) ->>> rdd.glom().collect() -[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ->>> rdd._jrdd.count() -16L """ import cPickle @@ -73,7 +62,7 @@ from pyspark import cloudpickle -__all__ = ["PickleSerializer", "MarshalSerializer"] +__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] class SpecialLengths(object): @@ -113,7 +102,7 @@ def __ne__(self, other): return not self.__eq__(other) def __repr__(self): - return "<%s object>" % self.__class__.__name__ + return "%s()" % self.__class__.__name__ def __hash__(self): return hash(str(self)) @@ -181,6 +170,7 @@ class BatchedSerializer(Serializer): """ UNLIMITED_BATCH_SIZE = -1 + UNKNOWN_BATCH_SIZE = 0 def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): self.serializer = serializer @@ -213,10 +203,10 @@ def _load_stream_without_unbatching(self, stream): def __eq__(self, other): return (isinstance(other, BatchedSerializer) and - other.serializer == self.serializer) + other.serializer == self.serializer and other.batchSize == self.batchSize) def __repr__(self): - return "BatchedSerializer<%s>" % str(self.serializer) + return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) class AutoBatchedSerializer(BatchedSerializer): @@ -225,7 +215,7 @@ class AutoBatchedSerializer(BatchedSerializer): """ def __init__(self, serializer, bestSize=1 << 16): - BatchedSerializer.__init__(self, serializer, -1) + BatchedSerializer.__init__(self, serializer, self.UNKNOWN_BATCH_SIZE) self.bestSize = bestSize def dump_stream(self, iterator, stream): @@ -248,10 +238,10 @@ def dump_stream(self, iterator, stream): def __eq__(self, other): return (isinstance(other, AutoBatchedSerializer) and - other.serializer == self.serializer) + other.serializer == self.serializer and other.bestSize == self.bestSize) def __str__(self): - return "AutoBatchedSerializer<%s>" % str(self.serializer) + return "AutoBatchedSerializer(%s)" % str(self.serializer) class CartesianDeserializer(FramedSerializer): @@ -284,7 +274,7 @@ def __eq__(self, other): self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __repr__(self): - return "CartesianDeserializer<%s, %s>" % \ + return "CartesianDeserializer(%s, %s)" % \ (str(self.key_ser), str(self.val_ser)) @@ -311,7 +301,7 @@ def __eq__(self, other): self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __repr__(self): - return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser)) + return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) class NoOpSerializer(FramedSerializer): @@ -430,7 +420,7 @@ def loads(self, obj): class AutoSerializer(FramedSerializer): """ - Choose marshal or cPickle as serialization protocol autumatically + Choose marshal or cPickle as serialization protocol automatically """ def __init__(self): diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index d57a802e4734a..5931e923c2e36 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -25,7 +25,7 @@ import random import pyspark.heapq3 as heapq -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer try: import psutil @@ -213,8 +213,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, Merger.__init__(self, aggregator) self.memory_limit = memory_limit # default serializer is only used for tests - self.serializer = serializer or \ - BatchedSerializer(PickleSerializer(), 1024) + self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions @@ -470,7 +469,7 @@ class ExternalSorter(object): def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") - self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024) + self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) def _get_path(self, n): """ Choose one directory for spill by number n """ diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index d16c18bc79fe4..e5d62a466cab6 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -44,7 +44,8 @@ from py4j.java_collections import ListConverter, MapConverter from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer +from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ + CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -1233,7 +1234,6 @@ def __init__(self, sparkContext, sqlContext=None): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray self._scala_SQLContext = sqlContext @property @@ -1263,8 +1263,8 @@ def registerFunction(self, name, f, returnType=StringType()): """ func = lambda _, it: imap(lambda x: f(*x), it) command = (func, None, - BatchedSerializer(PickleSerializer(), 1024), - BatchedSerializer(PickleSerializer(), 1024)) + AutoBatchedSerializer(PickleSerializer()), + AutoBatchedSerializer(PickleSerializer())) ser = CloudPickleSerializer() pickled_command = ser.dumps(command) if len(pickled_command) > (1 << 20): # 1M @@ -1443,8 +1443,7 @@ def applySchema(self, rdd, schema): converter = _python_to_sql_converter(schema) rdd = rdd.map(converter) - batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) - jrdd = self._pythonToJava(rdd._jrdd, batched) + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1841,7 +1840,7 @@ def __init__(self, jschema_rdd, sql_ctx): self.is_checkpointed = False self.ctx = self.sql_ctx._sc # the _jrdd is created by javaToPython(), serialized by pickle - self._jrdd_deserializer = BatchedSerializer(PickleSerializer()) + self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer()) @property def _jrdd(self): @@ -2071,16 +2070,13 @@ def subtract(self, other, numPartitions=None): def _test(): import doctest - from array import array from pyspark.context import SparkContext # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql from pyspark.sql import Row, SQLContext from pyspark.tests import ExamplePoint, ExamplePointUDT globs = pyspark.sql.__dict__.copy() - # The small batch size here ensures that we see multiple batches, - # even in these small test examples: - sc = SparkContext('local[4]', 'PythonTest', batchSize=2) + sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlCtx'] = SQLContext(sc) globs['rdd'] = sc.parallelize( diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index e947b09468108..7e61b017efa75 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -242,7 +242,7 @@ class PySparkTestCase(unittest.TestCase): def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ - self.sc = SparkContext('local[4]', class_name, batchSize=2) + self.sc = SparkContext('local[4]', class_name) def tearDown(self): self.sc.stop() @@ -253,7 +253,7 @@ class ReusedPySparkTestCase(unittest.TestCase): @classmethod def setUpClass(cls): - cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2) + cls.sc = SparkContext('local[4]', cls.__name__) @classmethod def tearDownClass(cls): @@ -671,7 +671,7 @@ def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ conf = SparkConf().set("spark.python.profile", "true") - self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf) + self.sc = SparkContext('local[4]', class_name, conf=conf) def test_profiler(self): @@ -1012,16 +1012,19 @@ def test_sequencefiles(self): clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable").collect()) - ec = (u'1', - {u'__class__': u'org.apache.spark.api.python.TestWritable', - u'double': 54.0, u'int': 123, u'str': u'test1'}) - self.assertEqual(clazz[0], ec) + cname = u'org.apache.spark.api.python.TestWritable' + ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), + (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), + (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), + (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), + (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] + self.assertEqual(clazz, ec) unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable", - batchSize=1).collect()) - self.assertEqual(unbatched_clazz[0], ec) + ).collect()) + self.assertEqual(unbatched_clazz, ec) def test_oldhadoop(self): basepath = self.tempdir.name @@ -1341,51 +1344,6 @@ def test_reserialization(self): result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) self.assertEqual(result5, data) - def test_unbatched_save_and_read(self): - basepath = self.tempdir.name - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.sc.parallelize(ei, len(ei)).saveAsSequenceFile( - basepath + "/unbatched/") - - unbatched_sequence = sorted(self.sc.sequenceFile( - basepath + "/unbatched/", - batchSize=1).collect()) - self.assertEqual(unbatched_sequence, ei) - - unbatched_hadoopFile = sorted(self.sc.hadoopFile( - basepath + "/unbatched/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - batchSize=1).collect()) - self.assertEqual(unbatched_hadoopFile, ei) - - unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile( - basepath + "/unbatched/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - batchSize=1).collect()) - self.assertEqual(unbatched_newAPIHadoopFile, ei) - - oldconf = {"mapred.input.dir": basepath + "/unbatched/"} - unbatched_hadoopRDD = sorted(self.sc.hadoopRDD( - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=oldconf, - batchSize=1).collect()) - self.assertEqual(unbatched_hadoopRDD, ei) - - newconf = {"mapred.input.dir": basepath + "/unbatched/"} - unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=newconf, - batchSize=1).collect()) - self.assertEqual(unbatched_newAPIHadoopRDD, ei) - def test_malformed_RDD(self): basepath = self.tempdir.name # non-batch-serialized RDD[[(K, V)]] should be rejected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 3ee2ea05cfa2d..fbec2f9f4b2c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.util.{List => JList} +import org.apache.spark.api.python.SerDeUtil + import scala.collection.JavaConversions._ import net.razorvine.pickle.Pickler @@ -385,12 +387,8 @@ class SchemaRDD( */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { val fieldTypes = schema.fields.map(_.dataType) - this.mapPartitions { iter => - val pickle = new Pickler - iter.map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)) - } + val jrdd = this.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) } /** From bcecd73fdd4d2ec209259cfd57d3ad1d63f028f2 Mon Sep 17 00:00:00 2001 From: Dariusz Kobylarz Date: Tue, 4 Nov 2014 09:53:43 -0800 Subject: [PATCH 28/79] fixed MLlib Naive-Bayes java example bug the filter tests Double objects by references whereas it should test their values Author: Dariusz Kobylarz Closes #3081 from dkobylarz/master and squashes the following commits: 5d43a39 [Dariusz Kobylarz] naive bayes example update a304b93 [Dariusz Kobylarz] fixed MLlib Naive-Bayes java example bug --- docs/mllib-naive-bayes.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 7f9d4c6563944..d5b044d94fdd7 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -88,11 +88,11 @@ JavaPairRDD predictionAndLabel = return new Tuple2(model.predict(p.features()), p.label()); } }); -double accuracy = 1.0 * predictionAndLabel.filter(new Function, Boolean>() { +double accuracy = predictionAndLabel.filter(new Function, Boolean>() { @Override public Boolean call(Tuple2 pl) { - return pl._1() == pl._2(); + return pl._1().equals(pl._2()); } - }).count() / test.count(); + }).count() / (double) test.count(); {% endhighlight %} From f90ad5d426cb726079c490a9bb4b1100e2b4e602 Mon Sep 17 00:00:00 2001 From: Niklas Wilcke <1wilcke@informatik.uni-hamburg.de> Date: Tue, 4 Nov 2014 09:57:03 -0800 Subject: [PATCH 29/79] [Spark-4060] [MLlib] exposing special rdd functions to the public Author: Niklas Wilcke <1wilcke@informatik.uni-hamburg.de> Closes #2907 from numbnut/master and squashes the following commits: 7f7c767 [Niklas Wilcke] [Spark-4060] [MLlib] exposing special rdd functions to the public, #2907 --- .../spark/mllib/evaluation/AreaUnderCurve.scala | 2 +- .../org/apache/spark/mllib/rdd/RDDFunctions.scala | 11 ++++++----- .../scala/org/apache/spark/mllib/rdd/SlidingRDD.scala | 5 +++-- .../apache/spark/mllib/rdd/RDDFunctionsSuite.scala | 6 +++--- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index 7858ec602483f..078fbfbe4f0e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -43,7 +43,7 @@ private[evaluation] object AreaUnderCurve { */ def of(curve: RDD[(Double, Double)]): Double = { curve.sliding(2).aggregate(0.0)( - seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points), + seqOp = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points), combOp = _ + _ ) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index b5e403bc8c14d..57c0768084e41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.rdd import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD @@ -28,8 +29,8 @@ import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. */ -private[mllib] -class RDDFunctions[T: ClassTag](self: RDD[T]) { +@DeveloperApi +class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { /** * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding @@ -39,10 +40,10 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { * trigger a Spark job if the parent RDD has more than one partitions and the window size is * greater than 1. */ - def sliding(windowSize: Int): RDD[Seq[T]] = { + def sliding(windowSize: Int): RDD[Array[T]] = { require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") if (windowSize == 1) { - self.map(Seq(_)) + self.map(Array(_)) } else { new SlidingRDD[T](self, windowSize) } @@ -112,7 +113,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { } } -private[mllib] +@DeveloperApi object RDDFunctions { /** Implicit conversion from an RDD to RDDFunctions. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index dd80782c0f001..35e81fcb3de0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -45,15 +45,16 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T] */ private[mllib] class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) - extends RDD[Seq[T]](parent) { + extends RDD[Array[T]](parent) { require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") - override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = { + override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = { val part = split.asInstanceOf[SlidingRDDPartition[T]] (firstParent[T].iterator(part.prev, context) ++ part.tail) .sliding(windowSize) .withPartial(false) + .map(_.toArray) } override def getPreferredLocations(split: Partition): Seq[String] = diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 27a19f793242b..4ef67a40b9f49 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -42,9 +42,9 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) val rdd = sc.parallelize(data, data.length).flatMap(s => s) assert(rdd.partitions.size === data.length) - val sliding = rdd.sliding(3) - val expected = data.flatMap(x => x).sliding(3).toList - assert(sliding.collect().toList === expected) + val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq) + val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) + assert(sliding === expected) } test("treeAggregate") { From 5e73138a0152b78380b3f1def4b969b58e70dd11 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 4 Nov 2014 16:15:38 -0800 Subject: [PATCH 30/79] [SPARK-2938] Support SASL authentication in NettyBlockTransferService Also lays the groundwork for supporting it inside the external shuffle service. Author: Aaron Davidson Closes #3087 from aarondav/sasl and squashes the following commits: 3481718 [Aaron Davidson] Delete rogue println 44f8410 [Aaron Davidson] Delete documentation - muahaha! eb9f065 [Aaron Davidson] Improve documentation and add end-to-end test at Spark-level a6b95f1 [Aaron Davidson] Address comments 785bbde [Aaron Davidson] Cleanup 79973cb [Aaron Davidson] Remove unused file 151b3c5 [Aaron Davidson] Add docs, timeout config, better failure handling f6177d7 [Aaron Davidson] Cleanup SASL state upon connection termination 7b42adb [Aaron Davidson] Add unit tests 8191bcb [Aaron Davidson] [SPARK-2938] Support SASL authentication in NettyBlockTransferService --- .../org/apache/spark/SecurityManager.scala | 23 ++- .../scala/org/apache/spark/SparkConf.scala | 6 + .../scala/org/apache/spark/SparkContext.scala | 2 + .../scala/org/apache/spark/SparkEnv.scala | 3 +- .../org/apache/spark/SparkSaslClient.scala | 147 --------------- .../org/apache/spark/SparkSaslServer.scala | 176 ------------------ .../org/apache/spark/executor/Executor.scala | 1 + .../netty/NettyBlockTransferService.scala | 28 ++- .../apache/spark/network/nio/Connection.scala | 5 +- .../spark/network/nio/ConnectionManager.scala | 7 +- .../apache/spark/storage/BlockManager.scala | 45 +++-- .../NettyBlockTransferSecuritySuite.scala | 161 ++++++++++++++++ .../network/nio/ConnectionManagerSuite.scala | 6 +- .../BlockManagerReplicationSuite.scala | 2 + .../spark/storage/BlockManagerSuite.scala | 4 +- docs/security.md | 1 - .../spark/network/TransportContext.java | 15 +- .../spark/network/client/TransportClient.java | 11 +- .../client/TransportClientBootstrap.java | 32 ++++ .../client/TransportClientFactory.java | 64 +++++-- .../spark/network/server/NoOpRpcHandler.java | 2 +- .../spark/network/server/RpcHandler.java | 19 +- .../server/TransportRequestHandler.java | 1 + .../spark/network/util/TransportConf.java | 3 + .../network/sasl/SaslClientBootstrap.java | 74 ++++++++ .../spark/network/sasl/SaslMessage.java | 74 ++++++++ .../spark/network/sasl/SaslRpcHandler.java | 97 ++++++++++ .../spark/network/sasl/SecretKeyHolder.java | 35 ++++ .../spark/network/sasl/SparkSaslClient.java | 138 ++++++++++++++ .../spark/network/sasl/SparkSaslServer.java | 170 +++++++++++++++++ .../shuffle/ExternalShuffleBlockHandler.java | 2 +- .../shuffle/ExternalShuffleClient.java | 15 +- .../spark/network/shuffle/ShuffleClient.java | 11 +- .../network/sasl/SaslIntegrationSuite.java | 172 +++++++++++++++++ .../spark/network/sasl/SparkSaslSuite.java | 89 +++++++++ .../ExternalShuffleIntegrationSuite.java | 7 +- .../streaming/ReceivedBlockHandlerSuite.scala | 1 + 37 files changed, 1257 insertions(+), 392 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/SparkSaslClient.scala delete mode 100644 core/src/main/scala/org/apache/spark/SparkSaslServer.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala create mode 100644 network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 0e0f1a7b2377e..dee935ffad51f 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication} import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.network.sasl.SecretKeyHolder /** * Spark class responsible for security. @@ -84,7 +85,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * Authenticator installed in the SecurityManager to how it does the authentication * and in this case gets the user name and password from the request. * - * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously + * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously * exchange messages. For this we use the Java SASL * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 * as the authentication mechanism. This means the shared secret is not passed @@ -98,7 +99,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * of protection they want. If we support those, the messages will also have to * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. * - * Since the connectionManager does asynchronous messages passing, the SASL + * Since the NioBlockTransferService does asynchronous messages passing, the SASL * authentication is a bit more complex. A ConnectionManager can be both a client * and a Server, so for a particular connection is has to determine what to do. * A ConnectionId was added to be able to track connections and is used to @@ -107,6 +108,10 @@ import org.apache.spark.deploy.SparkHadoopUtil * and waits for the response from the server and does the handshake before sending * the real message. * + * The NettyBlockTransferService ensures that SASL authentication is performed + * synchronously prior to any other communication on a connection. This is done in + * SaslClientBootstrap on the client side and SaslRpcHandler on the server side. + * * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters * can be used. Yarn requires a specific AmIpFilter be installed for security to work * properly. For non-Yarn deployments, users can write a filter to go through a @@ -139,7 +144,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * can take place. */ -private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { +private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder { // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" @@ -337,4 +342,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { * @return the secret key as a String if authentication is enabled, otherwise returns null */ def getSecretKey(): String = secretKey + + override def getSaslUser(appId: String): String = { + val myAppId = sparkConf.getAppId + require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") + getSaslUser() + } + + override def getSecretKey(appId: String): String = { + val myAppId = sparkConf.getAppId + require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") + getSecretKey() + } } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index ad0a9017afead..4c6c86c7bad78 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -217,6 +217,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { */ getAll.filter { case (k, _) => isAkkaConf(k) } + /** + * Returns the Spark application id, valid in the Driver after TaskScheduler registration and + * from the start in the Executor. + */ + def getAppId: String = get("spark.app.id") + /** Does the configuration contain a given parameter? */ def contains(key: String): Boolean = settings.contains(key) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 40444c237b738..3cdaa6a9cc8a8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -313,6 +313,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { val applicationId: String = taskScheduler.applicationId() conf.set("spark.app.id", applicationId) + env.blockManager.initialize(applicationId) + val metricsSystem = env.metricsSystem // The metrics system for Driver need to be set spark.app.id to app ID. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index e2f13accdfab5..45e9d7f243e96 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -276,7 +276,7 @@ object SparkEnv extends Logging { val blockTransferService = conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { case "netty" => - new NettyBlockTransferService(conf) + new NettyBlockTransferService(conf, securityManager) case "nio" => new NioBlockTransferService(conf, securityManager) } @@ -285,6 +285,7 @@ object SparkEnv extends Logging { "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala deleted file mode 100644 index a954fcc0c31fa..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.RealmCallback -import javax.security.sasl.RealmChoiceCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslClient -import javax.security.sasl.SaslException - -import scala.collection.JavaConversions.mapAsJavaMap - -import com.google.common.base.Charsets.UTF_8 - -/** - * Implements SASL Client logic for Spark - */ -private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging { - - /** - * Used to respond to server's counterpart, SaslServer with SASL tokens - * represented as byte arrays. - * - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST), - null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslClientCallbackHandler(securityMgr)) - - /** - * Used to initiate SASL handshake with server. - * @return response to challenge if needed - */ - def firstToken(): Array[Byte] = { - synchronized { - val saslToken: Array[Byte] = - if (saslClient != null && saslClient.hasInitialResponse()) { - logDebug("has initial response") - saslClient.evaluateChallenge(new Array[Byte](0)) - } else { - new Array[Byte](0) - } - saslToken - } - } - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslClient != null) saslClient.isComplete() else false - } - } - - /** - * Respond to server's SASL token. - * @param saslTokenMessage contains server's SASL token - * @return client's response SASL token - */ - def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = { - synchronized { - if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslClient might be using. - */ - def dispose() { - synchronized { - if (saslClient != null) { - try { - saslClient.dispose() - } catch { - case e: SaslException => // ignored - } finally { - saslClient = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * that works with share secrets. - */ - private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends - CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) - private val secretKey = securityMgr.getSecretKey() - private val userPassword: Array[Char] = SparkSaslServer.encodePassword( - if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8)) - - /** - * Implementation used to respond to SASL request from the server. - * - * @param callbacks objects that indicate what credential information the - * server's SaslServer requires from the client. - */ - override def handle(callbacks: Array[Callback]) { - logDebug("in the sasl client callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL client callback: setting username: " + userName) - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL client callback: setting userPassword") - pc.setPassword(userPassword) - } - case rc: RealmCallback => { - logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case cb: RealmChoiceCallback => {} - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback") - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala deleted file mode 100644 index 7c2afb364661f..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.AuthorizeCallback -import javax.security.sasl.RealmCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslException -import javax.security.sasl.SaslServer -import scala.collection.JavaConversions.mapAsJavaMap - -import com.google.common.base.Charsets.UTF_8 -import org.apache.commons.net.util.Base64 - -/** - * Encapsulates SASL server logic - */ -private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging { - - /** - * Actual SASL work done by this object from javax.security.sasl. - */ - private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null, - SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslDigestCallbackHandler(securityMgr)) - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslServer != null) saslServer.isComplete() else false - } - } - - /** - * Used to respond to server SASL tokens. - * @param token Server's SASL token - * @return response to send back to the server. - */ - def response(token: Array[Byte]): Array[Byte] = { - synchronized { - if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslServer might be using. - */ - def dispose() { - synchronized { - if (saslServer != null) { - try { - saslServer.dispose() - } catch { - case e: SaslException => // ignore - } finally { - saslServer = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * for SASL DIGEST-MD5 mechanism - */ - private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager) - extends CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) - - override def handle(callbacks: Array[Callback]) { - logDebug("In the sasl server callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL server callback: setting username") - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL server callback: setting userPassword") - val password: Array[Char] = - SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8)) - pc.setPassword(password) - } - case rc: RealmCallback => { - logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case ac: AuthorizeCallback => { - val authid = ac.getAuthenticationID() - val authzid = ac.getAuthorizationID() - if (authid.equals(authzid)) { - logDebug("set auth to true") - ac.setAuthorized(true) - } else { - logDebug("set auth to false") - ac.setAuthorized(false) - } - if (ac.isAuthorized()) { - logDebug("sasl server is authorized") - ac.setAuthorizedID(authzid) - } - } - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback") - } - } - } -} - -private[spark] object SparkSaslServer { - - /** - * This is passed as the server name when creating the sasl client/server. - * This could be changed to be configurable in the future. - */ - val SASL_DEFAULT_REALM = "default" - - /** - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - val DIGEST = "DIGEST-MD5" - - /** - * The quality of protection is just "auth". This means that we are doing - * authentication only, we are not supporting integrity or privacy protection of the - * communication channel after authentication. This could be changed to be configurable - * in the future. - */ - val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true") - - /** - * Encode a byte[] identifier as a Base64-encoded string. - * - * @param identifier identifier to encode - * @return Base64-encoded string - */ - def encodeIdentifier(identifier: Array[Byte]): String = { - new String(Base64.encodeBase64(identifier), UTF_8) - } - - /** - * Encode a password as a base64-encoded char[] array. - * @param password as a byte array. - * @return password as a char array. - */ - def encodePassword(password: Array[Byte]): Array[Char] = { - new String(Base64.encodeBase64(password), UTF_8).toCharArray() - } -} - diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 8b095e23f32ff..abc1dd0be6237 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -86,6 +86,7 @@ private[spark] class Executor( conf, executorId, slaveHostname, port, isLocal, actorSystem) SparkEnv.set(_env) _env.metricsSystem.registerSource(executorSource) + _env.blockManager.initialize(conf.getAppId) _env } else { SparkEnv.get diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 1c4327cf13b51..0d1fc81d2a16f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,13 +17,15 @@ package org.apache.spark.network.netty +import scala.collection.JavaConversions._ import scala.concurrent.{Future, Promise} -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory} +import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} +import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.serializer.JavaSerializer @@ -33,18 +35,30 @@ import org.apache.spark.util.Utils /** * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ -class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { +class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager) + extends BlockTransferService { + // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. - val serializer = new JavaSerializer(conf) + private val serializer = new JavaSerializer(conf) + private val authEnabled = securityManager.isAuthenticationEnabled() + private val transportConf = SparkTransportConf.fromSparkConf(conf) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ private[this] var clientFactory: TransportClientFactory = _ override def init(blockDataManager: BlockDataManager): Unit = { - val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) - transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler) - clientFactory = transportContext.createClientFactory() + val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = { + val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + if (!authEnabled) { + (nettyRpcHandler, None) + } else { + (new SaslRpcHandler(nettyRpcHandler, securityManager), + Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager))) + } + } + transportContext = new TransportContext(transportConf, rpcHandler) + clientFactory = transportContext.createClientFactory(bootstrap.toList) server = transportContext.createServer() logInfo("Server created on " + server.getPort) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 4f6f5e235811d..c2d9578be7ebb 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -23,12 +23,13 @@ import java.nio.channels._ import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList -import org.apache.spark._ - import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal +import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} + private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId, diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 8408b75bb4d65..f198aa8564a54 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -34,6 +34,7 @@ import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} import org.apache.spark.util.Utils import scala.util.Try @@ -600,7 +601,7 @@ private[nio] class ConnectionManager( } else { var replyToken : Array[Byte] = null try { - replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken) + replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken) if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId @@ -634,7 +635,7 @@ private[nio] class ConnectionManager( connection.synchronized { if (connection.sparkSaslServer == null) { logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(securityManager) + connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager) } } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) @@ -778,7 +779,7 @@ private[nio] class ConnectionManager( if (!conn.isSaslComplete()) { conn.synchronized { if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(securityManager) + conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager) var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 5f5dd0dc1c63f..655d16c65c8b5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -57,6 +57,12 @@ private[spark] class BlockResult( inputMetrics.bytesRead = bytes } +/** + * Manager running on every node (driver and executors) which provides interfaces for putting and + * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). + * + * Note that #initialize() must be called before the BlockManager is usable. + */ private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, @@ -69,8 +75,6 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService) extends BlockDataManager with Logging { - blockTransferService.init(this) - val diskBlockManager = new DiskBlockManager(this, conf) private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -102,22 +106,16 @@ private[spark] class BlockManager( + " switch to sort-based shuffle.") } - val blockManagerId = BlockManagerId( - executorId, blockTransferService.hostName, blockTransferService.port) + var blockManagerId: BlockManagerId = _ // Address of the server that serves this executor's shuffle files. This is either an external // service, or just our own Executor's BlockManager. - private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) { - BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) - } else { - blockManagerId - } + private[spark] var shuffleServerId: BlockManagerId = _ // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTranserService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val appId = conf.get("spark.app.id", "unknown-app-id") - new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId) + new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf)) } else { blockTransferService } @@ -150,8 +148,6 @@ private[spark] class BlockManager( private val peerFetchLock = new Object private var lastPeerFetchTime = 0L - initialize() - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -176,10 +172,27 @@ private[spark] class BlockManager( } /** - * Initialize the BlockManager. Register to the BlockManagerMaster, and start the - * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured. + * Initializes the BlockManager with the given appId. This is not performed in the constructor as + * the appId may not be known at BlockManager instantiation time (in particular for the driver, + * where it is only learned after registration with the TaskScheduler). + * + * This method initializes the BlockTransferService and ShuffleClient, registers with the + * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle + * service if configured. */ - private def initialize(): Unit = { + def initialize(appId: String): Unit = { + blockTransferService.init(this) + shuffleClient.init(appId) + + blockManagerId = BlockManagerId( + executorId, blockTransferService.hostName, blockTransferService.port) + + shuffleServerId = if (externalShuffleServiceEnabled) { + BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) + } else { + blockManagerId + } + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) // Register Executors' configuration with the local shuffle service, if one should exist. diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala new file mode 100644 index 0000000000000..bed0ed9d713dd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio._ +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ +import scala.concurrent.{Await, Promise} +import scala.util.{Failure, Success, Try} + +import org.apache.commons.io.IOUtils +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.{BlockDataManager, BlockTransferService} +import org.apache.spark.storage.{BlockId, ShuffleBlockId} +import org.apache.spark.{SecurityManager, SparkConf} +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers} + +class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { + test("security default off") { + testConnection(new SparkConf, new SparkConf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } + } + + test("security on same password") { + val conf = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + testConnection(conf, conf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } + } + + test("security on mismatch password") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate.secret", "bad") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("Mismatched response") + } + } + + test("security mismatch auth off on server") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate", "false") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => // any funny error may occur, sever will interpret SASL token as RPC + } + } + + test("security mismatch auth off on client") { + val conf0 = new SparkConf() + .set("spark.authenticate", "false") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.authenticate", "true") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("Expected SaslMessage") + } + } + + test("security mismatch app ids") { + val conf0 = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + val conf1 = conf0.clone.set("spark.app.id", "other-id") + testConnection(conf0, conf1) match { + case Success(_) => fail("Should have failed") + case Failure(t) => t.getMessage should include ("SASL appId app-id did not match") + } + } + + /** + * Creates two servers with different configurations and sees if they can talk. + * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed + * properly. We will throw an out-of-band exception if something other than that goes wrong. + */ + private def testConnection(conf0: SparkConf, conf1: SparkConf): Try[Unit] = { + val blockManager = mock[BlockDataManager] + val blockId = ShuffleBlockId(0, 1, 2) + val blockString = "Hello, world!" + val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap(blockString.getBytes)) + when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer) + + val securityManager0 = new SecurityManager(conf0) + val exec0 = new NettyBlockTransferService(conf0, securityManager0) + exec0.init(blockManager) + + val securityManager1 = new SecurityManager(conf1) + val exec1 = new NettyBlockTransferService(conf1, securityManager1) + exec1.init(blockManager) + + val result = fetchBlock(exec0, exec1, "1", blockId) match { + case Success(buf) => + IOUtils.toString(buf.createInputStream()) should equal(blockString) + buf.release() + Success() + case Failure(t) => + Failure(t) + } + exec0.close() + exec1.close() + result + } + + /** Synchronously fetches a single block, acting as the given executor fetching from another. */ + private def fetchBlock( + self: BlockTransferService, + from: BlockTransferService, + execId: String, + blockId: BlockId): Try[ManagedBuffer] = { + + val promise = Promise[ManagedBuffer]() + + self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString), + new BlockFetchingListener { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + promise.failure(exception) + } + + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + promise.success(data.retain()) + } + }) + + Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS)) + promise.future.value.get + } +} + diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index b70734dfe37cf..716f875d30b8a 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -60,6 +60,7 @@ class ConnectionManagerSuite extends FunSuite { val conf = new SparkConf conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") + conf.set("spark.app.id", "app-id") val securityManager = new SecurityManager(conf) val manager = new ConnectionManager(0, conf, securityManager) var numReceivedMessages = 0 @@ -95,6 +96,7 @@ class ConnectionManagerSuite extends FunSuite { test("security mismatch password") { val conf = new SparkConf conf.set("spark.authenticate", "true") + conf.set("spark.app.id", "app-id") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val manager = new ConnectionManager(0, conf, securityManager) @@ -105,9 +107,7 @@ class ConnectionManagerSuite extends FunSuite { None }) - val badconf = new SparkConf - badconf.set("spark.authenticate", "true") - badconf.set("spark.authenticate.secret", "bad") + val badconf = conf.clone.set("spark.authenticate.secret", "bad") val badsecurityManager = new SecurityManager(badconf) val managerServer = new ConnectionManager(0, badconf, badsecurityManager) var numReceivedServerMessages = 0 diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c6d7105592096..1461fa69db90d 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -63,6 +63,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val transfer = new NioBlockTransferService(conf, securityMgr) val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer) + store.initialize("app-id") allStores += store store } @@ -263,6 +264,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd when(failableTransfer.port).thenReturn(1000) val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, 10000, conf, mapOutputTracker, shuffleManager, failableTransfer) + failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 715b740b857b2..0782876c8e3c6 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -73,8 +73,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer) + manager.initialize("app-id") + manager } before { diff --git a/docs/security.md b/docs/security.md index ec0523184d665..1e206a139fb72 100644 --- a/docs/security.md +++ b/docs/security.md @@ -7,7 +7,6 @@ Spark currently supports authentication via a shared secret. Authentication can * For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. -* **IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.* ## Web UI diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index a271841e4e56c..5bc6e5a2418a9 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -17,12 +17,16 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.MessageDecoder; @@ -64,8 +68,17 @@ public TransportContext(TransportConf conf, RpcHandler rpcHandler) { this.decoder = new MessageDecoder(); } + /** + * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning + * a new Client. Bootstraps will be executed synchronously, and must run successfully in order + * to create a Client. + */ + public TransportClientFactory createClientFactory(List bootstraps) { + return new TransportClientFactory(this, bootstraps); + } + public TransportClientFactory createClientFactory() { - return new TransportClientFactory(this); + return createClientFactory(Lists.newArrayList()); } /** Create a server which will attempt to bind to a specific port. */ diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 01c143fff423c..a08cee02dd576 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -19,10 +19,9 @@ import java.io.Closeable; import java.util.UUID; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; @@ -186,4 +185,12 @@ public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("remoteAdress", channel.remoteAddress()) + .add("isActive", isActive()) + .toString(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java new file mode 100644 index 0000000000000..65e8020e34121 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * A bootstrap which is executed on a TransportClient before it is returned to the user. + * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per- + * connection basis. + * + * Since connections (and TransportClients) are reused as much as possible, it is generally + * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with + * the JVM itself. + */ +public interface TransportClientBootstrap { + /** Performs the bootstrapping operation, throwing an exception on failure. */ + public void doBootstrap(TransportClient client) throws RuntimeException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 0b4a1d8286407..1723fed307257 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -21,10 +21,14 @@ import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.Lists; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.Channel; @@ -40,6 +44,7 @@ import org.apache.spark.network.TransportContext; import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -47,22 +52,29 @@ * Factory for creating {@link TransportClient}s by using createClient. * * The factory maintains a connection pool to other hosts and should return the same - * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for - * all {@link TransportClient}s. + * TransportClient for the same remote host. It also shares a single worker thread pool for + * all TransportClients. + * + * TransportClients will be reused whenever possible. Prior to completing the creation of a new + * TransportClient, all given {@link TransportClientBootstrap}s will be run. */ public class TransportClientFactory implements Closeable { private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); private final TransportContext context; private final TransportConf conf; + private final List clientBootstraps; private final ConcurrentHashMap connectionPool; private final Class socketChannelClass; private EventLoopGroup workerGroup; - public TransportClientFactory(TransportContext context) { - this.context = context; + public TransportClientFactory( + TransportContext context, + List clientBootstraps) { + this.context = Preconditions.checkNotNull(context); this.conf = context.getConf(); + this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); this.connectionPool = new ConcurrentHashMap(); IOMode ioMode = IOMode.valueOf(conf.ioMode()); @@ -72,9 +84,12 @@ public TransportClientFactory(TransportContext context) { } /** - * Create a new BlockFetchingClient connecting to the given remote host / port. + * Create a new {@link TransportClient} connecting to the given remote host / port. This will + * reuse TransportClients if they are still active and are for the same remote address. Prior + * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s + * that are registered with this factory. * - * This blocks until a connection is successfully established. + * This blocks until a connection is successfully established and fully bootstrapped. * * Concurrency: This method is safe to call from multiple threads. */ @@ -104,17 +119,18 @@ public TransportClient createClient(String remoteHost, int remotePort) { // Use pooled buffers to reduce temporary buffer allocation bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); - final AtomicReference client = new AtomicReference(); + final AtomicReference clientRef = new AtomicReference(); bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { TransportChannelHandler clientHandler = context.initializePipeline(ch); - client.set(clientHandler.getClient()); + clientRef.set(clientHandler.getClient()); } }); // Connect to the remote server + long preConnect = System.currentTimeMillis(); ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { throw new RuntimeException( @@ -123,15 +139,35 @@ public void initChannel(SocketChannel ch) { throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); } - // Successful connection -- in the event that two threads raced to create a client, we will + TransportClient client = clientRef.get(); + assert client != null : "Channel future completed successfully with null client"; + + // Execute any client bootstraps synchronously before marking the Client as successful. + long preBootstrap = System.currentTimeMillis(); + logger.debug("Connection to {} successful, running bootstraps...", address); + try { + for (TransportClientBootstrap clientBootstrap : clientBootstraps) { + clientBootstrap.doBootstrap(client); + } + } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala + long bootstrapTime = System.currentTimeMillis() - preBootstrap; + logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e); + client.close(); + throw Throwables.propagate(e); + } + long postBootstrap = System.currentTimeMillis(); + + // Successful connection & bootstrap -- in the event that two threads raced to create a client, // use the first one that was put into the connectionPool and close the one we made here. - assert client.get() != null : "Channel future completed successfully with null client"; - TransportClient oldClient = connectionPool.putIfAbsent(address, client.get()); + TransportClient oldClient = connectionPool.putIfAbsent(address, client); if (oldClient == null) { - return client.get(); + logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", + address, postBootstrap - preConnect, postBootstrap - preBootstrap); + return client; } else { - logger.debug("Two clients were created concurrently, second one will be disposed."); - client.get().close(); + logger.debug("Two clients were created concurrently after {} ms, second will be disposed.", + postBootstrap - preConnect); + client.close(); return oldClient; } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 5a3f003726fc1..1502b7489e864 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -21,7 +21,7 @@ import org.apache.spark.network.client.TransportClient; /** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */ -public class NoOpRpcHandler implements RpcHandler { +public class NoOpRpcHandler extends RpcHandler { private final StreamManager streamManager; public NoOpRpcHandler() { diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 2369dc6203944..2ba92a40f8b0a 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,22 +23,33 @@ /** * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */ -public interface RpcHandler { +public abstract class RpcHandler { /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * + * This method will not be called in parallel for a single TransportClient (i.e., channel). + * * @param client A channel client which enables the handler to make requests back to the sender - * of this RPC. + * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. * @param callback Callback which should be invoked exactly once upon success or failure of the * RPC. */ - void receive(TransportClient client, byte[] message, RpcResponseCallback callback); + public abstract void receive( + TransportClient client, + byte[] message, + RpcResponseCallback callback); /** * Returns the StreamManager which contains the state about which streams are currently being * fetched by a TransportClient. */ - StreamManager getStreamManager(); + public abstract StreamManager getStreamManager(); + + /** + * Invoked when the connection associated with the given client has been invalidated. + * No further requests will come from this client. + */ + public void connectionTerminated(TransportClient client) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 17fe9001b35cc..1580180cc17e9 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -86,6 +86,7 @@ public void channelUnregistered() { for (long streamId : streamIds) { streamManager.connectionTerminated(streamId); } + rpcHandler.connectionTerminated(reverseClient); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index a68f38e0e94c9..823790dd3c66f 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -55,4 +55,7 @@ public int connectionTimeoutMs() { /** Send buffer size (SO_SNDBUF). */ public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + + /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ + public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java new file mode 100644 index 0000000000000..7bc91e375371f --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The + * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId. + */ +public class SaslClientBootstrap implements TransportClientBootstrap { + private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); + + private final TransportConf conf; + private final String appId; + private final SecretKeyHolder secretKeyHolder; + + public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.appId = appId; + this.secretKeyHolder = secretKeyHolder; + } + + /** + * Performs SASL authentication by sending a token, and then proceeding with the SASL + * challenge-response tokens until we either successfully authenticate or throw an exception + * due to mismatch. + */ + @Override + public void doBootstrap(TransportClient client) { + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder); + try { + byte[] payload = saslClient.firstToken(); + + while (!saslClient.isComplete()) { + SaslMessage msg = new SaslMessage(appId, payload); + ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + msg.encode(buf); + + byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeout()); + payload = saslClient.response(response); + } + } finally { + try { + // Once authentication is complete, the server will trust all remaining communication. + saslClient.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL client", e); + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java new file mode 100644 index 0000000000000..5b77e18c26bf4 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import com.google.common.base.Charsets; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged + * with the given appId. This appId allows a single SaslRpcHandler to multiplex different + * applications which may be using different sets of credentials. + */ +class SaslMessage implements Encodable { + + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xEA; + + public final String appId; + public final byte[] payload; + + public SaslMessage(String appId, byte[] payload) { + this.appId = appId; + this.payload = payload; + } + + @Override + public int encodedLength() { + // tag + appIdLength + appId + payloadLength + payload + return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + byte[] idBytes = appId.getBytes(Charsets.UTF_8); + buf.writeInt(idBytes.length); + buf.writeBytes(idBytes); + buf.writeInt(payload.length); + buf.writeBytes(payload); + } + + public static SaslMessage decode(ByteBuf buf) { + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected SaslMessage, received something else"); + } + + int idLength = buf.readInt(); + byte[] idBytes = new byte[idLength]; + buf.readBytes(idBytes); + + int payloadLength = buf.readInt(); + byte[] payload = new byte[payloadLength]; + buf.readBytes(payload); + + return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java new file mode 100644 index 0000000000000..3777a18e33f78 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.concurrent.ConcurrentMap; + +import com.google.common.base.Charsets; +import com.google.common.collect.Maps; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; + +/** + * RPC Handler which performs SASL authentication before delegating to a child RPC handler. + * The delegate will only receive messages if the given connection has been successfully + * authenticated. A connection may be authenticated at most once. + * + * Note that the authentication process consists of multiple challenge-response pairs, each of + * which are individual RPCs. + */ +public class SaslRpcHandler extends RpcHandler { + private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); + + /** RpcHandler we will delegate to for authenticated connections. */ + private final RpcHandler delegate; + + /** Class which provides secret keys which are shared by server and client on a per-app basis. */ + private final SecretKeyHolder secretKeyHolder; + + /** Maps each channel to its SASL authentication state. */ + private final ConcurrentMap channelAuthenticationMap; + + public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { + this.delegate = delegate; + this.secretKeyHolder = secretKeyHolder; + this.channelAuthenticationMap = Maps.newConcurrentMap(); + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + SparkSaslServer saslServer = channelAuthenticationMap.get(client); + if (saslServer != null && saslServer.isComplete()) { + // Authentication complete, delegate to base handler. + delegate.receive(client, message, callback); + return; + } + + SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + + if (saslServer == null) { + // First message in the handshake, setup the necessary state. + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder); + channelAuthenticationMap.put(client, saslServer); + } + + byte[] response = saslServer.response(saslMessage.payload); + if (saslServer.isComplete()) { + logger.debug("SASL authentication successful for channel {}", client); + } + callback.onSuccess(response); + } + + @Override + public StreamManager getStreamManager() { + return delegate.getStreamManager(); + } + + @Override + public void connectionTerminated(TransportClient client) { + SparkSaslServer saslServer = channelAuthenticationMap.remove(client); + if (saslServer != null) { + saslServer.dispose(); + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java new file mode 100644 index 0000000000000..81d5766794688 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +/** + * Interface for getting a secret key associated with some application. + */ +public interface SecretKeyHolder { + /** + * Gets an appropriate SASL User for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL user. + */ + String getSaslUser(String appId); + + /** + * Gets an appropriate SASL secret key for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key. + */ + String getSecretKey(String appId); +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java new file mode 100644 index 0000000000000..72ba737b998bc --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.RealmChoiceCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import java.io.IOException; + +import com.google.common.base.Throwables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.spark.network.sasl.SparkSaslServer.*; + +/** + * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. This client initializes the protocol via a + * firstToken, which is then followed by a set of challenges and responses. + */ +public class SparkSaslClient { + private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class); + + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslClient saslClient; + + public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, + SASL_PROPS, new ClientCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** Used to initiate SASL handshake with server. */ + public synchronized byte[] firstToken() { + if (saslClient != null && saslClient.hasInitialResponse()) { + try { + return saslClient.evaluateChallenge(new byte[0]); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } else { + return new byte[0]; + } + } + + /** Determines whether the authentication exchange has completed. */ + public synchronized boolean isComplete() { + return saslClient != null && saslClient.isComplete(); + } + + /** + * Respond to server's SASL token. + * @param token contains server's SASL token + * @return client's response SASL token + */ + public synchronized byte[] response(byte[] token) { + try { + return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslClient might be using. + */ + public synchronized void dispose() { + if (saslClient != null) { + try { + saslClient.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslClient = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * that works with share secrets. + */ + private class ClientCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL client callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL client callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL client callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + logger.info("Realm callback"); + } else if (callback instanceof RealmChoiceCallback) { + // ignore (?) + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java new file mode 100644 index 0000000000000..2c0ce40c75e80 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import java.io.IOException; +import java.util.Map; + +import com.google.common.base.Charsets; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.BaseEncoding; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. (It is not a server in the sense of accepting + * connections on some socket.) + */ +public class SparkSaslServer { + private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class); + + /** + * This is passed as the server name when creating the sasl client/server. + * This could be changed to be configurable in the future. + */ + static final String DEFAULT_REALM = "default"; + + /** + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + static final String DIGEST = "DIGEST-MD5"; + + /** + * The quality of protection is just "auth". This means that we are doing + * authentication only, we are not supporting integrity or privacy protection of the + * communication channel after authentication. This could be changed to be configurable + * in the future. + */ + static final Map SASL_PROPS = ImmutableMap.builder() + .put(Sasl.QOP, "auth") + .put(Sasl.SERVER_AUTH, "true") + .build(); + + /** Identifier for a certain secret key within the secretKeyHolder. */ + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslServer saslServer; + + public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS, + new DigestCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Determines whether the authentication exchange has completed successfully. + */ + public synchronized boolean isComplete() { + return saslServer != null && saslServer.isComplete(); + } + + /** + * Used to respond to server SASL tokens. + * @param token Server's SASL token + * @return response to send back to the server. + */ + public synchronized byte[] response(byte[] token) { + try { + return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + public synchronized void dispose() { + if (saslServer != null) { + try { + saslServer.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslServer = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism. + */ + private class DigestCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL server callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL server callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL server callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof AuthorizeCallback) { + AuthorizeCallback ac = (AuthorizeCallback) callback; + String authId = ac.getAuthenticationID(); + String authzId = ac.getAuthorizationID(); + ac.setAuthorized(authId.equals(authzId)); + if (ac.isAuthorized()) { + ac.setAuthorizedID(authzId); + } + logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized()); + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } + + /* Encode a byte[] identifier as a Base64-encoded string. */ + public static String encodeIdentifier(String identifier) { + Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); + return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8)); + } + + /** Encode a password as a base64-encoded char[] array. */ + public static char[] encodePassword(String password) { + Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); + return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index a9dff31decc83..cd3fea85b19a4 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -41,7 +41,7 @@ * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark- * level shuffle block. */ -public class ExternalShuffleBlockHandler implements RpcHandler { +public class ExternalShuffleBlockHandler extends RpcHandler { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); private final ExternalShuffleBlockManager blockManager; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 6bbabc44b958b..b0b19ba67bddc 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,8 +17,6 @@ package org.apache.spark.network.shuffle; -import java.io.Closeable; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,15 +34,20 @@ * BlockTransferService), which has the downside of losing the shuffle data if we lose the * executors. */ -public class ExternalShuffleClient implements ShuffleClient { +public class ExternalShuffleClient extends ShuffleClient { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); private final TransportClientFactory clientFactory; - private final String appId; - public ExternalShuffleClient(TransportConf conf, String appId) { + private String appId; + + public ExternalShuffleClient(TransportConf conf) { TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); this.clientFactory = context.createClientFactory(); + } + + @Override + public void init(String appId) { this.appId = appId; } @@ -55,6 +58,7 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener) { + assert appId != null : "Called before init()"; logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { TransportClient client = clientFactory.createClient(host, port); @@ -82,6 +86,7 @@ public void registerWithShuffleServer( int port, String execId, ExecutorShuffleInfo executorInfo) { + assert appId != null : "Called before init()"; TransportClient client = clientFactory.createClient(host, port); byte[] registerExecutorMessage = JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo)); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index d46a562394557..f72ab40690d0d 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -20,7 +20,14 @@ import java.io.Closeable; /** Provides an interface for reading shuffle files, either from an Executor or external service. */ -public interface ShuffleClient extends Closeable { +public abstract class ShuffleClient implements Closeable { + + /** + * Initializes the ShuffleClient, specifying this Executor's appId. + * Must be called before any other method on the ShuffleClient. + */ + public void init(String appId) { } + /** * Fetch a sequence of blocks from a remote node asynchronously, * @@ -28,7 +35,7 @@ public interface ShuffleClient extends Closeable { * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. */ - public void fetchBlocks( + public abstract void fetchBlocks( String host, int port, String execId, diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java new file mode 100644 index 0000000000000..84781207861ed --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.io.IOException; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class SaslIntegrationSuite { + static ExternalShuffleBlockHandler handler; + static TransportServer server; + static TransportConf conf; + static TransportContext context; + + TransportClientFactory clientFactory; + + /** Provides a secret key holder which always returns the given secret key. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + + private final String secretKey; + + TestSecretKeyHolder(String secretKey) { + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + @Override + public String getSecretKey(String appId) { + return secretKey; + } + } + + + @BeforeClass + public static void beforeAll() throws IOException { + SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key"); + SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder); + conf = new TransportConf(new SystemPropertyConfigProvider()); + context = new TransportContext(conf, handler); + server = context.createServer(); + } + + + @AfterClass + public static void afterAll() { + server.close(); + } + + @After + public void afterEach() { + if (clientFactory != null) { + clientFactory.close(); + clientFactory = null; + } + } + + @Test + public void testGoodClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + String msg = "Hello, World!"; + byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); + assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + } + + @Test + public void testBadClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key")))); + + try { + // Bootstrap should fail on startup. + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + @Test + public void testNoSaslClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList()); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.sendRpcSync(new byte[13], 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); + } + + try { + // Guessing the right tag byte doesn't magically get you in... + client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); + } + } + + @Test + public void testNoSaslServer() { + RpcHandler handler = new TestRpcHandler(); + TransportContext context = new TransportContext(conf, handler); + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key")))); + TransportServer server = context.createServer(); + try { + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); + } finally { + server.close(); + } + } + + /** RPC handler which simply responds with the message it received. */ + public static class TestRpcHandler extends RpcHandler { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + callback.onSuccess(message); + } + + @Override + public StreamManager getStreamManager() { + return new OneForOneStreamManager(); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java new file mode 100644 index 0000000000000..67a07f38eb5a0 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. + */ +public class SparkSaslSuite { + + /** Provides a secret key holder which returns secret key == appId */ + private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + return appId; + } + }; + + @Test + public void testMatching() { + SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + assertTrue(server.isComplete()); + + // Disposal should invalidate + server.dispose(); + assertFalse(server.isComplete()); + client.dispose(); + assertFalse(client.isComplete()); + } + + + @Test + public void testNonMatching() { + SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + try { + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + fail("Should not have completed"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("Mismatched response")); + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index b3bcf5fd68e73..bc101f53844d5 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -135,7 +135,8 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + ExternalShuffleClient client = new ExternalShuffleClient(conf); + client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @Override @@ -164,6 +165,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); } + client.close(); return res; } @@ -265,7 +267,8 @@ public void testFetchNoServer() throws Exception { } private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { - ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + ExternalShuffleClient client = new ExternalShuffleClient(conf); + client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index ad1a6f01b3a57..0f27f55fec4f3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -74,6 +74,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr)) + blockManager.initialize("app-id") tempDirectory = Files.createTempDir() manualClock.setTime(0) From 515abb9afa2d6b58947af6bb079a493b49d315ca Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 4 Nov 2014 18:14:28 -0800 Subject: [PATCH 31/79] [SQL] Add String option for DSL AS Author: Michael Armbrust Closes #3097 from marmbrus/asString and squashes the following commits: 6430520 [Michael Armbrust] Add String option for DSL AS --- .../main/scala/org/apache/spark/sql/catalyst/dsl/package.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 3314e15477016..31dc5a58e68e5 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -110,7 +110,8 @@ package object dsl { def asc = SortOrder(expr, Ascending) def desc = SortOrder(expr, Descending) - def as(s: Symbol) = Alias(expr, s.name)() + def as(alias: String) = Alias(expr, alias)() + def as(alias: Symbol) = Alias(expr, alias.name)() } trait ExpressionConversions { From c8abddc5164d8cf11cdede6ab3d5d1ea08028708 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 4 Nov 2014 21:35:52 -0800 Subject: [PATCH 32/79] [SPARK-3964] [MLlib] [PySpark] add Hypothesis test Python API ``` pyspark.mllib.stat.StatisticschiSqTest(observed, expected=None) :: Experimental :: If `observed` is Vector, conduct Pearson's chi-squared goodness of fit test of the observed data against the expected distribution, or againt the uniform distribution (by default), with each category having an expected frequency of `1 / len(observed)`. (Note: `observed` cannot contain negative values) If `observed` is matrix, conduct Pearson's independence test on the input contingency matrix, which cannot contain negative entries or columns or rows that sum up to 0. If `observed` is an RDD of LabeledPoint, conduct Pearson's independence test for every feature against the label across the input RDD. For each feature, the (feature, label) pairs are converted into a contingency matrix for which the chi-squared statistic is computed. All label and feature values must be categorical. :param observed: it could be a vector containing the observed categorical counts/relative frequencies, or the contingency matrix (containing either counts or relative frequencies), or an RDD of LabeledPoint containing the labeled dataset with categorical features. Real-valued features will be treated as categorical for each distinct value. :param expected: Vector containing the expected categorical counts/relative frequencies. `expected` is rescaled if the `expected` sum differs from the `observed` sum. :return: ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, the method used, and the null hypothesis. ``` Author: Davies Liu Closes #3091 from davies/his and squashes the following commits: 145d16c [Davies Liu] address comments 0ab0764 [Davies Liu] fix float 5097d54 [Davies Liu] add Hypothesis test Python API --- docs/mllib-statistics.md | 40 +++++ .../mllib/api/python/PythonMLLibAPI.scala | 26 ++++ python/pyspark/mllib/common.py | 7 +- python/pyspark/mllib/linalg.py | 13 +- python/pyspark/mllib/stat.py | 137 +++++++++++++++++- 5 files changed, 219 insertions(+), 4 deletions(-) diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 10a5131c07414..ca8c29218f52d 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -380,6 +380,46 @@ for (ChiSqTestResult result : featureTestResults) { {% endhighlight %} +
    +[`Statistics`](api/python/index.html#pyspark.mllib.stat.Statistics$) provides methods to +run Pearson's chi-squared tests. The following example demonstrates how to run and interpret +hypothesis tests. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.linalg import Vectors, Matrices +from pyspark.mllib.regresssion import LabeledPoint +from pyspark.mllib.stat import Statistics + +sc = SparkContext() + +vec = Vectors.dense(...) # a vector composed of the frequencies of events + +# compute the goodness of fit. If a second vector to test against is not supplied as a parameter, +# the test runs against a uniform distribution. +goodnessOfFitTestResult = Statistics.chiSqTest(vec) +print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. + +mat = Matrices.dense(...) # a contingency matrix + +# conduct Pearson's independence test on the input contingency matrix +independenceTestResult = Statistics.chiSqTest(mat) +print independenceTestResult # summary of the test including the p-value, degrees of freedom... + +obs = sc.parallelize(...) # LabeledPoint(feature, label) . + +# The contingency table is constructed from an RDD of LabeledPoint and used to conduct +# the independence test. Returns an array containing the ChiSquaredTestResult for every feature +# against the label. +featureTestResults = Statistics.chiSqTest(obs) + +for i, result in enumerate(featureTestResults): + print "Column $d:" % (i + 1) + print result +{% endhighlight %} +
    + ## Random data generation diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 65b98a8ceea55..d832ae34b55e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -43,6 +43,7 @@ import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames +import org.apache.spark.mllib.stat.test.ChiSqTestResult import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -454,6 +455,31 @@ class PythonMLLibAPI extends Serializable { Statistics.corr(x.rdd, y.rdd, getCorrNameOrDefault(method)) } + /** + * Java stub for mllib Statistics.chiSqTest() + */ + def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { + if (expected == null) { + Statistics.chiSqTest(observed) + } else { + Statistics.chiSqTest(observed, expected) + } + } + + /** + * Java stub for mllib Statistics.chiSqTest(observed: Matrix) + */ + def chiSqTest(observed: Matrix): ChiSqTestResult = { + Statistics.chiSqTest(observed) + } + + /** + * Java stub for mllib Statistics.chiSqTest(RDD[LabelPoint]) + */ + def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = { + Statistics.chiSqTest(data.rdd) + } + // used by the corr methods to retrieve the name of the correlation method passed in via pyspark private def getCorrNameOrDefault(method: String) = { if (method == null) CorrelationNames.defaultCorrName else method diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index dbe5f698b7345..c6149fe391ec8 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -98,8 +98,13 @@ def _java2py(sc, r): jrdd = sc._jvm.SerDe.javaToPython(r) return RDD(jrdd, sc) - elif isinstance(r, (JavaArray, JavaList)) or clsName in _picklable_classes: + if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) + elif isinstance(r, (JavaArray, JavaList)): + try: + r = sc._jvm.SerDe.dumps(r) + except Py4JJavaError: + pass # not pickable if isinstance(r, bytearray): r = PickleSerializer().loads(str(r)) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index c0c3dff31e7f8..e35202dca0acc 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -33,7 +33,7 @@ IntegerType, ByteType, Row -__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] +__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices'] if sys.version_info[:2] == (2, 7): @@ -578,6 +578,8 @@ class DenseMatrix(Matrix): def __init__(self, numRows, numCols, values): Matrix.__init__(self, numRows, numCols) assert len(values) == numRows * numCols + if not isinstance(values, array.array): + values = array.array('d', values) self.values = values def __reduce__(self): @@ -596,6 +598,15 @@ def toArray(self): return np.reshape(self.values, (self.numRows, self.numCols), order='F') +class Matrices(object): + @staticmethod + def dense(numRows, numCols, values): + """ + Create a DenseMatrix + """ + return DenseMatrix(numRows, numCols, values) + + def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 15f0652f833d7..0700f8a8e5a8e 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -19,11 +19,12 @@ Python package for statistical functions in MLlib. """ +from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import _convert_to_vector +from pyspark.mllib.linalg import Matrix, _convert_to_vector -__all__ = ['MultivariateStatisticalSummary', 'Statistics'] +__all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics'] class MultivariateStatisticalSummary(JavaModelWrapper): @@ -51,6 +52,54 @@ def min(self): return self.call("min").toArray() +class ChiSqTestResult(JavaModelWrapper): + """ + :: Experimental :: + + Object containing the test results for the chi-squared hypothesis test. + """ + @property + def method(self): + """ + Name of the test method + """ + return self._java_model.method() + + @property + def pValue(self): + """ + The probability of obtaining a test statistic result at least as + extreme as the one that was actually observed, assuming that the + null hypothesis is true. + """ + return self._java_model.pValue() + + @property + def degreesOfFreedom(self): + """ + Returns the degree(s) of freedom of the hypothesis test. + Return type should be Number(e.g. Int, Double) or tuples of Numbers. + """ + return self._java_model.degreesOfFreedom() + + @property + def statistic(self): + """ + Test statistic. + """ + return self._java_model.statistic() + + @property + def nullHypothesis(self): + """ + Null hypothesis of the test. + """ + return self._java_model.nullHypothesis() + + def __str__(self): + return self._java_model.toString() + + class Statistics(object): @staticmethod @@ -135,6 +184,90 @@ def corr(x, y=None, method=None): else: return callMLlibFunc("corr", x.map(float), y.map(float), method) + @staticmethod + def chiSqTest(observed, expected=None): + """ + :: Experimental :: + + If `observed` is Vector, conduct Pearson's chi-squared goodness + of fit test of the observed data against the expected distribution, + or againt the uniform distribution (by default), with each category + having an expected frequency of `1 / len(observed)`. + (Note: `observed` cannot contain negative values) + + If `observed` is matrix, conduct Pearson's independence test on the + input contingency matrix, which cannot contain negative entries or + columns or rows that sum up to 0. + + If `observed` is an RDD of LabeledPoint, conduct Pearson's independence + test for every feature against the label across the input RDD. + For each feature, the (feature, label) pairs are converted into a + contingency matrix for which the chi-squared statistic is computed. + All label and feature values must be categorical. + + :param observed: it could be a vector containing the observed categorical + counts/relative frequencies, or the contingency matrix + (containing either counts or relative frequencies), + or an RDD of LabeledPoint containing the labeled dataset + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + :param expected: Vector containing the expected categorical counts/relative + frequencies. `expected` is rescaled if the `expected` sum + differs from the `observed` sum. + :return: ChiSquaredTest object containing the test statistic, degrees + of freedom, p-value, the method used, and the null hypothesis. + + >>> from pyspark.mllib.linalg import Vectors, Matrices + >>> observed = Vectors.dense([4, 6, 5]) + >>> pearson = Statistics.chiSqTest(observed) + >>> print pearson.statistic + 0.4 + >>> pearson.degreesOfFreedom + 2 + >>> print round(pearson.pValue, 4) + 0.8187 + >>> pearson.method + u'pearson' + >>> pearson.nullHypothesis + u'observed follows the same distribution as expected.' + + >>> observed = Vectors.dense([21, 38, 43, 80]) + >>> expected = Vectors.dense([3, 5, 7, 20]) + >>> pearson = Statistics.chiSqTest(observed, expected) + >>> print round(pearson.pValue, 4) + 0.0027 + + >>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] + >>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) + >>> print round(chi.statistic, 4) + 21.9958 + + >>> from pyspark.mllib.regression import LabeledPoint + >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), + ... LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), + ... LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), + ... LabeledPoint(0.0, Vectors.dense([3.5, 30.0])), + ... LabeledPoint(0.0, Vectors.dense([3.5, 40.0])), + ... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),] + >>> rdd = sc.parallelize(data, 4) + >>> chi = Statistics.chiSqTest(rdd) + >>> print chi[0].statistic + 0.75 + >>> print chi[1].statistic + 1.5 + """ + if isinstance(observed, RDD): + jmodels = callMLlibFunc("chiSqTest", observed) + return [ChiSqTestResult(m) for m in jmodels] + + if isinstance(observed, Matrix): + jmodel = callMLlibFunc("chiSqTest", observed) + else: + if expected and len(expected) != len(observed): + raise ValueError("`expected` should have same length with `observed`") + jmodel = callMLlibFunc("chiSqTest", _convert_to_vector(observed), expected) + return ChiSqTestResult(jmodel) + def _test(): import doctest From 5f13759d3642ea5b58c12a756e7125ac19aff10e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 5 Nov 2014 01:21:53 -0800 Subject: [PATCH 33/79] [SPARK-4029][Streaming] Update streaming driver to reliably save and recover received block metadata on driver failures As part of the initiative of preventing data loss on driver failure, this JIRA tracks the sub task of modifying the streaming driver to reliably save received block metadata, and recover them on driver restart. This was solved by introducing a `ReceivedBlockTracker` that takes all the responsibility of managing the metadata of received blocks (i.e. `ReceivedBlockInfo`, and any actions on them (e.g, allocating blocks to batches, etc.). All actions to block info get written out to a write ahead log (using `WriteAheadLogManager`). On recovery, all the actions are replaying to recreate the pre-failure state of the `ReceivedBlockTracker`, which include the batch-to-block allocations and the unallocated blocks. Furthermore, the `ReceiverInputDStream` was modified to create `WriteAheadLogBackedBlockRDD`s when file segment info is present in the `ReceivedBlockInfo`. After recovery of all the block info (through recovery `ReceivedBlockTracker`), the `WriteAheadLogBackedBlockRDD`s gets recreated with the recovered info, and jobs submitted. The data of the blocks gets pulled from the write ahead logs, thanks to the segment info present in the `ReceivedBlockInfo`. This is still a WIP. Things that are missing here are. - *End-to-end integration tests:* Unit tests that tests the driver recovery, by killing and restarting the streaming context, and verifying all the input data gets processed. This has been implemented but not included in this PR yet. A sneak peek of that DriverFailureSuite can be found in this PR (on my personal repo): https://github.com/tdas/spark/pull/25 I can either include it in this PR, or submit that as a separate PR after this gets in. - *WAL cleanup:* Cleaning up the received data write ahead log, by calling `ReceivedBlockHandler.cleanupOldBlocks`. This is being worked on. Author: Tathagata Das Closes #3026 from tdas/driver-ha-rbt and squashes the following commits: a8009ed [Tathagata Das] Added comment 1d704bb [Tathagata Das] Enabled storing recovered WAL-backed blocks to BM 2ee2484 [Tathagata Das] More minor changes based on PR 47fc1e3 [Tathagata Das] Addressed PR comments. 9a7e3e4 [Tathagata Das] Refactored ReceivedBlockTracker API a bit to make things a little cleaner for users of the tracker. af63655 [Tathagata Das] Minor changes. fce2b21 [Tathagata Das] Removed commented lines 59496d3 [Tathagata Das] Changed class names, made allocation more explicit and added cleanup 19aec7d [Tathagata Das] Fixed casting bug. f66d277 [Tathagata Das] Fix line lengths. cda62ee [Tathagata Das] Added license 25611d6 [Tathagata Das] Minor changes before submitting PR 7ae0a7fb [Tathagata Das] Transferred changes from driver-ha-working branch --- .../dstream/ReceiverInputDStream.scala | 69 +++-- .../rdd/WriteAheadLogBackedBlockRDD.scala | 3 +- .../streaming/scheduler/JobGenerator.scala | 21 +- .../scheduler/ReceivedBlockTracker.scala | 230 +++++++++++++++++ .../streaming/scheduler/ReceiverTracker.scala | 98 ++++--- .../streaming/BasicOperationsSuite.scala | 19 +- .../streaming/ReceivedBlockTrackerSuite.scala | 242 ++++++++++++++++++ .../WriteAheadLogBackedBlockRDDSuite.scala | 4 +- 8 files changed, 597 insertions(+), 89 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index bb47d373de63d..3e67161363e50 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -17,15 +17,14 @@ package org.apache.spark.streaming.dstream -import scala.collection.mutable.HashMap import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} -import org.apache.spark.storage.BlockId +import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming._ -import org.apache.spark.streaming.receiver.{WriteAheadLogBasedStoreResult, BlockManagerBasedStoreResult, Receiver} +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD +import org.apache.spark.streaming.receiver.{Receiver, WriteAheadLogBasedStoreResult} import org.apache.spark.streaming.scheduler.ReceivedBlockInfo -import org.apache.spark.SparkException /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -40,9 +39,6 @@ import org.apache.spark.SparkException abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { - /** Keeps all received blocks information */ - private lazy val receivedBlockInfo = new HashMap[Time, Array[ReceivedBlockInfo]] - /** This is an unique identifier for the network input stream. */ val id = ssc.getNewReceiverStreamId() @@ -58,24 +54,45 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont def stop() {} - /** Ask ReceiverInputTracker for received data blocks and generates RDDs with them. */ + /** + * Generates RDDs with blocks received by the receiver of this stream. */ override def compute(validTime: Time): Option[RDD[T]] = { - // If this is called for any time before the start time of the context, - // then this returns an empty RDD. This may happen when recovering from a - // master failure - if (validTime >= graph.startTime) { - val blockInfo = ssc.scheduler.receiverTracker.getReceivedBlockInfo(id) - receivedBlockInfo(validTime) = blockInfo - val blockIds = blockInfo.map { _.blockStoreResult.blockId.asInstanceOf[BlockId] } - Some(new BlockRDD[T](ssc.sc, blockIds)) - } else { - Some(new BlockRDD[T](ssc.sc, Array.empty)) - } - } + val blockRDD = { - /** Get information on received blocks. */ - private[streaming] def getReceivedBlockInfo(time: Time) = { - receivedBlockInfo.get(time).getOrElse(Array.empty[ReceivedBlockInfo]) + if (validTime < graph.startTime) { + // If this is called for any time before the start time of the context, + // then this returns an empty RDD. This may happen when recovering from a + // driver failure without any write ahead log to recover pre-failure data. + new BlockRDD[T](ssc.sc, Array.empty) + } else { + // Otherwise, ask the tracker for all the blocks that have been allocated to this stream + // for this batch + val blockInfos = + ssc.scheduler.receiverTracker.getBlocksOfBatch(validTime).get(id).getOrElse(Seq.empty) + val blockStoreResults = blockInfos.map { _.blockStoreResult } + val blockIds = blockStoreResults.map { _.blockId.asInstanceOf[BlockId] }.toArray + + // Check whether all the results are of the same type + val resultTypes = blockStoreResults.map { _.getClass }.distinct + if (resultTypes.size > 1) { + logWarning("Multiple result types in block information, WAL information will be ignored.") + } + + // If all the results are of type WriteAheadLogBasedStoreResult, then create + // WriteAheadLogBackedBlockRDD else create simple BlockRDD. + if (resultTypes.size == 1 && resultTypes.head == classOf[WriteAheadLogBasedStoreResult]) { + val logSegments = blockStoreResults.map { + _.asInstanceOf[WriteAheadLogBasedStoreResult].segment + }.toArray + // Since storeInBlockManager = false, the storage level does not matter. + new WriteAheadLogBackedBlockRDD[T](ssc.sparkContext, + blockIds, logSegments, storeInBlockManager = true, StorageLevel.MEMORY_ONLY_SER) + } else { + new BlockRDD[T](ssc.sc, blockIds) + } + } + } + Some(blockRDD) } /** @@ -86,10 +103,6 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont */ private[streaming] override def clearMetadata(time: Time) { super.clearMetadata(time) - val oldReceivedBlocks = receivedBlockInfo.filter(_._1 <= (time - rememberDuration)) - receivedBlockInfo --= oldReceivedBlocks.keys - logDebug("Cleared " + oldReceivedBlocks.size + " RDDs that were older than " + - (time - rememberDuration) + ": " + oldReceivedBlocks.keys.mkString(", ")) + ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration) } } - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 23295bf658712..dd1e96334952f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -48,7 +48,6 @@ class WriteAheadLogBackedBlockRDDPartition( * If it does not find them, it looks up the corresponding file segment. * * @param sc SparkContext - * @param hadoopConfig Hadoop configuration * @param blockIds Ids of the blocks that contains this RDD's data * @param segments Segments in write ahead logs that contain this RDD's data * @param storeInBlockManager Whether to store in the block manager after reading from the segment @@ -58,7 +57,6 @@ class WriteAheadLogBackedBlockRDDPartition( private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( @transient sc: SparkContext, - @transient hadoopConfig: Configuration, @transient blockIds: Array[BlockId], @transient segments: Array[WriteAheadLogFileSegment], storeInBlockManager: Boolean, @@ -71,6 +69,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( s"the same as number of segments (${segments.length}})!") // Hadoop configuration is not serializable, so broadcast it as a serializable. + @transient private val hadoopConfig = sc.hadoopConfiguration private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig) override def getPartitions: Array[Partition] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 7d73ada12d107..39b66e1130768 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -112,7 +112,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // Wait until all the received blocks in the network input tracker has // been consumed by network input DStreams, and jobs have been generated with them logInfo("Waiting for all received blocks to be consumed for job generation") - while(!hasTimedOut && jobScheduler.receiverTracker.hasMoreReceivedBlockIds) { + while(!hasTimedOut && jobScheduler.receiverTracker.hasUnallocatedBlocks) { Thread.sleep(pollTime) } logInfo("Waited for all received blocks to be consumed for job generation") @@ -217,14 +217,18 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { - Try(graph.generateJobs(time)) match { + // Set the SparkEnv in this thread, so that job generation code can access the environment + // Example: BlockRDDs are created in this thread, and it needs to access BlockManager + // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed. + SparkEnv.set(ssc.env) + Try { + jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch + graph.generateJobs(time) // generate jobs using allocated block + } match { case Success(jobs) => - val receivedBlockInfo = graph.getReceiverInputStreams.map { stream => - val streamId = stream.id - val receivedBlockInfo = stream.getReceivedBlockInfo(time) - (streamId, receivedBlockInfo) - }.toMap - jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfo)) + val receivedBlockInfos = + jobScheduler.receiverTracker.getBlocksOfBatch(time).mapValues { _.toArray } + jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } @@ -234,6 +238,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) + jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration) // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala new file mode 100644 index 0000000000000..5f5e1909908d5 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.language.implicitConversions + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.util.{Clock, WriteAheadLogManager} +import org.apache.spark.util.Utils + +/** Trait representing any event in the ReceivedBlockTracker that updates its state. */ +private[streaming] sealed trait ReceivedBlockTrackerLogEvent + +private[streaming] case class BlockAdditionEvent(receivedBlockInfo: ReceivedBlockInfo) + extends ReceivedBlockTrackerLogEvent +private[streaming] case class BatchAllocationEvent(time: Time, allocatedBlocks: AllocatedBlocks) + extends ReceivedBlockTrackerLogEvent +private[streaming] case class BatchCleanupEvent(times: Seq[Time]) + extends ReceivedBlockTrackerLogEvent + + +/** Class representing the blocks of all the streams allocated to a batch */ +private[streaming] +case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) { + def getBlocksOfStream(streamId: Int): Seq[ReceivedBlockInfo] = { + streamIdToAllocatedBlocks.get(streamId).getOrElse(Seq.empty) + } +} + +/** + * Class that keep track of all the received blocks, and allocate them to batches + * when required. All actions taken by this class can be saved to a write ahead log + * (if a checkpoint directory has been provided), so that the state of the tracker + * (received blocks and block-to-batch allocations) can be recovered after driver failure. + * + * Note that when any instance of this class is created with a checkpoint directory, + * it will try reading events from logs in the directory. + */ +private[streaming] class ReceivedBlockTracker( + conf: SparkConf, + hadoopConf: Configuration, + streamIds: Seq[Int], + clock: Clock, + checkpointDirOption: Option[String]) + extends Logging { + + private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo] + + private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue] + private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks] + + private val logManagerRollingIntervalSecs = conf.getInt( + "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60) + private val logManagerOption = checkpointDirOption.map { checkpointDir => + new WriteAheadLogManager( + ReceivedBlockTracker.checkpointDirToLogDir(checkpointDir), + hadoopConf, + rollingIntervalSecs = logManagerRollingIntervalSecs, + callerName = "ReceivedBlockHandlerMaster", + clock = clock + ) + } + + private var lastAllocatedBatchTime: Time = null + + // Recover block information from write ahead logs + recoverFromWriteAheadLogs() + + /** Add received block. This event will get written to the write ahead log (if enabled). */ + def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized { + try { + writeToLog(BlockAdditionEvent(receivedBlockInfo)) + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + logDebug(s"Stream ${receivedBlockInfo.streamId} received " + + s"block ${receivedBlockInfo.blockStoreResult.blockId}") + true + } catch { + case e: Exception => + logError(s"Error adding block $receivedBlockInfo", e) + false + } + } + + /** + * Allocate all unallocated blocks to the given batch. + * This event will get written to the write ahead log (if enabled). + */ + def allocateBlocksToBatch(batchTime: Time): Unit = synchronized { + if (lastAllocatedBatchTime == null || batchTime > lastAllocatedBatchTime) { + val streamIdToBlocks = streamIds.map { streamId => + (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) + }.toMap + val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) + writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) + timeToAllocatedBlocks(batchTime) = allocatedBlocks + lastAllocatedBatchTime = batchTime + allocatedBlocks + } else { + throw new SparkException(s"Unexpected allocation of blocks, " + + s"last batch = $lastAllocatedBatchTime, batch time to allocate = $batchTime ") + } + } + + /** Get the blocks allocated to the given batch. */ + def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = synchronized { + timeToAllocatedBlocks.get(batchTime).map { _.streamIdToAllocatedBlocks }.getOrElse(Map.empty) + } + + /** Get the blocks allocated to the given batch and stream. */ + def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { + synchronized { + timeToAllocatedBlocks.get(batchTime).map { + _.getBlocksOfStream(streamId) + }.getOrElse(Seq.empty) + } + } + + /** Check if any blocks are left to be allocated to batches. */ + def hasUnallocatedReceivedBlocks: Boolean = synchronized { + !streamIdToUnallocatedBlockQueues.values.forall(_.isEmpty) + } + + /** + * Get blocks that have been added but not yet allocated to any batch. This method + * is primarily used for testing. + */ + def getUnallocatedBlocks(streamId: Int): Seq[ReceivedBlockInfo] = synchronized { + getReceivedBlockQueue(streamId).toSeq + } + + /** Clean up block information of old batches. */ + def cleanupOldBatches(cleanupThreshTime: Time): Unit = synchronized { + assert(cleanupThreshTime.milliseconds < clock.currentTime()) + val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq + logInfo("Deleting batches " + timesToCleanup) + writeToLog(BatchCleanupEvent(timesToCleanup)) + timeToAllocatedBlocks --= timesToCleanup + logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds)) + log + } + + /** Stop the block tracker. */ + def stop() { + logManagerOption.foreach { _.stop() } + } + + /** + * Recover all the tracker actions from the write ahead logs to recover the state (unallocated + * and allocated block info) prior to failure. + */ + private def recoverFromWriteAheadLogs(): Unit = synchronized { + // Insert the recovered block information + def insertAddedBlock(receivedBlockInfo: ReceivedBlockInfo) { + logTrace(s"Recovery: Inserting added block $receivedBlockInfo") + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + } + + // Insert the recovered block-to-batch allocations and clear the queue of received blocks + // (when the blocks were originally allocated to the batch, the queue must have been cleared). + def insertAllocatedBatch(batchTime: Time, allocatedBlocks: AllocatedBlocks) { + logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + + s"${allocatedBlocks.streamIdToAllocatedBlocks}") + streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } + lastAllocatedBatchTime = batchTime + timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + } + + // Cleanup the batch allocations + def cleanupBatches(batchTimes: Seq[Time]) { + logTrace(s"Recovery: Cleaning up batches $batchTimes") + timeToAllocatedBlocks --= batchTimes + } + + logManagerOption.foreach { logManager => + logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") + logManager.readFromLog().foreach { byteBuffer => + logTrace("Recovering record " + byteBuffer) + Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match { + case BlockAdditionEvent(receivedBlockInfo) => + insertAddedBlock(receivedBlockInfo) + case BatchAllocationEvent(time, allocatedBlocks) => + insertAllocatedBatch(time, allocatedBlocks) + case BatchCleanupEvent(batchTimes) => + cleanupBatches(batchTimes) + } + } + } + } + + /** Write an update to the tracker to the write ahead log */ + private def writeToLog(record: ReceivedBlockTrackerLogEvent) { + logDebug(s"Writing to log $record") + logManagerOption.foreach { logManager => + logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record))) + } + } + + /** Get the queue of received blocks belonging to a particular stream */ + private def getReceivedBlockQueue(streamId: Int): ReceivedBlockQueue = { + streamIdToUnallocatedBlockQueues.getOrElseUpdate(streamId, new ReceivedBlockQueue) + } +} + +private[streaming] object ReceivedBlockTracker { + def checkpointDirToLogDir(checkpointDir: String): String = { + new Path(checkpointDir, "receivedBlockMetadata").toString + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index d696563bcee83..1c3984d968d20 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,15 +17,16 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{HashMap, SynchronizedMap, SynchronizedQueue} + +import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.language.existentials import akka.actor._ -import org.apache.spark.{SerializableWritable, Logging, SparkEnv, SparkException} + +import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver} -import org.apache.spark.util.AkkaUtils /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -48,23 +49,28 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, err * This class manages the execution of the receivers of NetworkInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() * has been called because it needs the final set of input streams at the time of instantiation. + * + * @param skipReceiverLaunch Do not launch the receiver. This is useful for testing. */ private[streaming] -class ReceiverTracker(ssc: StreamingContext) extends Logging { +class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false) extends Logging { - val receiverInputStreams = ssc.graph.getReceiverInputStreams() - val receiverInputStreamMap = Map(receiverInputStreams.map(x => (x.id, x)): _*) - val receiverExecutor = new ReceiverLauncher() - val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] - val receivedBlockInfo = new HashMap[Int, SynchronizedQueue[ReceivedBlockInfo]] - with SynchronizedMap[Int, SynchronizedQueue[ReceivedBlockInfo]] - val timeout = AkkaUtils.askTimeout(ssc.conf) - val listenerBus = ssc.scheduler.listenerBus + private val receiverInputStreams = ssc.graph.getReceiverInputStreams() + private val receiverInputStreamIds = receiverInputStreams.map { _.id } + private val receiverExecutor = new ReceiverLauncher() + private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] + private val receivedBlockTracker = new ReceivedBlockTracker( + ssc.sparkContext.conf, + ssc.sparkContext.hadoopConfiguration, + receiverInputStreamIds, + ssc.scheduler.clock, + Option(ssc.checkpointDir) + ) + private val listenerBus = ssc.scheduler.listenerBus // actor is created when generator starts. // This not being null means the tracker has been started and not stopped - var actor: ActorRef = null - var currentTime: Time = null + private var actor: ActorRef = null /** Start the actor and receiver execution thread. */ def start() = synchronized { @@ -75,7 +81,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { if (!receiverInputStreams.isEmpty) { actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor), "ReceiverTracker") - receiverExecutor.start() + if (!skipReceiverLaunch) receiverExecutor.start() logInfo("ReceiverTracker started") } } @@ -84,45 +90,59 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { def stop() = synchronized { if (!receiverInputStreams.isEmpty && actor != null) { // First, stop the receivers - receiverExecutor.stop() + if (!skipReceiverLaunch) receiverExecutor.stop() // Finally, stop the actor ssc.env.actorSystem.stop(actor) actor = null + receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") } } - /** Return all the blocks received from a receiver. */ - def getReceivedBlockInfo(streamId: Int): Array[ReceivedBlockInfo] = { - val receivedBlockInfo = getReceivedBlockInfoQueue(streamId).dequeueAll(x => true) - logInfo("Stream " + streamId + " received " + receivedBlockInfo.size + " blocks") - receivedBlockInfo.toArray + /** Allocate all unallocated blocks to the given batch. */ + def allocateBlocksToBatch(batchTime: Time): Unit = { + if (receiverInputStreams.nonEmpty) { + receivedBlockTracker.allocateBlocksToBatch(batchTime) + } + } + + /** Get the blocks for the given batch and all input streams. */ + def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = { + receivedBlockTracker.getBlocksOfBatch(batchTime) } - private def getReceivedBlockInfoQueue(streamId: Int) = { - receivedBlockInfo.getOrElseUpdate(streamId, new SynchronizedQueue[ReceivedBlockInfo]) + /** Get the blocks allocated to the given batch and stream. */ + def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { + synchronized { + receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) + } + } + + /** Clean up metadata older than the given threshold time */ + def cleanupOldMetadata(cleanupThreshTime: Time) { + receivedBlockTracker.cleanupOldBatches(cleanupThreshTime) } /** Register a receiver */ - def registerReceiver( + private def registerReceiver( streamId: Int, typ: String, host: String, receiverActor: ActorRef, sender: ActorRef ) { - if (!receiverInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) + if (!receiverInputStreamIds.contains(streamId)) { + throw new SparkException("Register received for unexpected id " + streamId) } receiverInfo(streamId) = ReceiverInfo( streamId, s"${typ}-${streamId}", receiverActor, true, host) - ssc.scheduler.listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) logInfo("Registered receiver for stream " + streamId + " from " + sender.path.address) } /** Deregister a receiver */ - def deregisterReceiver(streamId: Int, message: String, error: String) { + private def deregisterReceiver(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => oldInfo.copy(actor = null, active = false, lastErrorMessage = message, lastError = error) @@ -131,7 +151,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) } receiverInfo(streamId) = newReceiverInfo - ssc.scheduler.listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId))) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -141,14 +161,12 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } /** Add new blocks for the given stream */ - def addBlocks(receivedBlockInfo: ReceivedBlockInfo) { - getReceivedBlockInfoQueue(receivedBlockInfo.streamId) += receivedBlockInfo - logDebug("Stream " + receivedBlockInfo.streamId + " received new blocks: " + - receivedBlockInfo.blockStoreResult.blockId) + private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { + receivedBlockTracker.addBlock(receivedBlockInfo) } /** Report error sent by a receiver */ - def reportError(streamId: Int, message: String, error: String) { + private def reportError(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => oldInfo.copy(lastErrorMessage = message, lastError = error) @@ -157,7 +175,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) } receiverInfo(streamId) = newReceiverInfo - ssc.scheduler.listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -167,8 +185,8 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } /** Check if any blocks are left to be processed */ - def hasMoreReceivedBlockIds: Boolean = { - !receivedBlockInfo.values.forall(_.isEmpty) + def hasUnallocatedBlocks: Boolean = { + receivedBlockTracker.hasUnallocatedReceivedBlocks } /** Actor to receive messages from the receivers. */ @@ -178,8 +196,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { registerReceiver(streamId, typ, host, receiverActor, sender) sender ! true case AddBlock(receivedBlockInfo) => - addBlocks(receivedBlockInfo) - sender ! true + sender ! addBlock(receivedBlockInfo) case ReportError(streamId, message, error) => reportError(streamId, message, error) case DeregisterReceiver(streamId, message, error) => @@ -194,6 +211,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { @transient val thread = new Thread() { override def run() { try { + SparkEnv.set(env) startReceivers() } catch { case ie: InterruptedException => logInfo("ReceiverLauncher interrupted") @@ -267,7 +285,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") - ssc.sparkContext.runJob(tempRDD, startReceiver) + ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) logInfo("All of the receivers have been terminated") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 6c8bb50145367..dbab685dc3511 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -17,18 +17,19 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.StreamingContext._ - -import org.apache.spark.rdd.{BlockRDD, RDD} -import org.apache.spark.SparkContext._ +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.language.existentials +import scala.reflect.ClassTag import util.ManualClock -import org.apache.spark.{SparkException, SparkConf} -import org.apache.spark.streaming.dstream.{WindowedDStream, DStream} -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} -import scala.reflect.ClassTag + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel -import scala.collection.mutable +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} class BasicOperationsSuite extends TestSuiteBase { test("map") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala new file mode 100644 index 0000000000000..fd9c97f551c62 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.File + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.{implicitConversions, postfixOps} +import scala.util.Random + +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.util.{Clock, ManualClock, SystemClock, WriteAheadLogReader} +import org.apache.spark.streaming.util.WriteAheadLogSuite._ +import org.apache.spark.util.Utils + +class ReceivedBlockTrackerSuite + extends FunSuite with BeforeAndAfter with Matchers with Logging { + + val conf = new SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite") + conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1") + + val hadoopConf = new Configuration() + val akkaTimeout = 10 seconds + val streamId = 1 + + var allReceivedBlockTrackers = new ArrayBuffer[ReceivedBlockTracker]() + var checkpointDirectory: File = null + + before { + checkpointDirectory = Files.createTempDir() + } + + after { + allReceivedBlockTrackers.foreach { _.stop() } + if (checkpointDirectory != null && checkpointDirectory.exists()) { + FileUtils.deleteDirectory(checkpointDirectory) + checkpointDirectory = null + } + } + + test("block addition, and block to batch allocation") { + val receivedBlockTracker = createTracker(enableCheckpoint = false) + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + + val blockInfos = generateBlockInfos() + blockInfos.map(receivedBlockTracker.addBlock) + + // Verify added blocks are unallocated blocks + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos + + // Allocate the blocks to a batch and verify that all of them have been allocated + receivedBlockTracker.allocateBlocksToBatch(1) + receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldBe empty + + // Allocate no blocks to another batch + receivedBlockTracker.allocateBlocksToBatch(2) + receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty + + // Verify that batch 2 cannot be allocated again + intercept[SparkException] { + receivedBlockTracker.allocateBlocksToBatch(2) + } + + // Verify that older batches cannot be allocated again + intercept[SparkException] { + receivedBlockTracker.allocateBlocksToBatch(1) + } + } + + test("block addition, block to batch allocation and cleanup with write ahead log") { + val manualClock = new ManualClock + conf.getInt( + "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", -1) should be (1) + + // Set the time increment level to twice the rotation interval so that every increment creates + // a new log file + val timeIncrementMillis = 2000L + def incrementTime() { + manualClock.addToTime(timeIncrementMillis) + } + + // Generate and add blocks to the given tracker + def addBlockInfos(tracker: ReceivedBlockTracker): Seq[ReceivedBlockInfo] = { + val blockInfos = generateBlockInfos() + blockInfos.map(tracker.addBlock) + blockInfos + } + + // Print the data present in the log ahead files in the log directory + def printLogFiles(message: String) { + val fileContents = getWriteAheadLogFiles().map { file => + (s"\n>>>>> $file: <<<<<\n${getWrittenLogData(file).mkString("\n")}") + }.mkString("\n") + logInfo(s"\n\n=====================\n$message\n$fileContents\n=====================\n") + } + + // Start tracker and add blocks + val tracker1 = createTracker(enableCheckpoint = true, clock = manualClock) + val blockInfos1 = addBlockInfos(tracker1) + tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 + + // Verify whether write ahead log has correct contents + val expectedWrittenData1 = blockInfos1.map(BlockAdditionEvent) + getWrittenLogData() shouldEqual expectedWrittenData1 + getWriteAheadLogFiles() should have size 1 + + // Restart tracker and verify recovered list of unallocated blocks + incrementTime() + val tracker2 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 + + // Allocate blocks to batch and verify whether the unallocated blocks got allocated + val batchTime1 = manualClock.currentTime + tracker2.allocateBlocksToBatch(batchTime1) + tracker2.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 + + // Add more blocks and allocate to another batch + incrementTime() + val batchTime2 = manualClock.currentTime + val blockInfos2 = addBlockInfos(tracker2) + tracker2.allocateBlocksToBatch(batchTime2) + tracker2.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + + // Verify whether log has correct contents + val expectedWrittenData2 = expectedWrittenData1 ++ + Seq(createBatchAllocation(batchTime1, blockInfos1)) ++ + blockInfos2.map(BlockAdditionEvent) ++ + Seq(createBatchAllocation(batchTime2, blockInfos2)) + getWrittenLogData() shouldEqual expectedWrittenData2 + + // Restart tracker and verify recovered state + incrementTime() + val tracker3 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 + tracker3.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + tracker3.getUnallocatedBlocks(streamId) shouldBe empty + + // Cleanup first batch but not second batch + val oldestLogFile = getWriteAheadLogFiles().head + incrementTime() + tracker3.cleanupOldBatches(batchTime2) + + // Verify that the batch allocations have been cleaned, and the act has been written to log + tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual Seq.empty + getWrittenLogData(getWriteAheadLogFiles().last) should contain(createBatchCleanup(batchTime1)) + + // Verify that at least one log file gets deleted + eventually(timeout(10 seconds), interval(10 millisecond)) { + getWriteAheadLogFiles() should not contain oldestLogFile + } + printLogFiles("After cleanup") + + // Restart tracker and verify recovered state, specifically whether info about the first + // batch has been removed, but not the second batch + incrementTime() + val tracker4 = createTracker(enableCheckpoint = true, clock = manualClock) + tracker4.getUnallocatedBlocks(streamId) shouldBe empty + tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty // should be cleaned + tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 + } + + /** + * Create tracker object with the optional provided clock. Use fake clock if you + * want to control time by manually incrementing it to test log cleanup. + */ + def createTracker(enableCheckpoint: Boolean, clock: Clock = new SystemClock): ReceivedBlockTracker = { + val cpDirOption = if (enableCheckpoint) Some(checkpointDirectory.toString) else None + val tracker = new ReceivedBlockTracker(conf, hadoopConf, Seq(streamId), clock, cpDirOption) + allReceivedBlockTrackers += tracker + tracker + } + + /** Generate blocks infos using random ids */ + def generateBlockInfos(): Seq[ReceivedBlockInfo] = { + List.fill(5)(ReceivedBlockInfo(streamId, 0, + BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) + } + + /** Get all the data written in the given write ahead log file. */ + def getWrittenLogData(logFile: String): Seq[ReceivedBlockTrackerLogEvent] = { + getWrittenLogData(Seq(logFile)) + } + + /** + * Get all the data written in the given write ahead log files. By default, it will read all + * files in the test log directory. + */ + def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = { + logFiles.flatMap { + file => new WriteAheadLogReader(file, hadoopConf).toSeq + }.map { byteBuffer => + Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) + }.toList + } + + /** Get all the write ahead log files in the test directory */ + def getWriteAheadLogFiles(): Seq[String] = { + import ReceivedBlockTracker._ + val logDir = checkpointDirToLogDir(checkpointDirectory.toString) + getLogFilesInDirectory(logDir).map { _.toString } + } + + /** Create batch allocation object from the given info */ + def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = { + BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) + } + + /** Create batch cleanup object from the given info */ + def createBatchCleanup(time: Long, moreTimes: Long*): BatchCleanupEvent = { + BatchCleanupEvent((Seq(time) ++ moreTimes).map(Time.apply)) + } + + implicit def millisToTime(milliseconds: Long): Time = Time(milliseconds) + + implicit def timeToMillis(time: Time): Long = time.milliseconds +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 10160244bcc91..d2b983c4b4d1a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -117,12 +117,12 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll { ) // Create the RDD and verify whether the returned data is correct - val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, hadoopConf, blockIds.toArray, + val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, segments.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY) assert(rdd.collect() === data.flatten) if (testStoreInBM) { - val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, hadoopConf, blockIds.toArray, + val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, segments.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY) assert(rdd2.collect() === data.flatten) assert( From 73d80170d70adfb37194a9dead0408512df863b0 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Wed, 5 Nov 2014 11:22:29 -0600 Subject: [PATCH 34/79] In tests where the scalatest assert conversion collides with the new DSL conversion (due to the existence of an === operator), applied the transformation assert(X == Y) --> assert(convertToEqualizer(X).===(Y)) --- .../ExpressionEvaluationSuite.scala | 89 ++--- .../sql/catalyst/trees/TreeNodeSuite.scala | 25 +- .../apache/spark/sql/CachedTableSuite.scala | 11 +- .../org/apache/spark/sql/DslQuerySuite.scala | 12 +- .../org/apache/spark/sql/JoinSuite.scala | 13 +- .../org/apache/spark/sql/SQLConfSuite.scala | 9 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 11 +- .../columnar/PartitionBatchPruningSuite.scala | 11 +- .../spark/sql/execution/PlannerSuite.scala | 17 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 311 +++++++++--------- .../spark/sql/hive/StatisticsSuite.scala | 25 +- .../sql/hive/execution/HiveQuerySuite.scala | 21 +- .../spark/sql/parquet/HiveParquetSuite.scala | 14 +- 13 files changed, 331 insertions(+), 238 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 6bfa0dbd65ba7..128a4860843ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -32,6 +32,13 @@ import org.apache.spark.sql.catalyst.types._ /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class ExpressionEvaluationSuite extends FunSuite { test("literals") { @@ -318,18 +325,18 @@ class ExpressionEvaluationSuite extends FunSuite { intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} - assert(("abcdef" cast StringType).nullable === false) - assert(("abcdef" cast BinaryType).nullable === false) - assert(("abcdef" cast BooleanType).nullable === false) - assert(("abcdef" cast TimestampType).nullable === true) - assert(("abcdef" cast LongType).nullable === true) - assert(("abcdef" cast IntegerType).nullable === true) - assert(("abcdef" cast ShortType).nullable === true) - assert(("abcdef" cast ByteType).nullable === true) - assert(("abcdef" cast DecimalType.Unlimited).nullable === true) - assert(("abcdef" cast DecimalType(4, 2)).nullable === true) - assert(("abcdef" cast DoubleType).nullable === true) - assert(("abcdef" cast FloatType).nullable === true) + assert(EQ(("abcdef" cast StringType).nullable).===(false)) + assert(EQ(("abcdef" cast BinaryType).nullable).===(false)) + assert(EQ(("abcdef" cast BooleanType).nullable).===(false)) + assert(EQ(("abcdef" cast TimestampType).nullable).===(true)) + assert(EQ(("abcdef" cast LongType).nullable).===(true)) + assert(EQ(("abcdef" cast IntegerType).nullable).===(true)) + assert(EQ(("abcdef" cast ShortType).nullable).===(true)) + assert(EQ(("abcdef" cast ByteType).nullable).===(true)) + assert(EQ(("abcdef" cast DecimalType.Unlimited).nullable).===(true)) + assert(EQ(("abcdef" cast DecimalType(4, 2)).nullable).===(true)) + assert(EQ(("abcdef" cast DoubleType).nullable).===(true)) + assert(EQ(("abcdef" cast FloatType).nullable).===(true)) checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null) } @@ -346,15 +353,15 @@ class ExpressionEvaluationSuite extends FunSuite { // - Values that would overflow the target precision should turn into null // - Because of this, casts to fixed-precision decimals should be nullable - assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false) - assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === false) - assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === false) - assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false) + assert(EQ(Cast(Literal(123), DecimalType.Unlimited).nullable).===(false)) + assert(EQ(Cast(Literal(10.03f), DecimalType.Unlimited).nullable).===(false)) + assert(EQ(Cast(Literal(10.03), DecimalType.Unlimited).nullable).===(false)) + assert(EQ(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable).===(false)) - assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true) + assert(EQ(Cast(Literal(123), DecimalType(2, 1)).nullable).===(true)) + assert(EQ(Cast(Literal(10.03f), DecimalType(2, 1)).nullable).===(true)) + assert(EQ(Cast(Literal(10.03), DecimalType(2, 1)).nullable).===(true)) + assert(EQ(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable).===(true)) checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123)) checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123)) @@ -500,26 +507,26 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) - assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) + assert(EQ(CaseWhen(Seq(c2, c4, c6)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4, c3, c5)).nullable).===(true)) val c4_notNull = 'a.boolean.notNull.at(3) val c5_notNull = 'a.boolean.notNull.at(4) val c6_notNull = 'a.boolean.notNull.at(5) - assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable).===(false)) + assert(EQ(CaseWhen(Seq(c2, c4, c6_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c6)).nullable).===(true)) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable).===(false)) + assert(EQ(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable).===(true)) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable).===(true)) } test("complex type") { @@ -559,11 +566,11 @@ class ExpressionEvaluationSuite extends FunSuite { :: StructField("b", StringType, nullable = false) :: Nil ) - assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) - assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) + assert(EQ(GetField(BoundReference(2,typeS, nullable = true), "a").nullable).===(true)) + assert(EQ(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable).===(false)) - assert(GetField(Literal(null, typeS), "a").nullable === true) - assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true) + assert(EQ(GetField(Literal(null, typeS), "a").nullable).===(true)) + assert(EQ(GetField(Literal(null, typeS_notNullable), "a").nullable).===(true)) checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) @@ -717,10 +724,10 @@ class ExpressionEvaluationSuite extends FunSuite { val s_notNull = 'a.string.notNull.at(0) - assert(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false) - assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true) + assert(EQ(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)).nullable).===(true)) + assert(EQ(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable).===(false)) + assert(EQ(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable).===(true)) + assert(EQ(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable).===(true)) checkEvaluation(s.substr(0, 2), "ex", row) checkEvaluation(s.substr(0), "example", row) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 036fd3fa1d6a1..82887dc9d4604 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -24,6 +24,13 @@ import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{StringType, NullType} +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class Dummy(optKey: Option[Expression]) extends Expression { def children = optKey.toSeq def nullable = true @@ -36,21 +43,21 @@ case class Dummy(optKey: Option[Expression]) extends Expression { class TreeNodeSuite extends FunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } - assert(after === Literal(2)) + assert(EQ(after).===(Literal(2))) } test("one child changed") { val before = Add(Literal(1), Literal(2)) val after = before transform { case Literal(2, _) => Literal(1) } - assert(after === Add(Literal(1), Literal(1))) + assert(EQ(after).===(Add(Literal(1), Literal(1)))) } test("no change") { val before = Add(Literal(1), Add(Literal(2), Add(Literal(3), Literal(4)))) val after = before transform { case Literal(5, _) => Literal(1)} - assert(before === after) + assert(EQ(before).===(after)) // Ensure that the objects after are the same objects before the transformation. before.map(identity[Expression]).zip(after.map(identity[Expression])).foreach { case (b, a) => assert(b eq a) @@ -61,7 +68,7 @@ class TreeNodeSuite extends FunSuite { val tree = Add(Literal(1), Add(Literal(2), Add(Literal(3), Literal(4)))) val literals = tree collect {case l: Literal => l} - assert(literals.size === 4) + assert(EQ(literals.size).===(4)) (1 to 4).foreach(i => assert(literals contains Literal(i))) } @@ -74,7 +81,7 @@ class TreeNodeSuite extends FunSuite { case l: Literal => actual.append(l.toString); l } - assert(expected === actual) + assert(EQ(expected).===(actual)) } test("post-order transform") { @@ -86,7 +93,7 @@ class TreeNodeSuite extends FunSuite { case l: Literal => actual.append(l.toString); l } - assert(expected === actual) + assert(EQ(expected).===(actual)) } test("transform works on nodes with Option children") { @@ -95,13 +102,13 @@ class TreeNodeSuite extends FunSuite { val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } var actual = dummy1 transformDown toZero - assert(actual === Dummy(Some(Literal(0)))) + assert(EQ(actual).===(Dummy(Some(Literal(0))))) actual = dummy1 transformUp toZero - assert(actual === Dummy(Some(Literal(0)))) + assert(EQ(actual).===(Dummy(Some(Literal(0))))) actual = dummy2 transform toZero - assert(actual === Dummy(None)) + assert(EQ(actual).===(Dummy(None))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1a5d87d5240e9..9eca55707ae52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,6 +22,13 @@ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class BigData(s: String) class CachedTableSuite extends QueryTest { @@ -86,7 +93,7 @@ class CachedTableSuite extends QueryTest { val data = "*" * 10000 sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData") table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(table("bigData").count() === 200000L) + assert(EQ(table("bigData").count()).===(200000L)) table("bigData").unpersist(blocking = true) } @@ -240,7 +247,7 @@ class CachedTableSuite extends QueryTest { table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum - assert(cached.statistics.sizeInBytes === actualSizeInBytes) + assert(EQ(cached.statistics.sizeInBytes).===(actualSizeInBytes)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 45e58afe9d9a2..c6698d41f1c01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -25,6 +25,14 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + + class DslQuerySuite extends QueryTest { import TestData._ @@ -172,7 +180,7 @@ class DslQuerySuite extends QueryTest { } test("count") { - assert(testData2.count() === testData2.map(_ => 1).count()) + assert(EQ(testData2.count()).===(testData2.map(_ => 1).count())) } test("null count") { @@ -193,7 +201,7 @@ class DslQuerySuite extends QueryTest { } test("zero count") { - assert(emptyTableData.count() === 0) + assert(EQ(emptyTableData.count()).===(0)) } test("except") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8b4cf5bac0187..0260eea467d1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -25,6 +25,13 @@ import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOu import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData @@ -34,7 +41,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val y = testData2.as('y) val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed val planned = planner.HashJoin(join) - assert(planned.size === 1) + assert(EQ(planned.size).===(1)) } def assertJoin(sqlString: String, c: Class[_]): Any = { @@ -50,7 +57,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: BroadcastNestedLoopJoin => j } - assert(operators.size === 1) + assert(EQ(operators.size).===(1)) if (operators(0).getClass() != c) { fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") } @@ -104,7 +111,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed val planned = planner.HashJoin(join) - assert(planned.size === 1) + assert(EQ(planned.size).===(1)) } test("inner join where, one match per row") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 60701f0e154f8..4a09ed517d6e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -24,6 +24,13 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class SQLConfSuite extends QueryTest with FunSuiteLike { val testKey = "test.key.0" @@ -38,7 +45,7 @@ class SQLConfSuite extends QueryTest with FunSuiteLike { test("programmatic ways of basic setting and getting") { clear() - assert(getAllConfs.size === 0) + assert(EQ(getAllConfs.size).===(0)) setConf(testKey, testVal) assert(getConf(testKey) == testVal) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index ef9b76b1e251e..95582eec06975 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -22,18 +22,25 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { test("Simple UDF") { registerFunction("strLenScala", (_: String).length) - assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) + assert(EQ(sql("SELECT strLenScala('test')").first().getInt(0)).===(4)) } test("TwoArgument UDF") { registerFunction("strLenScala", (_: String).length + (_:Int)) - assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) + assert(EQ(sql("SELECT strLenScala('test', 1)").first().getInt(0)).===(5)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 9ba3c210171bd..45086f88a27ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -22,6 +22,13 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { val originalColumnBatchSize = columnBatchSize val originalInMemoryPartitionPruning = inMemoryPartitionPruning @@ -107,8 +114,8 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head - assert(readBatches === expectedReadBatches, "Wrong number of read batches") - assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions") + assert(EQ(readBatches).===(expectedReadBatches), "Wrong number of read batches") + assert(EQ(readPartitions).===(expectedReadPartitions), "Wrong number of read partitions") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a5af71acfc79a..0ed380c63527a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -27,6 +27,13 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class PlannerSuite extends FunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan @@ -34,8 +41,8 @@ class PlannerSuite extends FunSuite { val logicalUnions = query collect { case u: logical.Union => u } val physicalUnions = planned collect { case u: execution.Union => u } - assert(logicalUnions.size === 2) - assert(physicalUnions.size === 1) + assert(EQ(logicalUnions.size).===(2)) + assert(EQ(physicalUnions.size).===(1)) } test("count is partially aggregated") { @@ -43,7 +50,7 @@ class PlannerSuite extends FunSuite { val planned = HashAggregation(query).head val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - assert(aggregations.size === 2) + assert(EQ(aggregations.size).===(2)) } test("count distinct is partially aggregated") { @@ -71,7 +78,7 @@ class PlannerSuite extends FunSuite { val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(EQ(broadcastHashJoins.size).===(1), "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) @@ -91,7 +98,7 @@ class PlannerSuite extends FunSuite { val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(EQ(broadcastHashJoins.size).===(1), "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 08d9da27f1b11..25031910c30de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -30,6 +30,13 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class TestRDDEntry(key: Int, value: String) case class NullReflectData( @@ -172,7 +179,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) var actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(EQ(actualCodec).===(TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -188,7 +195,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(EQ(actualCodec).===(TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -204,7 +211,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === "UNCOMPRESSED" :: Nil) + assert(EQ(actualCodec).===("UNCOMPRESSED" :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -220,7 +227,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(EQ(actualCodec).===(TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -236,7 +243,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(EQ(actualCodec).===(TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -285,8 +292,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } val result = query.collect() - assert(result.size === 9, "self-join result has incorrect size") - assert(result(0).size === 12, "result row has incorrect size") + assert(EQ(result.size).===(9), "self-join result has incorrect size") + assert(EQ(result(0).size).===(12), "result row has incorrect size") result.zipWithIndex.foreach { case (row, index) => row.zipWithIndex.foreach { case (field, column) => assert(field != null, s"self-join contains null value in row $index field $column") @@ -296,7 +303,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Import of simple Parquet file") { val result = parquetFile(ParquetTestData.testDir.toString).collect() - assert(result.size === 15) + assert(EQ(result.size).===(15)) result.zipWithIndex.foreach { case (row, index) => { val checkBoolean = @@ -304,12 +311,12 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA row(0) == true else row(0) == false - assert(checkBoolean === true, s"boolean field value in line $index did not match") - if (index % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match") - assert(row(2) === "abc", s"string field value in line $index did not match") - assert(row(3) === (index.toLong << 33), s"long value in line $index did not match") - assert(row(4) === 2.5F, s"float field value in line $index did not match") - assert(row(5) === 4.5D, s"double field value in line $index did not match") + assert(EQ(checkBoolean).===(true), s"boolean field value in line $index did not match") + if (index % 5 == 0) assert(EQ(row(1)).===(5), s"int field value in line $index did not match") + assert(EQ(row(2)).===("abc"), s"string field value in line $index did not match") + assert(EQ(row(3)).===((index.toLong << 33)), s"long value in line $index did not match") + assert(EQ(row(4)).===(2.5F), s"float field value in line $index did not match") + assert(EQ(row(5)).===(4.5D), s"double field value in line $index did not match") } } } @@ -319,11 +326,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA result.zipWithIndex.foreach { case (row, index) => { if (index % 3 == 0) - assert(row(0) === true, s"boolean field value in line $index did not match (every third row)") + assert(EQ(row(0)).===(true), s"boolean field value in line $index did not match (every third row)") else - assert(row(0) === false, s"boolean field value in line $index did not match") - assert(row(1) === (index.toLong << 33), s"long field value in line $index did not match") - assert(row.size === 2, s"number of columns in projection in line $index is incorrect") + assert(EQ(row(0)).===(false), s"boolean field value in line $index did not match") + assert(EQ(row(1)).===((index.toLong << 33)), s"long field value in line $index did not match") + assert(EQ(row.size).===(2), s"number of columns in projection in line $index is incorrect") } } } @@ -381,8 +388,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val rdd_copy = sql("SELECT * FROM tmpx").collect() val rdd_orig = rdd.collect() for(i <- 0 to 99) { - assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") - assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i") + assert(EQ(rdd_copy(i).apply(0)).===(rdd_orig(i).key), s"key error in line $i") + assert(EQ(rdd_copy(i).apply(1)).===(rdd_orig(i).value), s"value error in line $i") } Utils.deleteRecursively(file) } @@ -396,11 +403,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA dest_rdd.registerTempTable("dest") sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() val rdd_copy1 = sql("SELECT * FROM dest").collect() - assert(rdd_copy1.size === 100) + assert(EQ(rdd_copy1.size).===(100)) sql("INSERT INTO dest SELECT * FROM source") val rdd_copy2 = sql("SELECT * FROM dest").collect().sortBy(_.getInt(0)) - assert(rdd_copy2.size === 200) + assert(EQ(rdd_copy2.size).===(200)) Utils.deleteRecursively(dirname) } @@ -408,7 +415,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA sql("INSERT INTO testsource SELECT * FROM testsource") val double_rdd = sql("SELECT * FROM testsource").collect() assert(double_rdd != null) - assert(double_rdd.size === 30) + assert(EQ(double_rdd.size).===(30)) // let's restore the original test data Utils.deleteRecursively(ParquetTestData.testDir) @@ -425,7 +432,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val readFile = parquetFile(path) val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Seq.fill(5)(null)) + assert(EQ(rdd_saved(0)).===(Seq.fill(5)(null))) Utils.deleteRecursively(file) assert(true) } @@ -440,7 +447,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val readFile = parquetFile(path) val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Seq.fill(5)(null)) + assert(EQ(rdd_saved(0)).===(Seq.fill(5)(null))) Utils.deleteRecursively(file) assert(true) } @@ -478,11 +485,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val attribute2 = new AttributeReference("second", IntegerType, false)() val predicate5 = new GreaterThan(attribute1, attribute2) val badfilter = ParquetFilters.createFilter(predicate5) - assert(badfilter.isDefined === false) + assert(EQ(badfilter.isDefined).===(false)) val predicate6 = And(GreaterThan(attribute1, attribute2), GreaterThan(attribute1, attribute2)) val badfilter2 = ParquetFilters.createFilter(predicate6) - assert(badfilter2.isDefined === false) + assert(EQ(badfilter2.isDefined).===(false)) } test("test filter by predicate pushdown") { @@ -492,21 +499,21 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result1 = query1.collect() - assert(result1.size === 50) - assert(result1(0)(1) === 100) - assert(result1(49)(1) === 149) + assert(EQ(result1.size).===(50)) + assert(EQ(result1(0)(1)).===(100)) + assert(EQ(result1(49)(1)).===(149)) val query2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") assert( query2.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result2 = query2.collect() - assert(result2.size === 50) + assert(EQ(result2.size).===(50)) if (myval == "myint" || myval == "mylong") { - assert(result2(0)(1) === 151) - assert(result2(49)(1) === 200) + assert(EQ(result2(0)(1)).===(151)) + assert(EQ(result2(49)(1)).===(200)) } else { - assert(result2(0)(1) === 150) - assert(result2(49)(1) === 199) + assert(EQ(result2(0)(1)).===(150)) + assert(EQ(result2(49)(1)).===(199)) } } for(myval <- Seq("myint", "mylong")) { @@ -515,11 +522,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query3.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result3 = query3.collect() - assert(result3.size === 20) - assert(result3(0)(1) === 0) - assert(result3(9)(1) === 9) - assert(result3(10)(1) === 191) - assert(result3(19)(1) === 200) + assert(EQ(result3.size).===(20)) + assert(EQ(result3(0)(1)).===(0)) + assert(EQ(result3(9)(1)).===(9)) + assert(EQ(result3(10)(1)).===(191)) + assert(EQ(result3(19)(1)).===(200)) } for(myval <- Seq("mydouble", "myfloat")) { val result4 = @@ -534,18 +541,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA // currently no way to specify float constants in SqlParser? sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10").collect() } - assert(result4.size === 20) - assert(result4(0)(1) === 0) - assert(result4(9)(1) === 9) - assert(result4(10)(1) === 191) - assert(result4(19)(1) === 200) + assert(EQ(result4.size).===(20)) + assert(EQ(result4(0)(1)).===(0)) + assert(EQ(result4(9)(1)).===(9)) + assert(EQ(result4(10)(1)).===(191)) + assert(EQ(result4(19)(1)).===(200)) } val query5 = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40") assert( query5.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val booleanResult = query5.collect() - assert(booleanResult.size === 10) + assert(EQ(booleanResult.size).===(10)) for(i <- 0 until 10) { if (!booleanResult(i).getBoolean(0)) { fail(s"Boolean value in result row $i not true") @@ -559,16 +566,16 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query6.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val stringResult = query6.collect() - assert(stringResult.size === 1) + assert(EQ(stringResult.size).===(1)) assert(stringResult(0).getString(2) == "100", "stringvalue incorrect") - assert(stringResult(0).getInt(1) === 100) + assert(EQ(stringResult(0).getInt(1)).===(100)) val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40") assert( query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val optResult = query7.collect() - assert(optResult.size === 20) + assert(EQ(optResult.size).===(20)) for(i <- 0 until 20) { if (optResult(i)(7) != i * 2) { fail(s"optional Int value in result row $i should be ${2*4*i}") @@ -580,21 +587,21 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query8.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result8 = query8.collect() - assert(result8.size === 25) - assert(result8(0)(7) === 100) - assert(result8(24)(7) === 148) + assert(EQ(result8.size).===(25)) + assert(EQ(result8(0)(7)).===(100)) + assert(EQ(result8(24)(7)).===(148)) val query9 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") assert( query9.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result9 = query9.collect() - assert(result9.size === 25) + assert(EQ(result9.size).===(25)) if (myval == "myoptint" || myval == "myoptlong") { - assert(result9(0)(7) === 152) - assert(result9(24)(7) === 200) + assert(EQ(result9(0)(7)).===(152)) + assert(EQ(result9(24)(7)).===(200)) } else { - assert(result9(0)(7) === 150) - assert(result9(24)(7) === 198) + assert(EQ(result9(0)(7)).===(150)) + assert(EQ(result9(24)(7)).===(198)) } } val query10 = sql("SELECT * FROM testfiltersource WHERE myoptstring = \"100\"") @@ -602,15 +609,15 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query10.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result10 = query10.collect() - assert(result10.size === 1) + assert(EQ(result10.size).===(1)) assert(result10(0).getString(8) == "100", "stringvalue incorrect") - assert(result10(0).getInt(7) === 100) + assert(EQ(result10(0).getInt(7)).===(100)) val query11 = sql(s"SELECT * FROM testfiltersource WHERE myoptboolean = true AND myoptint < 40") assert( query11.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result11 = query11.collect() - assert(result11.size === 7) + assert(EQ(result11.size).===(7)) for(i <- 0 until 6) { if (!result11(i).getBoolean(6)) { fail(s"optional Boolean value in result row $i not true") @@ -623,7 +630,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { val query = sql(s"SELECT mystring FROM testfiltersource WHERE myint < 10") - assert(query.collect().size === 10) + assert(EQ(query.collect().size).===(10)) } test("Importing nested Parquet file (Addressbook)") { @@ -632,32 +639,32 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD .collect() assert(result != null) - assert(result.size === 2) + assert(EQ(result.size).===(2)) val first_record = result(0) val second_record = result(1) assert(first_record != null) assert(second_record != null) - assert(first_record.size === 3) - assert(second_record(1) === null) - assert(second_record(2) === null) - assert(second_record(0) === "A. Nonymous") - assert(first_record(0) === "Julien Le Dem") + assert(EQ(first_record.size).===(3)) + assert(EQ(second_record(1)).===(null)) + assert(EQ(second_record(2)).===(null)) + assert(EQ(second_record(0)).===("A. Nonymous")) + assert(EQ(first_record(0)).===("Julien Le Dem")) val first_owner_numbers = first_record(1) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] val first_contacts = first_record(2) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] assert(first_owner_numbers != null) - assert(first_owner_numbers(0) === "555 123 4567") - assert(first_owner_numbers(2) === "XXX XXX XXXX") - assert(first_contacts(0) - .asInstanceOf[CatalystConverter.StructScalaType[_]].size === 2) + assert(EQ(first_owner_numbers(0)).===("555 123 4567")) + assert(EQ(first_owner_numbers(2)).===("XXX XXX XXXX")) + assert(EQ(first_contacts(0) + .asInstanceOf[CatalystConverter.StructScalaType[_]].size).===(2)) val first_contacts_entry_one = first_contacts(0) .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(first_contacts_entry_one(0) === "Dmitriy Ryaboy") - assert(first_contacts_entry_one(1) === "555 987 6543") + assert(EQ(first_contacts_entry_one(0)).===("Dmitriy Ryaboy")) + assert(EQ(first_contacts_entry_one(1)).===("555 987 6543")) val first_contacts_entry_two = first_contacts(1) .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(first_contacts_entry_two(0) === "Chris Aniszczyk") + assert(EQ(first_contacts_entry_two(0)).===("Chris Aniszczyk")) } test("Importing nested Parquet file (nested numbers)") { @@ -665,31 +672,31 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .parquetFile(ParquetTestData.testNestedDir2.toString) .toSchemaRDD .collect() - assert(result.size === 1, "number of top-level rows incorrect") - assert(result(0).size === 5, "number of fields in row incorrect") - assert(result(0)(0) === 1) - assert(result(0)(1) === 7) + assert(EQ(result.size).===(1), "number of top-level rows incorrect") + assert(EQ(result(0).size).===(5), "number of fields in row incorrect") + assert(EQ(result(0)(0)).===(1)) + assert(EQ(result(0)(1)).===(7)) val subresult1 = result(0)(2).asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult1.size === 3) - assert(subresult1(0) === (1.toLong << 32)) - assert(subresult1(1) === (1.toLong << 33)) - assert(subresult1(2) === (1.toLong << 34)) + assert(EQ(subresult1.size).===(3)) + assert(EQ(subresult1(0)).===((1.toLong << 32))) + assert(EQ(subresult1(1)).===((1.toLong << 33))) + assert(EQ(subresult1(2)).===((1.toLong << 34))) val subresult2 = result(0)(3) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(subresult2.size === 2) - assert(subresult2(0) === 2.5) - assert(subresult2(1) === false) + assert(EQ(subresult2.size).===(2)) + assert(EQ(subresult2(0)).===(2.5)) + assert(EQ(subresult2(1)).===(false)) val subresult3 = result(0)(4) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult3.size === 2) - assert(subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 2) + assert(EQ(subresult3.size).===(2)) + assert(EQ(subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size).===(2)) val subresult4 = subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult4(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) - assert(subresult4(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) - assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 1) - assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) + assert(EQ(subresult4(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(7)) + assert(EQ(subresult4(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(8)) + assert(EQ(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size).===(1)) + assert(EQ(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(9)) } test("Simple query on addressbook") { @@ -697,8 +704,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .parquetFile(ParquetTestData.testNestedDir1.toString) .toSchemaRDD val tmp = data.where('owner === "Julien Le Dem").select('owner as 'a, 'contacts as 'c).collect() - assert(tmp.size === 1) - assert(tmp(0)(0) === "Julien Le Dem") + assert(EQ(tmp.size).===(1)) + assert(EQ(tmp(0)(0)).===("Julien Le Dem")) } test("Projection in addressbook") { @@ -706,37 +713,37 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA data.registerTempTable("data") val query = sql("SELECT owner, contacts[1].name FROM data") val tmp = query.collect() - assert(tmp.size === 2) - assert(tmp(0).size === 2) - assert(tmp(0)(0) === "Julien Le Dem") - assert(tmp(0)(1) === "Chris Aniszczyk") - assert(tmp(1)(0) === "A. Nonymous") - assert(tmp(1)(1) === null) + assert(EQ(tmp.size).===(2)) + assert(EQ(tmp(0).size).===(2)) + assert(EQ(tmp(0)(0)).===("Julien Le Dem")) + assert(EQ(tmp(0)(1)).===("Chris Aniszczyk")) + assert(EQ(tmp(1)(0)).===("A. Nonymous")) + assert(EQ(tmp(1)(1)).===(null)) } test("Simple query on nested int data") { val data = parquetFile(ParquetTestData.testNestedDir2.toString).toSchemaRDD data.registerTempTable("data") val result1 = sql("SELECT entries[0].value FROM data").collect() - assert(result1.size === 1) - assert(result1(0).size === 1) - assert(result1(0)(0) === 2.5) + assert(EQ(result1.size).===(1)) + assert(EQ(result1(0).size).===(1)) + assert(EQ(result1(0)(0)).===(2.5)) val result2 = sql("SELECT entries[0] FROM data").collect() - assert(result2.size === 1) + assert(EQ(result2.size).===(1)) val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(subresult1.size === 2) - assert(subresult1(0) === 2.5) - assert(subresult1(1) === false) + assert(EQ(subresult1.size).===(2)) + assert(EQ(subresult1(0)).===(2.5)) + assert(EQ(subresult1(1)).===(false)) val result3 = sql("SELECT outerouter FROM data").collect() val subresult2 = result3(0)(0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult2(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) - assert(subresult2(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) - assert(result3(0)(0) + assert(EQ(subresult2(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(7)) + assert(EQ(subresult2(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(8)) + assert(EQ(result3(0)(0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](1) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(9)) } test("nested structs") { @@ -744,17 +751,17 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD data.registerTempTable("data") val result1 = sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() - assert(result1.size === 1) - assert(result1(0).size === 1) - assert(result1(0)(0) === false) + assert(EQ(result1.size).===(1)) + assert(EQ(result1(0).size).===(1)) + assert(EQ(result1(0)(0)).===(false)) val result2 = sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() - assert(result2.size === 1) - assert(result2(0).size === 1) - assert(result2(0)(0) === true) + assert(EQ(result2.size).===(1)) + assert(EQ(result2(0).size).===(1)) + assert(EQ(result2(0)(0)).===(true)) val result3 = sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() - assert(result3.size === 1) - assert(result3(0).size === 1) - assert(result3(0)(0) === false) + assert(EQ(result3.size).===(1)) + assert(EQ(result3(0).size).===(1)) + assert(EQ(result3(0)(0)).===(false)) } test("simple map") { @@ -763,38 +770,38 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD data.registerTempTable("mapTable") val result1 = sql("SELECT data1 FROM mapTable").collect() - assert(result1.size === 1) - assert(result1(0)(0) + assert(EQ(result1.size).===(1)) + assert(EQ(result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, _]] - .getOrElse("key1", 0) === 1) - assert(result1(0)(0) + .getOrElse("key1", 0)).===(1)) + assert(EQ(result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, _]] - .getOrElse("key2", 0) === 2) + .getOrElse("key2", 0)).===(2)) val result2 = sql("""SELECT data1["key1"] FROM mapTable""").collect() - assert(result2(0)(0) === 1) + assert(EQ(result2(0)(0)).===(1)) } test("map with struct values") { val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD data.registerTempTable("mapTable") val result1 = sql("SELECT data2 FROM mapTable").collect() - assert(result1.size === 1) + assert(EQ(result1.size).===(1)) val entry1 = result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] .getOrElse("seven", null) assert(entry1 != null) - assert(entry1(0) === 42) - assert(entry1(1) === "the answer") + assert(EQ(entry1(0)).===(42)) + assert(EQ(entry1(1)).===("the answer")) val entry2 = result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] .getOrElse("eight", null) assert(entry2 != null) - assert(entry2(0) === 49) - assert(entry2(1) === null) + assert(EQ(entry2(0)).===(49)) + assert(EQ(entry2(1)).===(null)) val result2 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() - assert(result2.size === 1) - assert(result2(0)(0) === 42.toLong) - assert(result2(0)(1) === "the answer") + assert(EQ(result2.size).===(1)) + assert(EQ(result2(0)(0)).===(42.toLong)) + assert(EQ(result2(0)(1)).===("the answer")) } test("Writing out Addressbook and reading it back in") { @@ -808,12 +815,12 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD .registerTempTable("tmpcopy") val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() - assert(tmpdata.size === 2) - assert(tmpdata(0).size === 2) - assert(tmpdata(0)(0) === "Julien Le Dem") - assert(tmpdata(0)(1) === "Chris Aniszczyk") - assert(tmpdata(1)(0) === "A. Nonymous") - assert(tmpdata(1)(1) === null) + assert(EQ(tmpdata.size).===(2)) + assert(EQ(tmpdata(0).size).===(2)) + assert(EQ(tmpdata(0)(0)).===("Julien Le Dem")) + assert(EQ(tmpdata(0)(1)).===("Chris Aniszczyk")) + assert(EQ(tmpdata(1)(0)).===("A. Nonymous")) + assert(EQ(tmpdata(1)(1)).===(null)) Utils.deleteRecursively(tmpdir) } @@ -826,26 +833,26 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD .registerTempTable("tmpmapcopy") val result1 = sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() - assert(result1.size === 1) - assert(result1(0)(0) === 2) + assert(EQ(result1.size).===(1)) + assert(EQ(result1(0)(0)).===(2)) val result2 = sql("SELECT data2 FROM tmpmapcopy").collect() - assert(result2.size === 1) + assert(EQ(result2.size).===(1)) val entry1 = result2(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] .getOrElse("seven", null) assert(entry1 != null) - assert(entry1(0) === 42) - assert(entry1(1) === "the answer") + assert(EQ(entry1(0)).===(42)) + assert(EQ(entry1(1)).===("the answer")) val entry2 = result2(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] .getOrElse("eight", null) assert(entry2 != null) - assert(entry2(0) === 49) - assert(entry2(1) === null) + assert(EQ(entry2(0)).===(49)) + assert(EQ(entry2(1)).===(null)) val result3 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() - assert(result3.size === 1) - assert(result3(0)(0) === 42.toLong) - assert(result3(0)(1) === "the answer") + assert(EQ(result3.size).===(1)) + assert(EQ(result3(0)(0)).===(42.toLong)) + assert(EQ(result3(0)(1)).===("the answer")) Utils.deleteRecursively(tmpdir) } @@ -854,7 +861,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA Utils.deleteRecursively(tmpdir) createParquetFile[TestRDDEntry](tmpdir.toString()).registerTempTable("tmpemptytable") val result1 = sql("SELECT * FROM tmpemptytable").collect() - assert(result1.size === 0) + assert(EQ(result1.size).===(0)) Utils.deleteRecursively(tmpdir) } @@ -868,7 +875,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA (fromCaseClassString, fromJson).zipped.foreach { (a, b) => assert(a.name == b.name) - assert(a.dataType === b.dataType) + assert(EQ(a.dataType).===(b.dataType)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index a90fc023e67d8..f3323952b75d5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -27,6 +27,13 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class StatisticsSuite extends QueryTest with BeforeAndAfterAll { TestHive.reset() TestHive.cacheTables = false @@ -39,7 +46,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { case o => o } - assert(operators.size === 1) + assert(EQ(operators.size).===(1)) if (operators(0).getClass() != c) { fail( s"""$analyzeCommand expected command: $c, but got ${operators(0)} @@ -81,11 +88,11 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // TODO: How does it works? needs to add it back for other hive version. if (HiveShim.version =="0.12.0") { - assert(queryTotalSize("analyzeTable") === defaultSizeInBytes) + assert(EQ(queryTotalSize("analyzeTable")).===(defaultSizeInBytes)) } sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") - assert(queryTotalSize("analyzeTable") === BigInt(11624)) + assert(EQ(queryTotalSize("analyzeTable")).===(BigInt(11624))) sql("DROP TABLE analyzeTable").collect() @@ -110,11 +117,11 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === defaultSizeInBytes) + assert(EQ(queryTotalSize("analyzeTable_part")).===(defaultSizeInBytes)) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") - assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) + assert(EQ(queryTotalSize("analyzeTable_part")).===(BigInt(17436))) sql("DROP TABLE analyzeTable_part").collect() @@ -131,7 +138,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => mr.statistics.sizeInBytes } - assert(sizes.size === 1, s"Size wrong for:\n ${rdd.queryExecution}") + assert(EQ(sizes.size).===(1), s"Size wrong for:\n ${rdd.queryExecution}") assert(sizes(0).equals(BigInt(5812)), s"expected exact size 5812 for test table 'src', got: ${sizes(0)}") } @@ -151,14 +158,14 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = rdd.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= autoBroadcastJoinThreshold + assert(EQ(sizes.size).===(2) && sizes(0) <= autoBroadcastJoinThreshold && sizes(1) <= autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be // matched, other strategies need to be applied. var bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } - assert(bhj.size === 1, + assert(EQ(bhj.size).===(1), s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") checkAnswer(rdd, expectedAnswer) // check correctness of output @@ -172,7 +179,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") val shj = rdd.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } - assert(shj.size === 1, + assert(EQ(shj.size).===(1), "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 5918f888c8f4c..b5e70e0f30599 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -30,6 +30,13 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{Row, SchemaRDD} +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class TestData(a: Int, b: String) /** @@ -139,7 +146,7 @@ class HiveQuerySuite extends HiveComparisonTest { test("CREATE TABLE AS runs once") { sql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() - assert(sql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, + assert(EQ(sql("SELECT COUNT(*) FROM foo").collect().head.getLong(0)).===(1), "Incorrect number of rows in created table") } @@ -161,7 +168,7 @@ class HiveQuerySuite extends HiveComparisonTest { test("Query expressed in SQL") { setConf("spark.sql.dialect", "sql") - assert(sql("SELECT 1").collect() === Array(Seq(1))) + assert(EQ(sql("SELECT 1").collect()).===(Array(Seq(1)))) setConf("spark.sql.dialect", "hiveql") } @@ -365,7 +372,7 @@ class HiveQuerySuite extends HiveComparisonTest { .collect() .toSet - assert(actual === expected) + assert(EQ(actual).===(expected)) } // TODO: adopt this test when Spark SQL has the functionality / framework to report errors. @@ -415,7 +422,7 @@ class HiveQuerySuite extends HiveComparisonTest { .collect() .map(x => Pair(x.getString(0), x.getInt(1))) - assert(results === Array(Pair("foo", 4))) + assert(EQ(results).===(Array(Pair("foo", 4)))) TestHive.reset() } @@ -557,8 +564,8 @@ class HiveQuerySuite extends HiveComparisonTest { sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).map { case (Row(map: Map[_, _]), Row(key: Int, value: String)) => - assert(map.size === 1) - assert(map.head === (key, value)) + assert(EQ(map.size).===(1)) + assert(EQ(map.head).===((key, value))) } } @@ -654,7 +661,7 @@ class HiveQuerySuite extends HiveComparisonTest { sql("CREATE TABLE dp_verify(intcol INT)") sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE dp_verify") - assert(sql("SELECT * FROM dp_verify").collect() === Array(Row(value))) + assert(EQ(sql("SELECT * FROM dp_verify").collect()).===(Array(Row(value)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 6f57fe8958387..be1fbca69708f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -29,6 +29,13 @@ import org.apache.spark.util.Utils // Implicits import org.apache.spark.sql.hive.test.TestHive._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(convertToEqualizer(X).===(true)) + * (This file already imports convertToEqualizer) + */ + case class Cases(lower: String, UPPER: String) class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { @@ -80,7 +87,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft test("Simple column projection + filter on Parquet table") { val rdd = sql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect() - assert(rdd.size === 5, "Filter returned incorrect number of rows") + assert(convertToEqualizer(rdd.size).===(5), "Filter returned incorrect number of rows") assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value") } @@ -102,7 +109,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() val rddCopy = sql("SELECT * FROM ptable").collect() val rddOrig = sql("SELECT * FROM testsource").collect() - assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??") + assert(convertToEqualizer(rddCopy.size).===(rddOrig.size), "INSERT OVERWRITE changed size of table??") compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames) } @@ -111,7 +118,8 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft (rddOne, rddTwo).zipped.foreach { (a,b) => (a,b).zipped.toArray.zipWithIndex.foreach { case ((value_1, value_2), index) => - assert(value_1 === value_2, s"table $tableName row $counter field ${fieldNames(index)} don't match") + assert(convertToEqualizer(value_1).===(value_2), + s"table $tableName row $counter field ${fieldNames(index)} don't match") } counter = counter + 1 } From a9112401483d88e8a0200caaa777dec49aee70c4 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Wed, 5 Nov 2014 11:30:12 -0600 Subject: [PATCH 35/79] Removed Date and Timestamp from NativeTypes as this would force changes in the code generator, and this PR is already too big. Now passes the standard test suite (except SparkSubmitSuite, which I think is unrelated to these changes.) --- .../scala/org/apache/spark/sql/catalyst/dsl/package.scala | 2 +- .../org/apache/spark/sql/catalyst/types/dataTypes.scala | 5 +---- .../src/main/scala/org/apache/spark/sql/package.scala | 8 ++++---- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index bc4fbac7af1e1..4ba52c724f03a 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -151,7 +151,7 @@ package object dsl { * where a literal is being combined with a symbol. Without these an * expression such as 0 < 'x is not recognized. */ - case class LhsLiteral(x: Any) { + class LhsLiteral(x: Any) { val literal = Literal(x) def + (other: Symbol) = Add(literal, other) def - (other: Symbol) = Subtract(literal, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 7782cb05b8a40..8dda0b182805c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -187,8 +187,7 @@ case object NullType extends DataType object NativeType { val all = Seq( - IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, - ByteType, StringType, DateType, TimestampType) + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) def unapply(dt: DataType): Boolean = all.contains(dt) @@ -200,8 +199,6 @@ object NativeType { FloatType -> 4, ShortType -> 2, ByteType -> 1, - DateType -> 8, - TimestampType -> 12, StringType -> 4096) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index ea2f6ef103d22..1527196eb19cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -460,13 +460,13 @@ package object sql { * are provided. The class intializer accepts a String, e.g. * * {{{ - * val d = Date("2014-01-01") + * val d = RichDate("2014-01-01") * }}} * * @group dataType */ @DeveloperApi - val Date = catalyst.expressions.RichDate + val RichDate = catalyst.expressions.RichDate /** * :: DeveloperApi :: @@ -477,11 +477,11 @@ package object sql { * String, e.g. * * {{{ - * val ts = Timestamp("2014-01-01 12:34:56.78") + * val ts = RichTimestamp("2014-01-01 12:34:56.78") * }}} * * @group timeClasses */ @DeveloperApi - val Timestamp = catalyst.expressions.RichTimestamp + val RichTimestamp = catalyst.expressions.RichTimestamp } From 5b3b6f6f5f029164d7749366506e142b104c1d43 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 5 Nov 2014 10:33:13 -0800 Subject: [PATCH 36/79] [SPARK-4197] [mllib] GradientBoosting API cleanup and examples in Scala, Java MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Summary * Made it easier to construct default Strategy and BoostingStrategy and to set parameters using simple types. * Added Scala and Java examples for GradientBoostedTrees * small cleanups and fixes ### Details GradientBoosting bug fixes (“bug” = bad default options) * Force boostingStrategy.weakLearnerParams.algo = Regression * Force boostingStrategy.weakLearnerParams.impurity = impurity.Variance * Only persist data if not yet persisted (since it causes an error if persisted twice) BoostingStrategy * numEstimators: renamed to numIterations * removed subsamplingRate (duplicated by Strategy) * removed categoricalFeaturesInfo since it belongs with the weak learner params (since boosting can be oblivious to feature type) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added assertValid() method * Updated defaultParams() method and eliminated defaultWeakLearnerParams() since that belongs in Strategy Strategy (for DecisionTree) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added setCategoricalFeaturesInfo method taking Java Map. * Cleaned up assertValid * Changed val’s to def’s since parameters can now be changed. CC: manishamde mengxr codedeft Author: Joseph K. Bradley Closes #3094 from jkbradley/gbt-api and squashes the following commits: 7a27e22 [Joseph K. Bradley] scalastyle fix 52013d5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into gbt-api e9b8410 [Joseph K. Bradley] Summary of changes --- .../mllib/JavaGradientBoostedTrees.java | 126 +++++++++++++ .../examples/mllib/DecisionTreeRunner.scala | 64 +++++-- .../examples/mllib/GradientBoostedTrees.scala | 146 +++++++++++++++ .../spark/mllib/tree/GradientBoosting.scala | 169 ++++++------------ .../tree/configuration/BoostingStrategy.scala | 78 ++++---- .../mllib/tree/configuration/Strategy.scala | 51 ++++-- .../mllib/tree/GradientBoostingSuite.scala | 34 ++-- 7 files changed, 462 insertions(+), 206 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java new file mode 100644 index 0000000000000..1af2067b2b929 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.GradientBoosting; +import org.apache.spark.mllib.tree.configuration.BoostingStrategy; +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel; +import org.apache.spark.mllib.util.MLUtils; + +/** + * Classification and regression using gradient-boosted decision trees. + */ +public final class JavaGradientBoostedTrees { + + private static void usage() { + System.err.println("Usage: JavaGradientBoostedTrees " + + " "); + System.exit(-1); + } + + public static void main(String[] args) { + String datapath = "data/mllib/sample_libsvm_data.txt"; + String algo = "Classification"; + if (args.length >= 1) { + datapath = args[0]; + } + if (args.length >= 2) { + algo = args[1]; + } + if (args.length > 2) { + usage(); + } + SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); + + // Set parameters. + // Note: All features are treated as continuous. + BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); + boostingStrategy.setNumIterations(10); + boostingStrategy.weakLearnerParams().setMaxDepth(5); + + if (algo.equals("Classification")) { + // Compute the number of classes from the data. + Integer numClasses = data.map(new Function() { + @Override public Double call(LabeledPoint p) { + return p.label(); + } + }).countByValue().size(); + boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression + + // Train a GradientBoosting model for classification. + final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / data.count(); + System.out.println("Training error: " + trainErr); + System.out.println("Learned classification tree model:\n" + model); + } else if (algo.equals("Regression")) { + // Train a GradientBoosting model for classification. + final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainMSE = + predictionAndLabel.map(new Function, Double>() { + @Override public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Training Mean Squared Error: " + trainMSE); + System.out.println("Learned regression tree model:\n" + model); + } else { + usage(); + } + + sc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 49751a30491d0..63f02cf7b98b9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -154,20 +154,30 @@ object DecisionTreeRunner { } } - def run(params: Params) { - - val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") - val sc = new SparkContext(conf) - - println(s"DecisionTreeRunner with parameters:\n$params") - + /** + * Load training and test data from files. + * @param input Path to input dataset. + * @param dataFormat "libsvm" or "dense" + * @param testInput Path to test dataset. + * @param algo Classification or Regression + * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. + * @return (training dataset, test dataset, number of classes), + * where the number of classes is inferred from data (and set to 0 for Regression) + */ + private[mllib] def loadDatasets( + sc: SparkContext, + input: String, + dataFormat: String, + testInput: String, + algo: Algo, + fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = { // Load training data and cache it. - val origExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() + val origExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache() } // For classification, re-index classes if needed. - val (examples, classIndexMap, numClasses) = params.algo match { + val (examples, classIndexMap, numClasses) = algo match { case Classification => { // classCounts: class --> # examples in class val classCounts = origExamples.map(_.label).countByValue() @@ -205,14 +215,14 @@ object DecisionTreeRunner { } // Create training, test sets. - val splits = if (params.testInput != "") { + val splits = if (testInput != "") { // Load testInput. val numFeatures = examples.take(1)(0).features.size - val origTestExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures) + val origTestExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, testInput) + case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures) } - params.algo match { + algo match { case Classification => { // classCounts: class --> # examples in class val testExamples = { @@ -229,17 +239,31 @@ object DecisionTreeRunner { } } else { // Split input into training, test. - examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) + examples.randomSplit(Array(1.0 - fracTest, fracTest)) } val training = splits(0).cache() val test = splits(1).cache() + val numTraining = training.count() val numTest = test.count() - println(s"numTraining = $numTraining, numTest = $numTest.") examples.unpersist(blocking = false) + (training, test, numClasses) + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") + val sc = new SparkContext(conf) + + println(s"DecisionTreeRunner with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat, + params.testInput, params.algo, params.fracTest) + val impurityCalculator = params.impurity match { case Gini => impurity.Gini case Entropy => impurity.Entropy @@ -338,7 +362,9 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. */ - private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + private[mllib] def meanSquaredError( + tree: WeightedEnsembleModel, + data: RDD[LabeledPoint]): Double = { data.map { y => val err = tree.predict(y.features) - y.label err * err diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala new file mode 100644 index 0000000000000..9b6db01448be0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.tree.GradientBoosting +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} +import org.apache.spark.util.Utils + +/** + * An example runner for Gradient Boosting using decision trees as weak learners. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify categoricalFeaturesInfo. + */ +object GradientBoostedTrees { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "Classification", + maxDepth: Int = 5, + numIterations: Int = 10, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("GradientBoostedTrees") { + head("GradientBoostedTrees: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("numIterations") + .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}") + .action((x, c) => c.copy(numIterations = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params") + val sc = new SparkContext(conf) + + println(s"GradientBoostedTrees with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest) + + val boostingStrategy = BoostingStrategy.defaultParams(params.algo) + boostingStrategy.numClassesForClassification = numClasses + boostingStrategy.numIterations = params.numIterations + boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth + + val randomSeed = Utils.random.nextInt() + if (params.algo == "Classification") { + val startTime = System.nanoTime() + val model = GradientBoosting.trainClassifier(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainAccuracy = + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) + .precision + println(s"Train accuracy = $trainAccuracy") + val testAccuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $testAccuracy") + } else if (params.algo == "Regression") { + val startTime = System.nanoTime() + val model = GradientBoosting.trainRegressor(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainMSE = DecisionTreeRunner.meanSquaredError(model, training) + println(s"Train mean squared error = $trainMSE") + val testMSE = DecisionTreeRunner.meanSquaredError(model, test) + println(s"Test mean squared error = $testMSE") + } + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala index 1a847201ce157..f729344a682e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala @@ -17,30 +17,49 @@ package org.apache.spark.mllib.tree -import scala.collection.JavaConverters._ - +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy} -import org.apache.spark.Logging -import org.apache.spark.mllib.tree.impl.TimeTracker -import org.apache.spark.mllib.tree.loss.Losses -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum +import org.apache.spark.mllib.tree.impl.TimeTracker +import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * :: Experimental :: - * A class that implements gradient boosting for regression and binary classification problems. + * A class that implements Stochastic Gradient Boosting + * for regression and binary classification problems. + * + * The implementation is based upon: + * J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes: + * - This currently can be run with several loss functions. However, only SquaredError is + * fully supported. Specifically, the loss function should be used to compute the gradient + * (to re-label training instances on each iteration) and to weight weak hypotheses. + * Currently, gradients are computed correctly for the available loss functions, + * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. + * Running with those losses will likely behave reasonably, but lacks the same guarantees. + * * @param boostingStrategy Parameters for the gradient boosting algorithm */ @Experimental class GradientBoosting ( private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { + boostingStrategy.weakLearnerParams.algo = Regression + boostingStrategy.weakLearnerParams.impurity = impurity.Variance + + // Ensure values for weak learner are the same as what is provided to the boosting algorithm. + boostingStrategy.weakLearnerParams.numClassesForClassification = + boostingStrategy.numClassesForClassification + + boostingStrategy.assertValid() + /** * Method to train a gradient boosting model * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. @@ -51,6 +70,7 @@ class GradientBoosting ( algo match { case Regression => GradientBoosting.boost(input, boostingStrategy) case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoosting.boost(remappedInput, boostingStrategy) case _ => @@ -118,120 +138,32 @@ object GradientBoosting extends Logging { } /** - * Method to train a gradient boosting binary classification model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. - * @param loss Loss function used for minimization during gradient boosting. - * @param learningRate Learning rate for shrinking the contribution of each estimator. The - * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. - * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is - * supported.) - * @return WeightedEnsembleModel that can be used for prediction + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]] */ - def trainClassifier( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: Map[Int, Int], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - val lossType = Losses.fromString(loss) - val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType, - learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, - weakLearnerParams) - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Method to train a gradient boosting regression model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. - * @param loss Loss function used for minimization during gradient boosting. - * @param learningRate Learning rate for shrinking the contribution of each estimator. The - * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. - * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is - * supported.) - * @return WeightedEnsembleModel that can be used for prediction - */ - def trainRegressor( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: Map[Int, Int], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - val lossType = Losses.fromString(loss) - val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType, - learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, - weakLearnerParams) - new GradientBoosting(boostingStrategy).train(input) + def train( + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + train(input.rdd, boostingStrategy) } /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] */ def trainClassifier( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate, - numClassesForClassification, - categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, - weakLearnerParams) + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + trainClassifier(input.rdd, boostingStrategy) } /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] */ def trainRegressor( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate, - numClassesForClassification, - categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, - weakLearnerParams) + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + trainRegressor(input.rdd, boostingStrategy) } - /** * Internal method for performing regression using trees as base learners. * @param input training dataset @@ -247,15 +179,17 @@ object GradientBoosting extends Logging { timer.start("init") // Initialize gradient boosting parameters - val numEstimators = boostingStrategy.numEstimators - val baseLearners = new Array[DecisionTreeModel](numEstimators) - val baseLearnerWeights = new Array[Double](numEstimators) + val numIterations = boostingStrategy.numIterations + val baseLearners = new Array[DecisionTreeModel](numIterations) + val baseLearnerWeights = new Array[Double](numIterations) val loss = boostingStrategy.loss val learningRate = boostingStrategy.learningRate val strategy = boostingStrategy.weakLearnerParams // Cache input - input.persist(StorageLevel.MEMORY_AND_DISK) + if (input.getStorageLevel == StorageLevel.NONE) { + input.persist(StorageLevel.MEMORY_AND_DISK) + } timer.stop("init") @@ -264,7 +198,7 @@ object GradientBoosting extends Logging { logDebug("##########") var data = input - // 1. Initialize tree + // Initialize tree timer.start("building tree 0") val firstTreeModel = new DecisionTree(strategy).train(data) baseLearners(0) = firstTreeModel @@ -280,7 +214,7 @@ object GradientBoosting extends Logging { point.features)) var m = 1 - while (m < numEstimators) { + while (m < numIterations) { timer.start(s"building tree $m") logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) @@ -289,6 +223,9 @@ object GradientBoosting extends Logging { timer.stop(s"building tree $m") // Create partial model baseLearners(m) = model + // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. + // Technically, the weight should be optimized for the particular loss. + // However, the behavior should be reasonable, though not optimal. baseLearnerWeights(m) = learningRate // Note: A model of type regression is used since we require raw prediction val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1), @@ -305,8 +242,6 @@ object GradientBoosting extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - - // 3. Output classifier new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 501d9ff9ea9b7..abbda040bd528 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -21,7 +21,6 @@ import scala.beans.BeanProperty import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.{Gini, Variance} import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} /** @@ -30,46 +29,58 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * @param algo Learning goal. Supported: * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. + * @param numIterations Number of iterations of boosting. In other words, the number of + * weak hypotheses used in the final model. * @param loss Loss function used for minimization during gradient boosting. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. * @param numClassesForClassification Number of classes for classification. * (Ignored for regression.) + * This setting overrides any setting in [[weakLearnerParams]]. * Default value is 2 (binary classification). - * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are * supported. */ @Experimental case class BoostingStrategy( // Required boosting parameters - algo: Algo, - @BeanProperty var numEstimators: Int, + @BeanProperty var algo: Algo, + @BeanProperty var numIterations: Int, @BeanProperty var loss: Loss, // Optional boosting parameters @BeanProperty var learningRate: Double = 0.1, - @BeanProperty var subsamplingRate: Double = 1.0, @BeanProperty var numClassesForClassification: Int = 2, - @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), @BeanProperty var weakLearnerParams: Strategy) extends Serializable { - require(learningRate <= 1, "Learning rate should be <= 1. Provided learning rate is " + - s"$learningRate.") - require(learningRate > 0, "Learning rate should be > 0. Provided learning rate is " + - s"$learningRate.") - // Ensure values for weak learner are the same as what is provided to the boosting algorithm. - weakLearnerParams.categoricalFeaturesInfo = categoricalFeaturesInfo weakLearnerParams.numClassesForClassification = numClassesForClassification - weakLearnerParams.subsamplingRate = subsamplingRate + /** + * Sets Algorithm using a String. + */ + def setAlgo(algo: String): Unit = algo match { + case "Classification" => setAlgo(Classification) + case "Regression" => setAlgo(Regression) + } + + /** + * Check validity of parameters. + * Throws exception if invalid. + */ + private[tree] def assertValid(): Unit = { + algo match { + case Classification => + require(numClassesForClassification == 2) + case Regression => + // nothing + case _ => + throw new IllegalArgumentException( + s"BoostingStrategy given invalid algo parameter: $algo." + + s" Valid settings are: Classification, Regression.") + } + require(learningRate > 0 && learningRate <= 1, + "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.") + } } @Experimental @@ -82,28 +93,17 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: Algo): BoostingStrategy = { - val treeStrategy = defaultWeakLearnerParams(algo) + def defaultParams(algo: String): BoostingStrategy = { + val treeStrategy = Strategy.defaultStrategy("Regression") + treeStrategy.maxDepth = 3 algo match { - case Classification => - new BoostingStrategy(algo, 100, LogLoss, weakLearnerParams = treeStrategy) - case Regression => - new BoostingStrategy(algo, 100, SquaredError, weakLearnerParams = treeStrategy) + case "Classification" => + new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy) + case "Regression" => + new BoostingStrategy(Algo.withName(algo), 100, SquaredError, + weakLearnerParams = treeStrategy) case _ => throw new IllegalArgumentException(s"$algo is not supported by the boosting.") } } - - /** - * Returns default configuration for the weak learner (decision tree) algorithm - * @param algo Learning goal. Supported: - * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], - * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] - * @return Configuration for weak learner - */ - def defaultWeakLearnerParams(algo: Algo): Strategy = { - // Note: Regression tree used even for classification for GBT. - new Strategy(Regression, Variance, 3) - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d09295c507d67..b5b1f82177edc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -70,7 +70,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ */ @Experimental class Strategy ( - val algo: Algo, + @BeanProperty var algo: Algo, @BeanProperty var impurity: Impurity, @BeanProperty var maxDepth: Int, @BeanProperty var numClassesForClassification: Int = 2, @@ -85,17 +85,9 @@ class Strategy ( @BeanProperty var checkpointDir: Option[String] = None, @BeanProperty var checkpointInterval: Int = 10) extends Serializable { - if (algo == Classification) { - require(numClassesForClassification >= 2) - } - require(minInstancesPerNode >= 1, - s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") - require(maxMemoryInMB <= 10240, - s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") - - val isMulticlassClassification = + def isMulticlassClassification = algo == Classification && numClassesForClassification > 2 - val isMulticlassWithCategoricalFeatures + def isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) /** @@ -112,6 +104,23 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + /** + * Sets Algorithm using a String. + */ + def setAlgo(algo: String): Unit = algo match { + case "Classification" => setAlgo(Classification) + case "Regression" => setAlgo(Regression) + } + + /** + * Sets categoricalFeaturesInfo using a Java Map. + */ + def setCategoricalFeaturesInfo( + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = { + setCategoricalFeaturesInfo( + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) + } + /** * Check validity of parameters. * Throws exception if invalid. @@ -143,6 +152,26 @@ class Strategy ( s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + s" feature $feature has $arity categories. The number of categories should be >= 2.") } + require(minInstancesPerNode >= 1, + s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") + require(maxMemoryInMB <= 10240, + s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") } +} + +@Experimental +object Strategy { + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo "Classification" or "Regression" + */ + def defaultStrategy(algo: String): Strategy = algo match { + case "Classification" => + new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, + numClassesForClassification = 2) + case "Regression" => + new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, + numClassesForClassification = 0) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala index 970fff82215e2..99a02eda60baf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala @@ -22,9 +22,8 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} -import org.apache.spark.mllib.tree.impurity.{Variance, Gini} +import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss} -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.util.LocalSparkContext @@ -34,9 +33,8 @@ import org.apache.spark.mllib.util.LocalSparkContext class GradientBoostingSuite extends FunSuite with LocalSparkContext { test("Regression with continuous features: SquaredError") { - GradientBoostingSuite.testCombinations.foreach { - case (numEstimators, learningRate, subsamplingRate) => + case (numIterations, learningRate, subsamplingRate) => val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -48,11 +46,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(remappedInput, treeStrategy) - val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError, - subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, + learningRate, 1, treeStrategy) val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numEstimators) + assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) @@ -63,9 +61,8 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { } test("Regression with continuous features: Absolute Error") { - GradientBoostingSuite.testCombinations.foreach { - case (numEstimators, learningRate, subsamplingRate) => + case (numIterations, learningRate, subsamplingRate) => val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -77,11 +74,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(remappedInput, treeStrategy) - val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError, - subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, + learningRate, numClassesForClassification = 2, treeStrategy) val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numEstimators) + assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) @@ -91,11 +88,9 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { } } - test("Binary classification with continuous features: Log Loss") { - GradientBoostingSuite.testCombinations.foreach { - case (numEstimators, learningRate, subsamplingRate) => + case (numIterations, learningRate, subsamplingRate) => val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -107,11 +102,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(remappedInput, treeStrategy) - val boostingStrategy = new BoostingStrategy(Classification, numEstimators, LogLoss, - subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss, + learningRate, numClassesForClassification = 2, treeStrategy) val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numEstimators) + assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) EnsembleTestHelper.validateClassifier(gbt, arr, 0.9) @@ -126,7 +121,6 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { object GradientBoostingSuite { // Combinations for estimators, learning rates and subsamplingRate - val testCombinations - = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) + val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) } From 3a9e31bd2173008fd3a71365d9f9e74a2d50b314 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Wed, 5 Nov 2014 13:41:05 -0600 Subject: [PATCH 37/79] Added tests for the features in this PR. Added Date and Timestamp as aliases for RichDate and RichTimestamp when importing an SQLContext. --- .../ExpressionEvaluationSuite.scala | 21 +++++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 6 ++++++ 2 files changed, 27 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 128a4860843ff..36a4adf9321a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -779,4 +779,25 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 ^ c2, 3, row) checkEvaluation(~c1, -2, row) } + + test("recognizes literals on the left") { + assert(EQ(-1 + 'x).===(Add(-1, 'x))) + assert(EQ(0 < 'x).===(LessThan(0, 'x))) + assert(EQ(1.5 === 'x).===(EqualTo(1.5, 'x))) + assert(EQ(false !== 'x).===(Not(EqualTo(false, 'x)))) + assert(EQ("a string" >= 'x).===(GreaterThanOrEqual("a string", 'x))) + assert(EQ(RichDate("2014-11-05") > 'date).===(GreaterThan(RichDate("2014-11-05"), 'date))) + assert(EQ(RichTimestamp("2014-11-05 12:34:56.789") < 'now).===( + LessThan(RichTimestamp("2014-11-05 12:34:56.789"), 'now))) + } + + test("comparison operators for RichDate and RichTimestamp") { + assert(EQ(RichDate("2014-11-05") < RichDate("2014-11-06")).===(true)) + assert(EQ(RichDate("2014-11-05") <= RichDate("2013-11-06")).===(false)) + assert(EQ(RichTimestamp("2014-11-05 12:34:56.5432") > RichTimestamp("2014-11-05 00:00:00") + ).===(true)) + assert(EQ(RichTimestamp("2014-11-05 12:34:56") >= RichTimestamp("2014-11-06 00:00:00") + ).===(false)) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4953f8399a96b..ac139b0447249 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -464,4 +464,10 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } + + /* * + * Map RichDate and RichTimestamp to their expected names in this context. + */ + val Date = org.apache.spark.sql.catalyst.expressions.RichDate + val Timestamp = org.apache.spark.sql.catalyst.expressions.RichTimestamp } From 4c42986cc070d9c5c55c7bf8a2a67585967b1082 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Wed, 5 Nov 2014 14:38:43 -0800 Subject: [PATCH 38/79] [SPARK-4242] [Core] Add SASL to external shuffle service Does three things: (1) Adds SASL to ExternalShuffleClient, (2) puts SecurityManager in BlockManager's constructor, and (3) adds unit test. Author: Aaron Davidson Closes #3108 from aarondav/sasl-client and squashes the following commits: 48b622d [Aaron Davidson] Screw it, let's just get LimitedInputStream 3543b70 [Aaron Davidson] Back out of pom change due to unknown test issue? b58518a [Aaron Davidson] ByteStreams.limit() not available :( cbe451a [Aaron Davidson] Address comments 2bf2908 [Aaron Davidson] [SPARK-4242] [Core] Add SASL to external shuffle service --- LICENSE | 21 +++- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 12 +- .../BlockManagerReplicationSuite.scala | 4 +- .../spark/storage/BlockManagerSuite.scala | 4 +- network/common/pom.xml | 1 + .../buffer/FileSegmentManagedBuffer.java | 3 +- .../network/util/LimitedInputStream.java | 87 ++++++++++++++ network/shuffle/pom.xml | 1 + .../spark/network/sasl/SparkSaslClient.java | 1 - .../spark/network/sasl/SparkSaslServer.java | 9 +- .../shuffle/ExternalShuffleClient.java | 31 ++++- .../ExternalShuffleIntegrationSuite.java | 4 +- .../shuffle/ExternalShuffleSecuritySuite.java | 113 ++++++++++++++++++ .../streaming/ReceivedBlockHandlerSuite.scala | 2 +- 15 files changed, 272 insertions(+), 23 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java diff --git a/LICENSE b/LICENSE index f1732fb47afc0..3c667bf45059a 100644 --- a/LICENSE +++ b/LICENSE @@ -754,7 +754,7 @@ SUCH DAMAGE. ======================================================================== -For Timsort (core/src/main/java/org/apache/spark/util/collection/Sorter.java): +For Timsort (core/src/main/java/org/apache/spark/util/collection/TimSort.java): ======================================================================== Copyright (C) 2008 The Android Open Source Project @@ -771,6 +771,25 @@ See the License for the specific language governing permissions and limitations under the License. +======================================================================== +For LimitedInputStream + (network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java): +======================================================================== +Copyright (C) 2007 The Guava Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + ======================================================================== BSD-style licenses ======================================================================== diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 45e9d7f243e96..e7454beddbfd0 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -287,7 +287,7 @@ object SparkEnv extends Logging { // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) + serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 655d16c65c8b5..a5fb87b9b2c51 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -72,7 +72,8 @@ private[spark] class BlockManager( val conf: SparkConf, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService) + blockTransferService: BlockTransferService, + securityManager: SecurityManager) extends BlockDataManager with Logging { val diskBlockManager = new DiskBlockManager(this, conf) @@ -115,7 +116,8 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTranserService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf)) + new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), securityManager, + securityManager.isAuthenticationEnabled()) } else { blockTransferService } @@ -166,9 +168,10 @@ private[spark] class BlockManager( conf: SparkConf, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService) = { + blockTransferService: BlockTransferService, + securityManager: SecurityManager) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, mapOutputTracker, shuffleManager, blockTransferService) + conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) } /** @@ -219,7 +222,6 @@ private[spark] class BlockManager( return } catch { case e: Exception if i < MAX_ATTEMPTS => - val attemptsRemaining = logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) Thread.sleep(SLEEP_TIME_SECS * 1000) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 1461fa69db90d..f63e772bf1e59 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -62,7 +62,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer) + mapOutputTracker, shuffleManager, transfer, securityMgr) store.initialize("app-id") allStores += store store @@ -263,7 +263,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, - 10000, conf, mapOutputTracker, shuffleManager, failableTransfer) + 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 0782876c8e3c6..9529502bc8e10 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -74,7 +74,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer) + mapOutputTracker, shuffleManager, transfer, securityMgr) manager.initialize("app-id") manager } @@ -795,7 +795,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, - new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer) + new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr) // The put should fail since a1 is not serializable. class UnserializableClass diff --git a/network/common/pom.xml b/network/common/pom.xml index ea887148d98ba..6144548a8f998 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -50,6 +50,7 @@ com.google.guava guava + 11.0.2 provided diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 89ed79bc63903..5fa1527ddff92 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -30,6 +30,7 @@ import io.netty.channel.DefaultFileRegion; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.LimitedInputStream; /** * A {@link ManagedBuffer} backed by a segment in a file. @@ -101,7 +102,7 @@ public InputStream createInputStream() throws IOException { try { is = new FileInputStream(file); ByteStreams.skipFully(is, offset); - return ByteStreams.limit(is, length); + return new LimitedInputStream(is, length); } catch (IOException e) { try { if (is != null) { diff --git a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java new file mode 100644 index 0000000000000..63ca43c046525 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; + +import com.google.common.base.Preconditions; + +/** + * Wraps a {@link InputStream}, limiting the number of bytes which can be read. + * + * This code is from Guava's 14.0 source code, because there is no compatible way to + * use this functionality in both a Guava 11 environment and a Guava >14 environment. + */ +public final class LimitedInputStream extends FilterInputStream { + private long left; + private long mark = -1; + + public LimitedInputStream(InputStream in, long limit) { + super(in); + Preconditions.checkNotNull(in); + Preconditions.checkArgument(limit >= 0, "limit must be non-negative"); + left = limit; + } + @Override public int available() throws IOException { + return (int) Math.min(in.available(), left); + } + // it's okay to mark even if mark isn't supported, as reset won't work + @Override public synchronized void mark(int readLimit) { + in.mark(readLimit); + mark = left; + } + @Override public int read() throws IOException { + if (left == 0) { + return -1; + } + int result = in.read(); + if (result != -1) { + --left; + } + return result; + } + @Override public int read(byte[] b, int off, int len) throws IOException { + if (left == 0) { + return -1; + } + len = (int) Math.min(len, left); + int result = in.read(b, off, len); + if (result != -1) { + left -= result; + } + return result; + } + @Override public synchronized void reset() throws IOException { + if (!in.markSupported()) { + throw new IOException("Mark not supported"); + } + if (mark == -1) { + throw new IOException("Mark not set"); + } + in.reset(); + left = mark; + } + @Override public long skip(long n) throws IOException { + n = Math.min(n, left); + long skipped = in.skip(n); + left -= skipped; + return skipped; + } +} diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index d271704d98a7a..fe5681d463499 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -51,6 +51,7 @@ com.google.guava guava + 11.0.2 provided diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java index 72ba737b998bc..9abad1f30a259 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -126,7 +126,6 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback logger.trace("SASL client callback: setting realm"); RealmCallback rc = (RealmCallback) callback; rc.setText(rc.getDefaultText()); - logger.info("Realm callback"); } else if (callback instanceof RealmChoiceCallback) { // ignore (?) } else { diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index 2c0ce40c75e80..e87b17ead1e1a 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -34,7 +34,8 @@ import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; -import com.google.common.io.BaseEncoding; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -159,12 +160,14 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback /* Encode a byte[] identifier as a Base64-encoded string. */ public static String encodeIdentifier(String identifier) { Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); - return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8)); + return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8); } /** Encode a password as a base64-encoded char[] array. */ public static char[] encodePassword(String password) { Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); - return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray(); + return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8).toCharArray(); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index b0b19ba67bddc..3aa95d00f6b20 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,12 +17,18 @@ package org.apache.spark.network.shuffle; +import java.util.List; + +import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.sasl.SaslClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; import org.apache.spark.network.util.JavaUtils; @@ -37,18 +43,35 @@ public class ExternalShuffleClient extends ShuffleClient { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); - private final TransportClientFactory clientFactory; + private final TransportConf conf; + private final boolean saslEnabled; + private final SecretKeyHolder secretKeyHolder; + private TransportClientFactory clientFactory; private String appId; - public ExternalShuffleClient(TransportConf conf) { - TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); - this.clientFactory = context.createClientFactory(); + /** + * Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled, + * then secretKeyHolder may be null. + */ + public ExternalShuffleClient( + TransportConf conf, + SecretKeyHolder secretKeyHolder, + boolean saslEnabled) { + this.conf = conf; + this.secretKeyHolder = secretKeyHolder; + this.saslEnabled = saslEnabled; } @Override public void init(String appId) { this.appId = appId; + TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + List bootstraps = Lists.newArrayList(); + if (saslEnabled) { + bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder)); + } + clientFactory = context.createClientFactory(bootstraps); } @Override diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index bc101f53844d5..71e017b9e4e74 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -135,7 +135,7 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @@ -267,7 +267,7 @@ public void testFetchNoServer() throws Exception { } private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { - ExternalShuffleClient client = new ExternalShuffleClient(conf); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java new file mode 100644 index 0000000000000..4c18fcdfbcd88 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleSecuritySuite { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportServer server; + + @Before + public void beforeEach() { + RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(), + new TestSecretKeyHolder("my-app-id", "secret")); + TransportContext context = new TransportContext(conf, handler); + this.server = context.createServer(); + } + + @After + public void afterEach() { + if (server != null) { + server.close(); + server = null; + } + } + + @Test + public void testValid() { + validate("my-app-id", "secret"); + } + + @Test + public void testBadAppId() { + try { + validate("wrong-app-id", "secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); + } + } + + @Test + public void testBadSecret() { + try { + validate("my-app-id", "bad-secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + /** Creates an ExternalShuffleClient and attempts to register with the server. */ + private void validate(String appId, String secretKey) { + ExternalShuffleClient client = + new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true); + client.init(appId); + // Registration either succeeds or throws an exception. + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", + new ExecutorShuffleInfo(new String[0], 0, "")); + client.close(); + } + + /** Provides a secret key holder which always returns the given secret key, for a single appId. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + private final String appId; + private final String secretKey; + + TestSecretKeyHolder(String appId, String secretKey) { + this.appId = appId; + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + if (!appId.equals(this.appId)) { + throw new IllegalArgumentException("Wrong appId!"); + } + return secretKey; + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 0f27f55fec4f3..9efe15d01ed0c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -73,7 +73,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, - new NioBlockTransferService(conf, securityMgr)) + new NioBlockTransferService(conf, securityMgr), securityMgr) blockManager.initialize("app-id") tempDirectory = Files.createTempDir() From a46497eecc50f854c5c5701dc2b8a2468b76c085 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Wed, 5 Nov 2014 15:30:31 -0800 Subject: [PATCH 39/79] [SPARK-3984] [SPARK-3983] Fix incorrect scheduler delay and display task deserialization time in UI This commit fixes the scheduler delay in the UI (which previously included things that are not scheduler delay, like time to deserialize the task and serialize the result), and also adds information about time to deserialize tasks to the optional additional metrics. Time to deserialize the task can be large relative to task time for short jobs, and understanding when it is high can help developers realize that they should try to reduce closure size (e.g, by including less data in the task description). cc shivaram etrain Author: Kay Ousterhout Closes #2832 from kayousterhout/SPARK-3983 and squashes the following commits: 0c1398e [Kay Ousterhout] Fixed ordering 531575d [Kay Ousterhout] Removed executor launch time 1f13afe [Kay Ousterhout] Minor spacing fixes 335be4b [Kay Ousterhout] Made metrics hideable 5bc3cba [Kay Ousterhout] [SPARK-3984] [SPARK-3983] Improve UI task metrics. --- .../org/apache/spark/executor/Executor.scala | 4 +-- .../scala/org/apache/spark/ui/ToolTips.scala | 3 ++ .../org/apache/spark/ui/jobs/StagePage.scala | 31 ++++++++++++++++++- .../spark/ui/jobs/TaskDetailsClassNames.scala | 1 + 4 files changed, 36 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index abc1dd0be6237..96114571d6c77 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -161,7 +161,7 @@ private[spark] class Executor( } override def run() { - val startTime = System.currentTimeMillis() + val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") @@ -206,7 +206,7 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.executorDeserializeTime = taskStart - startTime + m.executorDeserializeTime = taskStart - deserializeStartTime m.executorRunTime = taskFinish - taskStart m.jvmGCTime = gcTime - startGCTime m.resultSerializationTime = afterSerialization - beforeSerialization diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index f02904df31fcf..51dc08f668a43 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,6 +24,9 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" + val TASK_DESERIALIZATION_TIME = + """Time spent deserializating the task closure on the executor.""" + val INPUT = "Bytes read from Hadoop or from Spark storage." val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage." diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 7cc03b7d333df..63ed5fc4949c2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -112,6 +112,13 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { Scheduler Delay +
  • + + + Task Deserialization Time + +
  • @@ -147,6 +154,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), + ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), ("GC Time", TaskDetailsClassNames.GC_TIME), ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ @@ -179,6 +187,17 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { } } + val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.executorDeserializeTime.toDouble + } + val deserializationQuantiles = + + + Task Deserialization Time + + +: getFormattedTimeQuantiles(deserializationTimes) + val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorRunTime.toDouble } @@ -266,6 +285,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val listings: Seq[Seq[Node]] = Seq( {serviceQuantiles}, {schedulerDelayQuantiles}, + + {deserializationQuantiles} + {gcQuantiles}, {serializationQuantiles} @@ -314,6 +336,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L) val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) + val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) val gettingResultTime = info.gettingResultTime @@ -367,6 +390,10 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { class={TaskDetailsClassNames.SCHEDULER_DELAY}> {UIUtils.formatDuration(schedulerDelay.toLong)} + + {UIUtils.formatDuration(taskDeserializationTime.toLong)} + {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} @@ -424,6 +451,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { (info.finishTime - info.launchTime) } } - totalExecutionTime - metrics.executorRunTime + val executorOverhead = (metrics.executorDeserializeTime + + metrics.resultSerializationTime) + totalExecutionTime - metrics.executorRunTime - executorOverhead } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 23d672cabda07..eb371bd0ea7ed 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -24,6 +24,7 @@ package org.apache.spark.ui.jobs private object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" val GC_TIME = "gc_time" + val TASK_DESERIALIZATION_TIME = "deserialization_time" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" } From f37817b18a479839b2e6118cc1cbd1059a94db52 Mon Sep 17 00:00:00 2001 From: industrial-sloth Date: Wed, 5 Nov 2014 15:38:48 -0800 Subject: [PATCH 40/79] SPARK-4222 [CORE] use readFully in FixedLengthBinaryRecordReader replaces the existing read() call with readFully(). Author: industrial-sloth Closes #3093 from industrial-sloth/branch-1.2-fixedLenRecRdr and squashes the following commits: a245c8a [industrial-sloth] use readFully in FixedLengthBinaryRecordReader (cherry picked from commit 6844e7a8219ac78790a422ffd5054924e7d2bea1) Signed-off-by: Matei Zaharia --- .../org/apache/spark/input/FixedLengthBinaryRecordReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala index 5164a74bec4e9..36a1e5d475f46 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -115,7 +115,7 @@ private[spark] class FixedLengthBinaryRecordReader if (currentPosition < splitEnd) { // setup a buffer to store the record val buffer = recordValue.getBytes - fileInputStream.read(buffer, 0, recordLength) + fileInputStream.readFully(buffer) // update our current position currentPosition = currentPosition + recordLength // return true From 61a5cced049a8056292ba94f23fa7bd040f50685 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 5 Nov 2014 15:42:05 -0800 Subject: [PATCH 41/79] [SPARK-3797] Run external shuffle service in Yarn NM This creates a new module `network/yarn` that depends on `network/shuffle` recently created in #3001. This PR introduces a custom Yarn auxiliary service that runs the external shuffle service. As of the changes here this shuffle service is required for using dynamic allocation with Spark. This is still WIP mainly because it doesn't handle security yet. I have tested this on a stable Yarn cluster. Author: Andrew Or Closes #3082 from andrewor14/yarn-shuffle-service and squashes the following commits: ef3ddae [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-shuffle-service 0ee67a2 [Andrew Or] Minor wording suggestions 1c66046 [Andrew Or] Remove unused provided dependencies 0eb6233 [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-shuffle-service 6489db5 [Andrew Or] Try catch at the right places 7b71d8f [Andrew Or] Add detailed java docs + reword a few comments d1124e4 [Andrew Or] Add security to shuffle service (INCOMPLETE) 5f8a96f [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-shuffle-service 9b6e058 [Andrew Or] Address various feedback f48b20c [Andrew Or] Fix tests again f39daa6 [Andrew Or] Do not make network-yarn an assembly module 761f58a [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-shuffle-service 15a5b37 [Andrew Or] Fix build for Hadoop 1.x baff916 [Andrew Or] Fix tests 5bf9b7e [Andrew Or] Address a few minor comments 5b419b8 [Andrew Or] Add missing license header 804e7ff [Andrew Or] Include the Yarn shuffle service jar in the distribution cd076a4 [Andrew Or] Require external shuffle service for dynamic allocation ea764e0 [Andrew Or] Connect to Yarn shuffle service only if it's enabled 1bf5109 [Andrew Or] Use the shuffle service port specified through hadoop config b4b1f0c [Andrew Or] 4 tabs -> 2 tabs 43dcb96 [Andrew Or] First cut integration of shuffle service with Yarn aux service b54a0c4 [Andrew Or] Initial skeleton for Yarn shuffle service --- .../spark/ExecutorAllocationManager.scala | 37 +++- .../apache/spark/storage/BlockManager.scala | 8 +- .../scala/org/apache/spark/util/Utils.scala | 16 ++ make-distribution.sh | 3 + .../network/sasl/ShuffleSecretManager.java | 117 ++++++++++++ network/yarn/pom.xml | 58 ++++++ .../network/yarn/YarnShuffleService.java | 176 ++++++++++++++++++ .../yarn/util/HadoopConfigProvider.java | 42 +++++ pom.xml | 2 + project/SparkBuild.scala | 8 +- .../spark/deploy/yarn/ExecutorRunnable.scala | 16 ++ .../spark/deploy/yarn/ExecutorRunnable.scala | 16 ++ 12 files changed, 483 insertions(+), 16 deletions(-) create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java create mode 100644 network/yarn/pom.xml create mode 100644 network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java create mode 100644 network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index c11f1db0064fd..ef93009a074e7 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -66,7 +66,6 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // Lower and upper bounds on the number of executors. These are required. private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1) private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1) - verifyBounds() // How long there must be backlogged tasks for before an addition is triggered private val schedulerBacklogTimeout = conf.getLong( @@ -77,9 +76,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout) // How long an executor must be idle for before it is removed - private val removeThresholdSeconds = conf.getLong( + private val executorIdleTimeout = conf.getLong( "spark.dynamicAllocation.executorIdleTimeout", 600) + // During testing, the methods to actually kill and add executors are mocked out + private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) + + validateSettings() + // Number of executors to add in the next round private var numExecutorsToAdd = 1 @@ -103,17 +107,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // Polling loop interval (ms) private val intervalMillis: Long = 100 - // Whether we are testing this class. This should only be used internally. - private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) - // Clock used to schedule when executors should be added and removed private var clock: Clock = new RealClock /** - * Verify that the lower and upper bounds on the number of executors are valid. + * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. */ - private def verifyBounds(): Unit = { + private def validateSettings(): Unit = { if (minNumExecutors < 0 || maxNumExecutors < 0) { throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!") } @@ -124,6 +125,22 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " + s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!") } + if (schedulerBacklogTimeout <= 0) { + throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!") + } + if (sustainedSchedulerBacklogTimeout <= 0) { + throw new SparkException( + "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") + } + if (executorIdleTimeout <= 0) { + throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") + } + // Require external shuffle service for dynamic allocation + // Otherwise, we may lose shuffle files when killing executors + if (!conf.getBoolean("spark.shuffle.service.enabled", false) && !testing) { + throw new SparkException("Dynamic allocation of executors requires the external " + + "shuffle service. You may enable this through spark.shuffle.service.enabled.") + } } /** @@ -254,7 +271,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging val removeRequestAcknowledged = testing || sc.killExecutor(executorId) if (removeRequestAcknowledged) { logInfo(s"Removing executor $executorId because it has been idle for " + - s"$removeThresholdSeconds seconds (new desired total will be ${numExistingExecutors - 1})") + s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})") executorsPendingToRemove.add(executorId) true } else { @@ -329,8 +346,8 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging private def onExecutorIdle(executorId: String): Unit = synchronized { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $removeThresholdSeconds seconds)") - removeTimes(executorId) = clock.getTimeMillis + removeThresholdSeconds * 1000 + s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") + removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a5fb87b9b2c51..e48d7772d6ee9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -40,7 +40,6 @@ import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager -import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.util._ private[spark] sealed trait BlockValues @@ -97,7 +96,12 @@ private[spark] class BlockManager( private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) - private val externalShuffleServicePort = conf.getInt("spark.shuffle.service.port", 7337) + + // Port used by the external shuffle service. In Yarn mode, this may be already be + // set through the Hadoop configuration as the server is launched in the Yarn NM. + private val externalShuffleServicePort = + Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + // Check that we're not using external shuffle service with consolidated shuffle files. if (externalShuffleServiceEnabled && conf.getBoolean("spark.shuffle.consolidateFiles", false) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6ab94af9f3739..7caf6bcf94ef3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -45,6 +45,7 @@ import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} /** CallSite represents a place in user code. It can have a short and a long form. */ @@ -1780,6 +1781,21 @@ private[spark] object Utils extends Logging { val manifest = new JarManifest(manifestUrl.openStream()) manifest.getMainAttributes.getValue(Name.IMPLEMENTATION_VERSION) }.getOrElse("Unknown") + + /** + * Return the value of a config either through the SparkConf or the Hadoop configuration + * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf + * if the key is not set in the Hadoop configuration. + */ + def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = { + val sparkValue = conf.get(key, default) + if (SparkHadoopUtil.get.isYarnMode) { + SparkHadoopUtil.get.newConfiguration(conf).get(key, sparkValue) + } else { + sparkValue + } + } + } /** diff --git a/make-distribution.sh b/make-distribution.sh index 0bc839e1dbe4d..fac7f7e284be4 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -181,6 +181,9 @@ echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DI # Copy jars cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" +cp "$FWDIR"/network/yarn/target/scala*/spark-network-yarn*.jar "$DISTDIR/lib/" +cp "$FWDIR"/network/yarn/target/scala*/spark-network-shuffle*.jar "$DISTDIR/lib/" +cp "$FWDIR"/network/yarn/target/scala*/spark-network-common*.jar "$DISTDIR/lib/" # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java new file mode 100644 index 0000000000000..e66c4af0f1ebd --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.lang.Override; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.util.concurrent.ConcurrentHashMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.sasl.SecretKeyHolder; + +/** + * A class that manages shuffle secret used by the external shuffle service. + */ +public class ShuffleSecretManager implements SecretKeyHolder { + private final Logger logger = LoggerFactory.getLogger(ShuffleSecretManager.class); + private final ConcurrentHashMap shuffleSecretMap; + + private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + + // Spark user used for authenticating SASL connections + // Note that this must match the value in org.apache.spark.SecurityManager + private static final String SPARK_SASL_USER = "sparkSaslUser"; + + /** + * Convert the given string to a byte buffer. The resulting buffer can be converted back to + * the same string through {@link #bytesToString(ByteBuffer)}. This is used if the external + * shuffle service represents shuffle secrets as bytes buffers instead of strings. + */ + public static ByteBuffer stringToBytes(String s) { + return ByteBuffer.wrap(s.getBytes(UTF8_CHARSET)); + } + + /** + * Convert the given byte buffer to a string. The resulting string can be converted back to + * the same byte buffer through {@link #stringToBytes(String)}. This is used if the external + * shuffle service represents shuffle secrets as bytes buffers instead of strings. + */ + public static String bytesToString(ByteBuffer b) { + return new String(b.array(), UTF8_CHARSET); + } + + public ShuffleSecretManager() { + shuffleSecretMap = new ConcurrentHashMap(); + } + + /** + * Register an application with its secret. + * Executors need to first authenticate themselves with the same secret before + * fetching shuffle files written by other executors in this application. + */ + public void registerApp(String appId, String shuffleSecret) { + if (!shuffleSecretMap.contains(appId)) { + shuffleSecretMap.put(appId, shuffleSecret); + logger.info("Registered shuffle secret for application {}", appId); + } else { + logger.debug("Application {} already registered", appId); + } + } + + /** + * Register an application with its secret specified as a byte buffer. + */ + public void registerApp(String appId, ByteBuffer shuffleSecret) { + registerApp(appId, bytesToString(shuffleSecret)); + } + + /** + * Unregister an application along with its secret. + * This is called when the application terminates. + */ + public void unregisterApp(String appId) { + if (shuffleSecretMap.contains(appId)) { + shuffleSecretMap.remove(appId); + logger.info("Unregistered shuffle secret for application {}", appId); + } else { + logger.warn("Attempted to unregister application {} when it is not registered", appId); + } + } + + /** + * Return the Spark user for authenticating SASL connections. + */ + @Override + public String getSaslUser(String appId) { + return SPARK_SASL_USER; + } + + /** + * Return the secret key registered with the given application. + * This key is used to authenticate the executors before they can fetch shuffle files + * written by this application from the external shuffle service. If the specified + * application is not registered, return null. + */ + @Override + public String getSecretKey(String appId) { + return shuffleSecretMap.get(appId); + } +} diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml new file mode 100644 index 0000000000000..e60d8c1f7876c --- /dev/null +++ b/network/yarn/pom.xml @@ -0,0 +1,58 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.2.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-yarn_2.10 + jar + Spark Project Yarn Shuffle Service Code + http://spark.apache.org/ + + network-yarn + + + + + + org.apache.spark + spark-network-shuffle_2.10 + ${project.version} + + + + + org.apache.hadoop + hadoop-client + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java new file mode 100644 index 0000000000000..bb0b8f7e6cba6 --- /dev/null +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn; + +import java.lang.Override; +import java.nio.ByteBuffer; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.server.api.AuxiliaryService; +import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; +import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext; +import org.apache.hadoop.yarn.server.api.ContainerInitializationContext; +import org.apache.hadoop.yarn.server.api.ContainerTerminationContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.ShuffleSecretManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.yarn.util.HadoopConfigProvider; + +/** + * An external shuffle service used by Spark on Yarn. + * + * This is intended to be a long-running auxiliary service that runs in the NodeManager process. + * A Spark application may connect to this service by setting `spark.shuffle.service.enabled`. + * The application also automatically derives the service port through `spark.shuffle.service.port` + * specified in the Yarn configuration. This is so that both the clients and the server agree on + * the same port to communicate on. + * + * The service also optionally supports authentication. This ensures that executors from one + * application cannot read the shuffle files written by those from another. This feature can be + * enabled by setting `spark.authenticate` in the Yarn configuration before starting the NM. + * Note that the Spark application must also set `spark.authenticate` manually and, unlike in + * the case of the service port, will not inherit this setting from the Yarn configuration. This + * is because an application running on the same Yarn cluster may choose to not use the external + * shuffle service, in which case its setting of `spark.authenticate` should be independent of + * the service's. + */ +public class YarnShuffleService extends AuxiliaryService { + private final Logger logger = LoggerFactory.getLogger(YarnShuffleService.class); + + // Port on which the shuffle server listens for fetch requests + private static final String SPARK_SHUFFLE_SERVICE_PORT_KEY = "spark.shuffle.service.port"; + private static final int DEFAULT_SPARK_SHUFFLE_SERVICE_PORT = 7337; + + // Whether the shuffle server should authenticate fetch requests + private static final String SPARK_AUTHENTICATE_KEY = "spark.authenticate"; + private static final boolean DEFAULT_SPARK_AUTHENTICATE = false; + + // An entity that manages the shuffle secret per application + // This is used only if authentication is enabled + private ShuffleSecretManager secretManager; + + // The actual server that serves shuffle files + private TransportServer shuffleServer = null; + + public YarnShuffleService() { + super("spark_shuffle"); + logger.info("Initializing YARN shuffle service for Spark"); + } + + /** + * Return whether authentication is enabled as specified by the configuration. + * If so, fetch requests will fail unless the appropriate authentication secret + * for the application is provided. + */ + private boolean isAuthenticationEnabled() { + return secretManager != null; + } + + /** + * Start the shuffle server with the given configuration. + */ + @Override + protected void serviceInit(Configuration conf) { + // If authentication is enabled, set up the shuffle server to use a + // special RPC handler that filters out unauthenticated fetch requests + boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); + RpcHandler rpcHandler = new ExternalShuffleBlockHandler(); + if (authEnabled) { + secretManager = new ShuffleSecretManager(); + rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); + } + + int port = conf.getInt( + SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); + TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); + TransportContext transportContext = new TransportContext(transportConf, rpcHandler); + shuffleServer = transportContext.createServer(port); + String authEnabledString = authEnabled ? "enabled" : "not enabled"; + logger.info("Started YARN shuffle service for Spark on port {}. " + + "Authentication is {}.", port, authEnabledString); + } + + @Override + public void initializeApplication(ApplicationInitializationContext context) { + String appId = context.getApplicationId().toString(); + try { + ByteBuffer shuffleSecret = context.getApplicationDataForService(); + logger.info("Initializing application {}", appId); + if (isAuthenticationEnabled()) { + secretManager.registerApp(appId, shuffleSecret); + } + } catch (Exception e) { + logger.error("Exception when initializing application {}", appId, e); + } + } + + @Override + public void stopApplication(ApplicationTerminationContext context) { + String appId = context.getApplicationId().toString(); + try { + logger.info("Stopping application {}", appId); + if (isAuthenticationEnabled()) { + secretManager.unregisterApp(appId); + } + } catch (Exception e) { + logger.error("Exception when stopping application {}", appId, e); + } + } + + @Override + public void initializeContainer(ContainerInitializationContext context) { + ContainerId containerId = context.getContainerId(); + logger.info("Initializing container {}", containerId); + } + + @Override + public void stopContainer(ContainerTerminationContext context) { + ContainerId containerId = context.getContainerId(); + logger.info("Stopping container {}", containerId); + } + + /** + * Close the shuffle server to clean up any associated state. + */ + @Override + protected void serviceStop() { + try { + if (shuffleServer != null) { + shuffleServer.close(); + } + } catch (Exception e) { + logger.error("Exception when stopping service", e); + } + } + + // Not currently used + @Override + public ByteBuffer getMetaData() { + return ByteBuffer.allocate(0); + } + +} diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java new file mode 100644 index 0000000000000..884861752e80d --- /dev/null +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn.util; + +import java.util.NoSuchElementException; + +import org.apache.hadoop.conf.Configuration; + +import org.apache.spark.network.util.ConfigProvider; + +/** Use the Hadoop configuration to obtain config values. */ +public class HadoopConfigProvider extends ConfigProvider { + private final Configuration conf; + + public HadoopConfigProvider(Configuration conf) { + this.conf = conf; + } + + @Override + public String get(String name) { + String value = conf.get(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/pom.xml b/pom.xml index eb613531b8a5f..88ef67c515b3a 100644 --- a/pom.xml +++ b/pom.xml @@ -1229,6 +1229,7 @@ yarn-alpha yarn + network/yarn @@ -1236,6 +1237,7 @@ yarn yarn + network/yarn diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 33618f5401768..657e4b4432775 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -38,9 +38,9 @@ object BuildCommons { "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) - val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = - Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl") - .map(ProjectRef(buildLocation, _)) + val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, networkYarn, java8Tests, + sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "yarn-alpha", "network-yarn", + "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") .map(ProjectRef(buildLocation, _)) @@ -143,7 +143,7 @@ object SparkBuild extends PomBuild { // TODO: Add Sql to mima checks allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - streamingFlumeSink, networkCommon, networkShuffle).contains(x)).foreach { + streamingFlumeSink, networkCommon, networkShuffle, networkYarn).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 7ee4b5c842df1..5f47c79cabaee 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils} import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.network.sasl.ShuffleSecretManager @deprecated("use yarn/stable", "1.2.0") class ExecutorRunnable( @@ -90,6 +91,21 @@ class ExecutorRunnable( ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // If external shuffle service is enabled, register with the Yarn shuffle service already + // started on the NodeManager and, if authentication is enabled, provide it with our secret + // key for fetching shuffle files later + if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) { + val secretString = securityMgr.getSecretKey() + val secretBytes = + if (secretString != null) { + ShuffleSecretManager.stringToBytes(secretString) + } else { + // Authentication is not enabled, so just provide dummy metadata + ByteBuffer.allocate(0) + } + ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + } + // Send the start request to the ContainerManager val startReq = Records.newRecord(classOf[StartContainerRequest]) .asInstanceOf[StartContainerRequest] diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 0b5a92d87d722..18f48b4b6caf6 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.network.sasl.ShuffleSecretManager class ExecutorRunnable( @@ -89,6 +90,21 @@ class ExecutorRunnable( ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // If external shuffle service is enabled, register with the Yarn shuffle service already + // started on the NodeManager and, if authentication is enabled, provide it with our secret + // key for fetching shuffle files later + if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) { + val secretString = securityMgr.getSecretKey() + val secretBytes = + if (secretString != null) { + ShuffleSecretManager.stringToBytes(secretString) + } else { + // Authentication is not enabled, so just provide dummy metadata + ByteBuffer.allocate(0) + } + ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + } + // Send the start request to the ContainerManager nmClient.startContainer(container, ctx) } From 868cd4c3ca11e6ecc4425b972d9a20c360b52425 Mon Sep 17 00:00:00 2001 From: "jay@apache.org" Date: Wed, 5 Nov 2014 15:45:34 -0800 Subject: [PATCH 42/79] SPARK-4040. Update documentation to exemplify use of local (n) value, fo... This is a minor docs update which helps to clarify the way local[n] is used for streaming apps. Author: jay@apache.org Closes #2964 from jayunit100/SPARK-4040 and squashes the following commits: 35b5a5e [jay@apache.org] SPARK-4040: Update documentation to exemplify use of local (n) value. --- docs/configuration.md | 10 ++++++++-- docs/streaming-programming-guide.md | 14 +++++++++----- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 685101ea5c9c9..0f9eb81f6e993 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -21,16 +21,22 @@ application. These properties can be set directly on a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) passed to your `SparkContext`. `SparkConf` allows you to configure some of the common properties (e.g. master URL and application name), as well as arbitrary key-value pairs through the -`set()` method. For example, we could initialize an application as follows: +`set()` method. For example, we could initialize an application with two threads as follows: + +Note that we run with local[2], meaning two threads - which represents "minimal" parallelism, +which can help detect bugs that only exist when we run in a distributed context. {% highlight scala %} val conf = new SparkConf() - .setMaster("local") + .setMaster("local[2]") .setAppName("CountingSheep") .set("spark.executor.memory", "1g") val sc = new SparkContext(conf) {% endhighlight %} +Note that we can have more than 1 thread in local mode, and in cases like spark streaming, we may actually +require one to prevent any sort of starvation issues. + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 8bbba88b31978..44a1f3ad7560b 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -68,7 +68,9 @@ import org.apache.spark._ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ -// Create a local StreamingContext with two working thread and batch interval of 1 second +// Create a local StreamingContext with two working thread and batch interval of 1 second. +// The master requires 2 cores to prevent from a starvation scenario. + val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") val ssc = new StreamingContext(conf, Seconds(1)) {% endhighlight %} @@ -586,11 +588,13 @@ Every input DStream (except file stream) is associated with a single [Receiver]( A receiver is run within a Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the Spark Streaming application. Hence, it is important to remember that Spark Streaming application needs to be allocated enough cores to process the received data, as well as, to run the receiver(s). Therefore, few important points to remember are: -##### Points to remember: +##### Points to remember {:.no_toc} -- If the number of cores allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them. -- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs with even one input DStream (file streams are okay) as the receiver will occupy that core and there will be no core left to process the data. - +- If the number of threads allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them. +- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs using a DStream as the receiver (file streams are okay). So, a "local" master URL in a streaming app is generally going to cause starvation for the processor. +Thus in any streaming app, you generally will want to allocate more than one thread (i.e. set your master to "local[2]") when testing locally. +See [Spark Properties] (configuration.html#spark-properties.html). + ### Basic Sources {:.no_toc} From f7ac8c2b1de96151231617846b7468d23379c74a Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Wed, 5 Nov 2014 15:49:42 -0800 Subject: [PATCH 43/79] SPARK-3223 runAsSparkUser cannot change HDFS write permission properly i... ...n mesos cluster mode - change master newer Author: Jongyoul Lee Closes #3034 from jongyoul/SPARK-3223 and squashes the following commits: 42b2ed3 [Jongyoul Lee] SPARK-3223 runAsSparkUser cannot change HDFS write permission properly in mesos cluster mode - change master newer --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d8c0e2f66df01..e4b859846035c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -93,7 +93,7 @@ private[spark] class CoarseMesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { { val ret = driver.run() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 8e2faff90f9b2..7d097a3a7aaa3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -72,7 +72,7 @@ private[spark] class MesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { val ret = driver.run() From cb0eae3b78d7f6f56c0b9521ee48564a4967d3de Mon Sep 17 00:00:00 2001 From: Brenden Matthews Date: Wed, 5 Nov 2014 16:02:44 -0800 Subject: [PATCH 44/79] [SPARK-4158] Fix for missing resources. Mesos offers may not contain all resources, and Spark needs to check to ensure they are present and sufficient. Spark may throw an erroneous exception when resources aren't present. Author: Brenden Matthews Closes #3024 from brndnmtthws/fix-mesos-resource-misuse and squashes the following commits: e5f9580 [Brenden Matthews] [SPARK-4158] Fix for missing resources. --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 3 +-- .../spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index e4b859846035c..5289661eb896b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -242,8 +242,7 @@ private[spark] class CoarseMesosSchedulerBackend( for (r <- res if r.getName == name) { return r.getScalar.getValue } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) + 0 } /** Build a Mesos resource protobuf object */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 7d097a3a7aaa3..c5f3493477bc5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -278,8 +278,7 @@ private[spark] class MesosSchedulerBackend( for (r <- res if r.getName == name) { return r.getScalar.getValue } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) + 0 } /** Turn a Spark TaskDescription into a Mesos task */ From c315d1316cb2372e90ae3a12f72d5b3304435a6b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 5 Nov 2014 19:51:18 -0800 Subject: [PATCH 45/79] [SPARK-4254] [mllib] MovieLensALS bug fix Changed code so it does not try to serialize Params. CC: mengxr debasish83 srowen Author: Joseph K. Bradley Closes #3116 from jkbradley/als-bugfix and squashes the following commits: e575bd8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into als-bugfix 9401b16 [Joseph K. Bradley] changed implicitPrefs so it is not serialized to fix MovieLensALS example bug --- .../scala/org/apache/spark/examples/mllib/MovieLensALS.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 8796c28db8a66..91a0a860d6c71 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -106,9 +106,11 @@ object MovieLensALS { Logger.getRootLogger.setLevel(Level.WARN) + val implicitPrefs = params.implicitPrefs + val ratings = sc.textFile(params.input).map { line => val fields = line.split("::") - if (params.implicitPrefs) { + if (implicitPrefs) { /* * MovieLens ratings are on a scale of 1-5: * 5: Must see From 3d2b5bc5bb979d8b0b71e06bc0f4548376fdbb98 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 5 Nov 2014 19:56:16 -0800 Subject: [PATCH 46/79] [SPARK-4262][SQL] add .schemaRDD to JavaSchemaRDD marmbrus Author: Xiangrui Meng Closes #3125 from mengxr/SPARK-4262 and squashes the following commits: 307695e [Xiangrui Meng] add .schemaRDD to JavaSchemaRDD --- .../scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 1e0ccb368a276..78e8d908fe0c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -47,6 +47,9 @@ class JavaSchemaRDD( private[sql] val baseSchemaRDD = new SchemaRDD(sqlContext, logicalPlan) + /** Returns the underlying Scala SchemaRDD. */ + val schemaRDD: SchemaRDD = baseSchemaRDD + override val classTag = scala.reflect.classTag[Row] override def wrapRDD(rdd: RDD[Row]): JavaRDD[Row] = JavaRDD.fromRDD(rdd) From db45f5ad0368760dbeaa618a04f66ae9b2bed656 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Wed, 5 Nov 2014 20:45:35 -0800 Subject: [PATCH 47/79] [SPARK-4137] [EC2] Don't change working dir on user This issue was uncovered after [this discussion](https://issues.apache.org/jira/browse/SPARK-3398?focusedCommentId=14187471&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-14187471). Don't change the working directory on the user. This breaks relative paths the user may pass in, e.g., for the SSH identity file. ``` ./ec2/spark-ec2 -i ../my.pem ``` This patch will preserve the user's current working directory and allow calls like the one above to work. Author: Nicholas Chammas Closes #2988 from nchammas/spark-ec2-cwd and squashes the following commits: f3850b5 [Nicholas Chammas] pep8 fix fbc20c7 [Nicholas Chammas] revert to old commenting style 752f958 [Nicholas Chammas] specify deploy.generic path absolutely bcdf6a5 [Nicholas Chammas] fix typo 77871a2 [Nicholas Chammas] add clarifying comment ce071fc [Nicholas Chammas] don't change working dir --- ec2/spark-ec2 | 8 ++++++-- ec2/spark_ec2.py | 12 +++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 index 31f9771223e51..4aa908242eeaa 100755 --- a/ec2/spark-ec2 +++ b/ec2/spark-ec2 @@ -18,5 +18,9 @@ # limitations under the License. # -cd "`dirname $0`" -PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py "$@" +# Preserve the user's CWD so that relative paths are passed correctly to +#+ the underlying Python script. +SPARK_EC2_DIR="$(dirname $0)" + +PYTHONPATH="${SPARK_EC2_DIR}/third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" \ + python "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 50f88f735650e..a5396c2375915 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -40,6 +40,7 @@ from boto import ec2 DEFAULT_SPARK_VERSION = "1.1.0" +SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) MESOS_SPARK_EC2_BRANCH = "v4" # A URL prefix from which to fetch AMI information @@ -593,7 +594,14 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): ) print "Deploying files to master..." - deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules) + deploy_files( + conn=conn, + root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", + opts=opts, + master_nodes=master_nodes, + slave_nodes=slave_nodes, + modules=modules + ) print "Running setup on master..." setup_spark_cluster(master, opts) @@ -730,6 +738,8 @@ def get_num_disks(instance_type): # cluster (e.g. lists of masters and slaves). Files are only deployed to # the first master instance in the cluster, and we expect the setup # script to be run on that instance to copy them to other nodes. +# +# root_dir should be an absolute path to the directory with the files we want to deploy. def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): active_master = master_nodes[0].public_dns_name From 5f27ae16d5b016fae4afeb0f2ad779fd3130b390 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Thu, 6 Nov 2014 00:03:03 -0800 Subject: [PATCH 48/79] [SPARK-4255] Fix incorrect table striping This commit stripes table rows after hiding some rows, to ensure that rows are correct striped to alternate white and grey even when rows are hidden by default. Author: Kay Ousterhout Closes #3117 from kayousterhout/striping and squashes the following commits: be6e10a [Kay Ousterhout] [SPARK-4255] Fix incorrect table striping --- .../org/apache/spark/ui/static/additional-metrics.js | 2 ++ core/src/main/resources/org/apache/spark/ui/static/table.js | 5 ----- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index c5936b5038ac9..badd85ed48c82 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -39,6 +39,8 @@ $(function() { var column = "table ." + $(this).attr("name"); $(column).hide(); }); + // Stripe table rows after rows have been hidden to ensure correct striping. + stripeTables(); $("input:checkbox").click(function() { var column = "table ." + $(this).attr("name"); diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 32187ba6e8df0..6bb03015abb51 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -28,8 +28,3 @@ function stripeTables() { }); }); } - -/* Stripe all tables after pages finish loading. */ -$(function() { - stripeTables(); -}); From b41a39e24038876359aeb7ce2bbbb4de2234e5f3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Nov 2014 00:22:19 -0800 Subject: [PATCH 49/79] [SPARK-4186] add binaryFiles and binaryRecords in Python add binaryFiles() and binaryRecords() in Python ``` binaryFiles(self, path, minPartitions=None): :: Developer API :: Read a directory of binary files from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI as a byte array. Each file is read as a single record and returned in a key-value pair, where the key is the path of each file, the value is the content of each file. Note: Small files are preferred, large file is also allowable, but may cause bad performance. binaryRecords(self, path, recordLength): Load data from a flat binary file, assuming each record is a set of numbers with the specified numerical format (see ByteBuffer), and the number of bytes per record is constant. :param path: Directory to the input data files :param recordLength: The length at which to split the records ``` Author: Davies Liu Closes #3078 from davies/binary and squashes the following commits: cd0bdbd [Davies Liu] Merge branch 'master' of github.com:apache/spark into binary 3aa349b [Davies Liu] add experimental notes 24e84b6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into binary 5ceaa8a [Davies Liu] Merge branch 'master' of github.com:apache/spark into binary 1900085 [Davies Liu] bugfix bb22442 [Davies Liu] add binaryFiles and binaryRecords in Python --- .../scala/org/apache/spark/SparkContext.scala | 4 ++ .../spark/api/java/JavaSparkContext.scala | 12 ++--- .../apache/spark/api/python/PythonRDD.scala | 45 ++++++++++++------- python/pyspark/context.py | 32 ++++++++++++- python/pyspark/tests.py | 19 ++++++++ 5 files changed, 90 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3cdaa6a9cc8a8..03ea672c813d1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -560,6 +560,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { /** + * :: Experimental :: + * * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file * (useful for binary data) * @@ -602,6 +604,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { } /** + * :: Experimental :: + * * Load data from a flat binary file, assuming the length of each record is constant. * * @param path Directory to the input data files diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index e3aeba7e6c39d..5c6e8d32c5c8a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -21,11 +21,6 @@ import java.io.Closeable import java.util import java.util.{Map => JMap} -import java.io.DataInputStream - -import org.apache.hadoop.io.{BytesWritable, LongWritable} -import org.apache.spark.input.{PortableDataStream, FixedLengthBinaryInputFormat} - import scala.collection.JavaConversions import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -33,6 +28,7 @@ import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration +import org.apache.spark.input.PortableDataStream import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} @@ -286,6 +282,8 @@ class JavaSparkContext(val sc: SparkContext) new JavaPairRDD(sc.binaryFiles(path, minPartitions)) /** + * :: Experimental :: + * * Read a directory of binary files from HDFS, a local file system (available on all nodes), * or any Hadoop-supported file system URI as a byte array. Each file is read as a single * record and returned in a key-value pair, where the key is the path of each file, @@ -312,15 +310,19 @@ class JavaSparkContext(val sc: SparkContext) * * @note Small files are preferred; very large files but may cause bad performance. */ + @Experimental def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] = new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions)) /** + * :: Experimental :: + * * Load data from a flat binary file, assuming the length of each record is constant. * * @param path Directory to the input data files * @return An RDD of data with values, represented as byte arrays */ + @Experimental def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = { new JavaRDD(sc.binaryRecords(path, recordLength)) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index e94ccdcd47bb7..45beb8fc8c925 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -21,6 +21,8 @@ import java.io._ import java.net._ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} +import org.apache.spark.input.PortableDataStream + import scala.collection.JavaConversions._ import scala.collection.mutable import scala.language.existentials @@ -395,22 +397,33 @@ private[spark] object PythonRDD extends Logging { newIter.asInstanceOf[Iterator[String]].foreach { str => writeUTF(str, dataOut) } - case pair: Tuple2[_, _] => - pair._1 match { - case bytePair: Array[Byte] => - newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair => - dataOut.writeInt(pair._1.length) - dataOut.write(pair._1) - dataOut.writeInt(pair._2.length) - dataOut.write(pair._2) - } - case stringPair: String => - newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair => - writeUTF(pair._1, dataOut) - writeUTF(pair._2, dataOut) - } - case other => - throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass) + case stream: PortableDataStream => + newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream => + val bytes = stream.toArray() + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case (key: String, stream: PortableDataStream) => + newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach { + case (key, stream) => + writeUTF(key, dataOut) + val bytes = stream.toArray() + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case (key: String, value: String) => + newIter.asInstanceOf[Iterator[(String, String)]].foreach { + case (key, value) => + writeUTF(key, dataOut) + writeUTF(value, dataOut) + } + case (key: Array[Byte], value: Array[Byte]) => + newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { + case (key, value) => + dataOut.writeInt(key.length) + dataOut.write(key) + dataOut.writeInt(value.length) + dataOut.write(value) } case other => throw new SparkException("Unexpected element type " + first.getClass) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a0e4821728c8b..faa5952258aef 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer, AutoBatchedSerializer + PairDeserializer, CompressedSerializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -388,6 +388,36 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): return RDD(self._jsc.wholeTextFiles(path, minPartitions), self, PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode))) + def binaryFiles(self, path, minPartitions=None): + """ + :: Experimental :: + + Read a directory of binary files from HDFS, a local file system + (available on all nodes), or any Hadoop-supported file system URI + as a byte array. Each file is read as a single record and returned + in a key-value pair, where the key is the path of each file, the + value is the content of each file. + + Note: Small files are preferred, large file is also allowable, but + may cause bad performance. + """ + minPartitions = minPartitions or self.defaultMinPartitions + return RDD(self._jsc.binaryFiles(path, minPartitions), self, + PairDeserializer(UTF8Deserializer(), NoOpSerializer())) + + def binaryRecords(self, path, recordLength): + """ + :: Experimental :: + + Load data from a flat binary file, assuming each record is a set of numbers + with the specified numerical format (see ByteBuffer), and the number of + bytes per record is constant. + + :param path: Directory to the input data files + :param recordLength: The length at which to split the records + """ + return RDD(self._jsc.binaryRecords(path, recordLength), self, NoOpSerializer()) + def _dictToJavaMap(self, d): jm = self._jvm.java.util.HashMap() if not d: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7e61b017efa75..9f625c5c6ca48 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1110,6 +1110,25 @@ def test_converters(self): (u'\x03', [2.0])] self.assertEqual(maps, em) + def test_binary_files(self): + path = os.path.join(self.tempdir.name, "binaryfiles") + os.mkdir(path) + data = "short binary data" + with open(os.path.join(path, "part-0000"), 'w') as f: + f.write(data) + [(p, d)] = self.sc.binaryFiles(path).collect() + self.assertTrue(p.endswith("part-0000")) + self.assertEqual(d, data) + + def test_binary_records(self): + path = os.path.join(self.tempdir.name, "binaryrecords") + os.mkdir(path) + with open(os.path.join(path, "part-0000"), 'w') as f: + for i in range(100): + f.write('%04d' % i) + result = self.sc.binaryRecords(path, 4).map(int).collect() + self.assertEqual(range(100), result) + class OutputFormatTests(ReusedPySparkTestCase): From 6e87d72e762da793ca46071f9415b1917ade555a Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Thu, 6 Nov 2014 06:22:49 -0600 Subject: [PATCH 50/79] Removed accidentlay extraneous import from Row.scala. --- .../scala/org/apache/spark/sql/catalyst/expressions/Row.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 99b9e6efbab90..c849f60bafe20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types.NativeType import java.sql.{Date, Timestamp} -import java.math.BigDecimal object Row { /** From a5205b55108a081bb2425ee7926e134b19b8c531 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Thu, 6 Nov 2014 06:25:47 -0600 Subject: [PATCH 51/79] ... and removed another extraneous import. --- .../scala/org/apache/spark/sql/catalyst/expressions/Row.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index c849f60bafe20..fab0cd31407f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types.NativeType -import java.sql.{Date, Timestamp} + object Row { /** From 76a18dc7df4e4c66a683c52706a1710281d34fb7 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Thu, 6 Nov 2014 06:27:36 -0600 Subject: [PATCH 52/79] Tiny style issue. --- .../scala/org/apache/spark/sql/catalyst/expressions/Row.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index fab0cd31407f1..d00ec39774c35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types.NativeType - object Row { /** * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: From b6a43748c8478074bcc92d02c3a5ff9b6ee6c914 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Thu, 6 Nov 2014 07:04:53 -0600 Subject: [PATCH 53/79] Cleaning up comments. --- .../catalyst/expressions/ExpressionEvaluationSuite.scala | 6 +++--- .../org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala | 6 +++--- .../test/scala/org/apache/spark/sql/CachedTableSuite.scala | 6 +++--- .../src/test/scala/org/apache/spark/sql/DslQuerySuite.scala | 6 +++--- .../src/test/scala/org/apache/spark/sql/JoinSuite.scala | 6 +++--- .../src/test/scala/org/apache/spark/sql/SQLConfSuite.scala | 6 +++--- sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | 6 +++--- .../spark/sql/columnar/PartitionBatchPruningSuite.scala | 6 +++--- .../scala/org/apache/spark/sql/execution/PlannerSuite.scala | 6 +++--- .../org/apache/spark/sql/parquet/ParquetQuerySuite.scala | 6 +++--- .../scala/org/apache/spark/sql/hive/StatisticsSuite.scala | 6 +++--- .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 6 +++--- .../org/apache/spark/sql/parquet/HiveParquetSuite.scala | 6 +++--- 13 files changed, 39 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 36a4adf9321a1..472e97cad79b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -33,9 +33,9 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.dsl.expressions._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 82887dc9d4604..3586e6557aa14 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{StringType, NullType} /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 9eca55707ae52..c5be93044d213 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index c6698d41f1c01..6532f0b4deb0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -26,9 +26,9 @@ import org.apache.spark.sql.test._ import TestSQLContext._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 0260eea467d1c..3d1f1801778c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -26,9 +26,9 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 4a09ed517d6e1..df22ec9b1bcd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.test._ import TestSQLContext._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 95582eec06975..5d5c2b0b5168e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.test._ import TestSQLContext._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 45086f88a27ff..a6f31e0e15302 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 0ed380c63527a..2c43c110e32b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 25031910c30de..0f53e9736fb20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index f3323952b75d5..3565b377ef602 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index b5e70e0f30599..9da56de200b11 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{Row, SchemaRDD} /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index be1fbca69708f..acefbe66dbd91 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -30,9 +30,9 @@ import org.apache.spark.util.Utils import org.apache.spark.sql.hive.test.TestHive._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(convertToEqualizer(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) * (This file already imports convertToEqualizer) */ From 0dc0ff004ef78237bca3465fb718837e99e4064c Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Thu, 6 Nov 2014 10:02:41 -0600 Subject: [PATCH 54/79] One last comment clarification. --- .../scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index acefbe66dbd91..0069726426f96 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.hive.test.TestHive._ /* * Note: the DSL conversions collide with the scalatest === operator! * We can apply the scalatest conversion explicitly: - * assert(X === Y) --> assert(EQ(X).===(Y)) + * assert(X === Y) --> assert(convertToEqualizer(X).===(Y)) * (This file already imports convertToEqualizer) */ From 23eaf0e12ff221dcca40a79e61b6cc5e7c846cb5 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 6 Nov 2014 10:45:46 -0800 Subject: [PATCH 55/79] [SPARK-4264] Completion iterator should only invoke callback once Author: Aaron Davidson Closes #3128 from aarondav/compiter and squashes the following commits: 698e4be [Aaron Davidson] [SPARK-4264] Completion iterator should only invoke callback once --- .../spark/util/CompletionIterator.scala | 5 +- .../spark/util/CompletionIteratorSuite.scala | 47 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index b6a099825f01b..390310243ee0a 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -25,10 +25,13 @@ private[spark] // scalastyle:off abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] { // scalastyle:on + + private[this] var completed = false def next() = sub.next() def hasNext = { val r = sub.hasNext - if (!r) { + if (!r && !completed) { + completed = true completion() } r diff --git a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala new file mode 100644 index 0000000000000..3755d43e25ea8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.scalatest.FunSuite + +class CompletionIteratorSuite extends FunSuite { + test("basic test") { + var numTimesCompleted = 0 + val iter = List(1, 2, 3).iterator + val completionIter = CompletionIterator[Int, Iterator[Int]](iter, { numTimesCompleted += 1 }) + + assert(completionIter.hasNext) + assert(completionIter.next() === 1) + assert(numTimesCompleted === 0) + + assert(completionIter.hasNext) + assert(completionIter.next() === 2) + assert(numTimesCompleted === 0) + + assert(completionIter.hasNext) + assert(completionIter.next() === 3) + assert(numTimesCompleted === 0) + + assert(!completionIter.hasNext) + assert(numTimesCompleted === 1) + + // SPARK-4264: Calling hasNext should not trigger the completion callback again. + assert(!completionIter.hasNext) + assert(numTimesCompleted === 1) + } +} From d15c6e9dc2860bbe56e31ddf71218ccc6d5c841d Mon Sep 17 00:00:00 2001 From: lianhuiwang Date: Thu, 6 Nov 2014 10:46:45 -0800 Subject: [PATCH 56/79] [SPARK-4249][GraphX]fix a problem of EdgePartitionBuilder in Graphx at first srcIds is not initialized and are all 0. so we use edgeArray(0).srcId to currSrcId Author: lianhuiwang Closes #3138 from lianhuiwang/SPARK-4249 and squashes the following commits: 3f4e503 [lianhuiwang] fix a problem of EdgePartitionBuilder in Graphx --- .../org/apache/spark/graphx/impl/EdgePartitionBuilder.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 4520beb991515..2b6137be25547 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -45,8 +45,8 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and // adding them to the index if (edgeArray.length > 0) { - index.update(srcIds(0), 0) - var currSrcId: VertexId = srcIds(0) + index.update(edgeArray(0).srcId, 0) + var currSrcId: VertexId = edgeArray(0).srcId var i = 0 while (i < edgeArray.size) { srcIds(i) = edgeArray(i).srcId From 0d2e389db819ee5b48da57fc7799f80fd60eaa92 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Thu, 6 Nov 2014 13:22:04 -0600 Subject: [PATCH 57/79] Adding a test which appeared after the PR and uses assert(X === Y). --- .../spark/sql/UserDefinedTypeSuite.scala | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala new file mode 100644 index 0000000000000..3c1505fd63e43 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.beans.{BeanInfo, BeanProperty} + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.types.UserDefinedType +import org.apache.spark.sql.test.TestSQLContext._ + +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + +@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) +private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { + override def equals(other: Any): Boolean = other match { + case v: MyDenseVector => + java.util.Arrays.equals(this.data, v.data) + case _ => false + } +} + +@BeanInfo +private[sql] case class MyLabeledPoint( + @BeanProperty label: Double, + @BeanProperty features: MyDenseVector) + +private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { + + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) + + override def serialize(obj: Any): Seq[Double] = { + obj match { + case features: MyDenseVector => + features.data.toSeq + } + } + + override def deserialize(datum: Any): MyDenseVector = { + datum match { + case data: Seq[_] => + new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) + } + } + + override def userClass = classOf[MyDenseVector] +} + +class UserDefinedTypeSuite extends QueryTest { + val points = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) + val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) + + + test("register user type: MyDenseVector for MyLabeledPoint") { + val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } + val labelsArrays: Array[Double] = labels.collect() + assert(EQ(labelsArrays.size).===(2)) + assert(labelsArrays.contains(1.0)) + assert(labelsArrays.contains(0.0)) + + val features: RDD[MyDenseVector] = + pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } + val featuresArrays: Array[MyDenseVector] = features.collect() + assert(EQ(featuresArrays.size).===(2)) + assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) + assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) + } + + test("UDTs and UDFs") { + registerFunction("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + pointsRDD.registerTempTable("points") + checkAnswer( + sql("SELECT testType(features) from points"), + Seq(Row(true), Row(true))) + } +} From 470881b24a503c9edcaed159c29bafa446ab0e9a Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 6 Nov 2014 15:31:07 -0800 Subject: [PATCH 58/79] [HOT FIX] Make distribution fails This was added by me in https://github.com/apache/spark/commit/61a5cced049a8056292ba94f23fa7bd040f50685. The real fix will be added in [SPARK-4281](https://issues.apache.org/jira/browse/SPARK-4281). Author: Andrew Or Closes #3145 from andrewor14/fix-make-distribution and squashes the following commits: c78be61 [Andrew Or] Hot fix make distribution --- make-distribution.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/make-distribution.sh b/make-distribution.sh index fac7f7e284be4..0bc839e1dbe4d 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -181,9 +181,6 @@ echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DI # Copy jars cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" -cp "$FWDIR"/network/yarn/target/scala*/spark-network-yarn*.jar "$DISTDIR/lib/" -cp "$FWDIR"/network/yarn/target/scala*/spark-network-shuffle*.jar "$DISTDIR/lib/" -cp "$FWDIR"/network/yarn/target/scala*/spark-network-common*.jar "$DISTDIR/lib/" # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" From 96136f222abd4f3abd10cb78a4ebecdb21f3bde7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 6 Nov 2014 17:18:49 -0800 Subject: [PATCH 59/79] [SPARK-3797] Minor addendum to Yarn shuffle service I did not realize there was a `network.util.JavaUtils` when I wrote this code. This PR moves the `ByteBuffer` string conversion to the appropriate place. I tested the changes on a stable yarn cluster. Author: Andrew Or Closes #3144 from andrewor14/yarn-shuffle-util and squashes the following commits: b6c08bf [Andrew Or] Remove unused import 94e205c [Andrew Or] Use netty Unpooled 85202a5 [Andrew Or] Use guava Charsets 057135b [Andrew Or] Reword comment adf186d [Andrew Or] Move byte buffer String conversion logic to JavaUtils --- .../apache/spark/network/util/JavaUtils.java | 20 ++++++++++++++++ .../network/sasl/ShuffleSecretManager.java | 24 ++----------------- .../spark/deploy/yarn/ExecutorRunnable.scala | 5 ++-- .../spark/deploy/yarn/ExecutorRunnable.scala | 5 ++-- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 40b71b0c87a47..2856d1c8c9337 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,6 +17,8 @@ package org.apache.spark.network.util; +import java.nio.ByteBuffer; + import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.Closeable; @@ -25,6 +27,8 @@ import java.io.ObjectOutputStream; import com.google.common.io.Closeables; +import com.google.common.base.Charsets; +import io.netty.buffer.Unpooled; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -73,4 +77,20 @@ public static int nonNegativeHash(Object obj) { int hash = obj.hashCode(); return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0; } + + /** + * Convert the given string to a byte buffer. The resulting buffer can be + * converted back to the same string through {@link #bytesToString(ByteBuffer)}. + */ + public static ByteBuffer stringToBytes(String s) { + return Unpooled.wrappedBuffer(s.getBytes(Charsets.UTF_8)).nioBuffer(); + } + + /** + * Convert the given byte buffer to a string. The resulting string can be + * converted back to the same byte buffer through {@link #stringToBytes(String)}. + */ + public static String bytesToString(ByteBuffer b) { + return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java index e66c4af0f1ebd..351c7930a900f 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -19,13 +19,13 @@ import java.lang.Override; import java.nio.ByteBuffer; -import java.nio.charset.Charset; import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.util.JavaUtils; /** * A class that manages shuffle secret used by the external shuffle service. @@ -34,30 +34,10 @@ public class ShuffleSecretManager implements SecretKeyHolder { private final Logger logger = LoggerFactory.getLogger(ShuffleSecretManager.class); private final ConcurrentHashMap shuffleSecretMap; - private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); - // Spark user used for authenticating SASL connections // Note that this must match the value in org.apache.spark.SecurityManager private static final String SPARK_SASL_USER = "sparkSaslUser"; - /** - * Convert the given string to a byte buffer. The resulting buffer can be converted back to - * the same string through {@link #bytesToString(ByteBuffer)}. This is used if the external - * shuffle service represents shuffle secrets as bytes buffers instead of strings. - */ - public static ByteBuffer stringToBytes(String s) { - return ByteBuffer.wrap(s.getBytes(UTF8_CHARSET)); - } - - /** - * Convert the given byte buffer to a string. The resulting string can be converted back to - * the same byte buffer through {@link #stringToBytes(String)}. This is used if the external - * shuffle service represents shuffle secrets as bytes buffers instead of strings. - */ - public static String bytesToString(ByteBuffer b) { - return new String(b.array(), UTF8_CHARSET); - } - public ShuffleSecretManager() { shuffleSecretMap = new ConcurrentHashMap(); } @@ -80,7 +60,7 @@ public void registerApp(String appId, String shuffleSecret) { * Register an application with its secret specified as a byte buffer. */ public void registerApp(String appId, ByteBuffer shuffleSecret) { - registerApp(appId, bytesToString(shuffleSecret)); + registerApp(appId, JavaUtils.bytesToString(shuffleSecret)); } /** diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 5f47c79cabaee..7023a1170654f 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,7 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils} import org.apache.spark.{SecurityManager, SparkConf, Logging} -import org.apache.spark.network.sasl.ShuffleSecretManager +import org.apache.spark.network.util.JavaUtils @deprecated("use yarn/stable", "1.2.0") class ExecutorRunnable( @@ -98,7 +98,8 @@ class ExecutorRunnable( val secretString = securityMgr.getSecretKey() val secretBytes = if (secretString != null) { - ShuffleSecretManager.stringToBytes(secretString) + // This conversion must match how the YarnShuffleService decodes our secret + JavaUtils.stringToBytes(secretString) } else { // Authentication is not enabled, so just provide dummy metadata ByteBuffer.allocate(0) diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 18f48b4b6caf6..fdd3c2300fa78 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -36,7 +36,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} import org.apache.spark.{SecurityManager, SparkConf, Logging} -import org.apache.spark.network.sasl.ShuffleSecretManager +import org.apache.spark.network.util.JavaUtils class ExecutorRunnable( @@ -97,7 +97,8 @@ class ExecutorRunnable( val secretString = securityMgr.getSecretKey() val secretBytes = if (secretString != null) { - ShuffleSecretManager.stringToBytes(secretString) + // This conversion must match how the YarnShuffleService decodes our secret + JavaUtils.stringToBytes(secretString) } else { // Authentication is not enabled, so just provide dummy metadata ByteBuffer.allocate(0) From 6e9ef10fd7446a11f37446c961916ba2a8e02cb8 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 6 Nov 2014 17:20:46 -0800 Subject: [PATCH 60/79] [SPARK-4277] Support external shuffle service on Standalone Worker Author: Aaron Davidson Closes #3142 from aarondav/worker and squashes the following commits: 3780bd7 [Aaron Davidson] Address comments 2dcdfc1 [Aaron Davidson] Add private[worker] 47f49d3 [Aaron Davidson] NettyBlockTransferService shouldn't care about app ids (it's only b/t executors) 258417c [Aaron Davidson] [SPARK-4277] Support external shuffle service on executor --- .../org/apache/spark/SecurityManager.scala | 14 +--- .../StandaloneWorkerShuffleService.scala | 66 +++++++++++++++++++ .../apache/spark/deploy/worker/Worker.scala | 8 ++- .../storage/ShuffleBlockFetcherIterator.scala | 2 +- .../NettyBlockTransferSecuritySuite.scala | 12 ---- .../spark/network/sasl/SaslMessage.java | 3 +- 6 files changed, 79 insertions(+), 26 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index dee935ffad51f..dbff9d12b5ad7 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -343,15 +343,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with */ def getSecretKey(): String = secretKey - override def getSaslUser(appId: String): String = { - val myAppId = sparkConf.getAppId - require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") - getSaslUser() - } - - override def getSecretKey(appId: String): String = { - val myAppId = sparkConf.getAppId - require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}") - getSecretKey() - } + // Default SecurityManager only has a single secret key, so ignore appId. + override def getSaslUser(appId: String): String = getSaslUser() + override def getSecretKey(appId: String): String = getSecretKey() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala new file mode 100644 index 0000000000000..88118e2837741 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import org.apache.spark.{Logging, SparkConf, SecurityManager} +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.sasl.SaslRpcHandler +import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler + +/** + * Provides a server from which Executors can read shuffle files (rather than reading directly from + * each other), to provide uninterrupted access to the files in the face of executors being turned + * off or killed. + * + * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. + */ +private[worker] +class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) + extends Logging { + + private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) + private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) + private val useSasl: Boolean = securityManager.isAuthenticationEnabled() + + private val transportConf = SparkTransportConf.fromSparkConf(sparkConf) + private val blockHandler = new ExternalShuffleBlockHandler() + private val transportContext: TransportContext = { + val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler + new TransportContext(transportConf, handler) + } + + private var server: TransportServer = _ + + /** Starts the external shuffle service if the user has configured us to. */ + def startIfEnabled() { + if (enabled) { + require(server == null, "Shuffle server already started") + logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + server = transportContext.createServer(port) + } + } + + def stop() { + if (enabled && server != null) { + server.close() + server = null + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index f1f66d0903f1c..ca262de832e25 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -111,6 +111,9 @@ private[spark] class Worker( val drivers = new HashMap[String, DriverRunner] val finishedDrivers = new HashMap[String, DriverRunner] + // The shuffle service is not actually started unless configured. + val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + val publicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") if (envVar != null) envVar else host @@ -154,6 +157,7 @@ private[spark] class Worker( logInfo("Spark home: " + sparkHome) createWorkDir() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() registerWithMaster() @@ -419,6 +423,7 @@ private[spark] class Worker( registrationRetryTimer.foreach(_.cancel()) executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) + shuffleService.stop() webUi.stop() metricsSystem.stop() } @@ -441,7 +446,8 @@ private[spark] object Worker extends Logging { cores: Int, memory: Int, masterUrls: Array[String], - workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = { + workDir: String, + workerNumber: Option[Int] = None): (ActorSystem, Int) = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val conf = new SparkConf diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 1e579187e4193..6b1f57a069431 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -92,7 +92,7 @@ final class ShuffleBlockFetcherIterator( * Current [[FetchResult]] being processed. We track this so we can release the current buffer * in case of a runtime exception when processing the current buffer. */ - private[this] var currentResult: FetchResult = null + @volatile private[this] var currentResult: FetchResult = null /** * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index bed0ed9d713dd..9162ec9801663 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -89,18 +89,6 @@ class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with Sh } } - test("security mismatch app ids") { - val conf0 = new SparkConf() - .set("spark.authenticate", "true") - .set("spark.authenticate.secret", "good") - .set("spark.app.id", "app-id") - val conf1 = conf0.clone.set("spark.app.id", "other-id") - testConnection(conf0, conf1) match { - case Success(_) => fail("Should have failed") - case Failure(t) => t.getMessage should include ("SASL appId app-id did not match") - } - } - /** * Creates two servers with different configurations and sees if they can talk. * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index 5b77e18c26bf4..599cc6428c90e 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -58,7 +58,8 @@ public void encode(ByteBuf buf) { public static SaslMessage decode(ByteBuf buf) { if (buf.readByte() != TAG_BYTE) { - throw new IllegalStateException("Expected SaslMessage, received something else"); + throw new IllegalStateException("Expected SaslMessage, received something else" + + " (maybe your client does not have SASL enabled?)"); } int idLength = buf.readInt(); From f165b2bbf5d4acf34d826fa55b900f5bbc295654 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 6 Nov 2014 18:39:14 -0800 Subject: [PATCH 61/79] [SPARK-4188] [Core] Perform network-level retry of shuffle file fetches This adds a RetryingBlockFetcher to the NettyBlockTransferService which is wrapped around our typical OneForOneBlockFetcher, adding retry logic in the event of an IOException. This sort of retry allows us to avoid marking an entire executor as failed due to garbage collection or high network load. TODO: - [x] unit tests - [x] put in ExternalShuffleClient too Author: Aaron Davidson Closes #3101 from aarondav/retry and squashes the following commits: 72a2a32 [Aaron Davidson] Add that we should remove the condition around the retry thingy c7fd107 [Aaron Davidson] Fix unit tests e80e4c2 [Aaron Davidson] Address initial comments 6f594cd [Aaron Davidson] Fix unit test 05ff43c [Aaron Davidson] Add to external shuffle client and add unit test 66e5a24 [Aaron Davidson] [SPARK-4238] [Core] Perform network-level retry of shuffle file fetches --- .../netty/NettyBlockTransferService.scala | 21 +- .../spark/network/client/TransportClient.java | 16 +- .../client/TransportClientFactory.java | 13 +- .../client/TransportResponseHandler.java | 3 +- .../network/protocol/MessageEncoder.java | 2 +- .../spark/network/server/TransportServer.java | 8 +- .../apache/spark/network/util/NettyUtils.java | 14 +- .../spark/network/util/TransportConf.java | 17 + .../network/TransportClientFactorySuite.java | 7 +- .../shuffle/ExternalShuffleClient.java | 31 +- .../shuffle/OneForOneBlockFetcher.java | 9 +- .../network/shuffle/RetryingBlockFetcher.java | 234 +++++++++++++ .../network/sasl/SaslIntegrationSuite.java | 4 +- .../ExternalShuffleIntegrationSuite.java | 18 +- .../shuffle/ExternalShuffleSecuritySuite.java | 6 +- .../shuffle/RetryingBlockFetcherSuite.java | 310 ++++++++++++++++++ 16 files changed, 668 insertions(+), 45 deletions(-) create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 0d1fc81d2a16f..b937ea825f49e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -27,7 +27,7 @@ import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCal import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher} +import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -71,9 +71,22 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage listener: BlockFetchingListener): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { - val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, blockIds.toArray, listener) - .start(OpenBlocks(blockIds.map(BlockId.apply))) + val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { + override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { + val client = clientFactory.createClient(host, port) + new OneForOneBlockFetcher(client, blockIds.toArray, listener) + .start(OpenBlocks(blockIds.map(BlockId.apply))) + } + } + + val maxRetries = transportConf.maxIORetries() + if (maxRetries > 0) { + // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's + // a bug in this code. We should remove the if statement once we're sure of the stability. + new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start() + } else { + blockFetchStarter.createAndStart(blockIds, listener) + } } catch { case e: Exception => logError("Exception while beginning fetchBlocks", e) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index a08cee02dd576..4e944114e8176 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -18,7 +18,9 @@ package org.apache.spark.network.client; import java.io.Closeable; +import java.io.IOException; import java.util.UUID; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import com.google.common.base.Objects; @@ -116,8 +118,12 @@ public void operationComplete(ChannelFuture future) throws Exception { serverAddr, future.cause()); logger.error(errorMsg, future.cause()); handler.removeFetchRequest(streamChunkId); - callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause())); channel.close(); + try { + callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } } } }); @@ -147,8 +153,12 @@ public void operationComplete(ChannelFuture future) throws Exception { serverAddr, future.cause()); logger.error(errorMsg, future.cause()); handler.removeRpcRequest(requestId); - callback.onFailure(new RuntimeException(errorMsg, future.cause())); channel.close(); + try { + callback.onFailure(new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } } } }); @@ -175,6 +185,8 @@ public void onFailure(Throwable e) { try { return result.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (ExecutionException e) { + throw Throwables.propagate(e.getCause()); } catch (Exception e) { throw Throwables.propagate(e); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 1723fed307257..397d3a8455c86 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -18,12 +18,12 @@ package org.apache.spark.network.client; import java.io.Closeable; +import java.io.IOException; import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.List; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; import com.google.common.base.Preconditions; @@ -44,7 +44,6 @@ import org.apache.spark.network.TransportContext; import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.util.IOMode; -import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -93,15 +92,17 @@ public TransportClientFactory( * * Concurrency: This method is safe to call from multiple threads. */ - public TransportClient createClient(String remoteHost, int remotePort) { + public TransportClient createClient(String remoteHost, int remotePort) throws IOException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); TransportClient cachedClient = connectionPool.get(address); if (cachedClient != null) { if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); return cachedClient; } else { + logger.info("Found inactive connection to {}, closing it.", address); connectionPool.remove(address, cachedClient); // Remove inactive clients. } } @@ -133,10 +134,10 @@ public void initChannel(SocketChannel ch) { long preConnect = System.currentTimeMillis(); ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { - throw new RuntimeException( + throw new IOException( String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } else if (cf.cause() != null) { - throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); + throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); } TransportClient client = clientRef.get(); @@ -198,7 +199,7 @@ public void close() { */ private PooledByteBufAllocator createPooledByteBufAllocator() { return new PooledByteBufAllocator( - PlatformDependent.directBufferPreferred(), + conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(), getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), getPrivateStaticField("DEFAULT_PAGE_SIZE"), diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index d8965590b34da..2044afb0d85db 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.client; +import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -94,7 +95,7 @@ public void channelUnregistered() { String remoteAddress = NettyUtils.getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", numOutstandingRequests(), remoteAddress); - failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed")); + failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed")); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 4cb8becc3ed22..91d1e8a538a77 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -66,7 +66,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) { // All messages have the frame length, message type, and message itself. int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); long frameLength = headerLength + bodyLength; - ByteBuf header = ctx.alloc().buffer(headerLength); + ByteBuf header = ctx.alloc().heapBuffer(headerLength); header.writeLong(frameLength); msgType.encode(header); in.encode(header); diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 70da48ca8ee79..579676c2c3564 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -28,6 +28,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import io.netty.util.internal.PlatformDependent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -71,11 +72,14 @@ private void init(int portToBind) { NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); EventLoopGroup workerGroup = bossGroup; + PooledByteBufAllocator allocator = new PooledByteBufAllocator( + conf.preferDirectBufs() && PlatformDependent.directBufferPreferred()); + bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT); + .option(ChannelOption.ALLOCATOR, allocator) + .childOption(ChannelOption.ALLOCATOR, allocator); if (conf.backLog() > 0) { bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog()); diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index b1872341198e0..2a7664fe89388 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -37,13 +37,17 @@ * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. */ public class NettyUtils { - /** Creates a Netty EventLoopGroup based on the IOMode. */ - public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { - - ThreadFactory threadFactory = new ThreadFactoryBuilder() + /** Creates a new ThreadFactory which prefixes each thread with the given name. */ + public static ThreadFactory createThreadFactory(String threadPoolPrefix) { + return new ThreadFactoryBuilder() .setDaemon(true) - .setNameFormat(threadPrefix + "-%d") + .setNameFormat(threadPoolPrefix + "-%d") .build(); + } + + /** Creates a Netty EventLoopGroup based on the IOMode. */ + public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { + ThreadFactory threadFactory = createThreadFactory(threadPrefix); switch (mode) { case NIO: diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 823790dd3c66f..787a8f0031af1 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -30,6 +30,11 @@ public TransportConf(ConfigProvider conf) { /** IO mode: nio or epoll */ public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + /** If true, we will prefer allocating off-heap byte buffers within Netty. */ + public boolean preferDirectBufs() { + return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true); + } + /** Connect timeout in secs. Default 120 secs. */ public int connectionTimeoutMs() { return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000; @@ -58,4 +63,16 @@ public int connectionTimeoutMs() { /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); } + + /** + * Max number of times we will try IO exceptions (such as connection timeouts) per request. + * If set to 0, we will not do any retries. + */ + public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); } + + /** + * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. + * Only relevant if maxIORetries > 0. + */ + public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); } } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 5a10fdb3842ef..822bef1d81b2a 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.io.IOException; import java.util.concurrent.TimeoutException; import org.junit.After; @@ -57,7 +58,7 @@ public void tearDown() { } @Test - public void createAndReuseBlockClients() throws TimeoutException { + public void createAndReuseBlockClients() throws IOException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); @@ -70,7 +71,7 @@ public void createAndReuseBlockClients() throws TimeoutException { } @Test - public void neverReturnInactiveClients() throws Exception { + public void neverReturnInactiveClients() throws IOException, InterruptedException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); c1.close(); @@ -88,7 +89,7 @@ public void neverReturnInactiveClients() throws Exception { } @Test - public void closeBlockClientsWithFactory() throws TimeoutException { + public void closeBlockClientsWithFactory() throws IOException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 3aa95d00f6b20..27884b82c8cb9 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.io.IOException; import java.util.List; import com.google.common.collect.Lists; @@ -76,17 +77,33 @@ public void init(String appId) { @Override public void fetchBlocks( - String host, - int port, - String execId, + final String host, + final int port, + final String execId, String[] blockIds, BlockFetchingListener listener) { assert appId != null : "Called before init()"; logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { - TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, blockIds, listener) - .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = + new RetryingBlockFetcher.BlockFetchStarter() { + @Override + public void createAndStart(String[] blockIds, BlockFetchingListener listener) + throws IOException { + TransportClient client = clientFactory.createClient(host, port); + new OneForOneBlockFetcher(client, blockIds, listener) + .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + } + }; + + int maxRetries = conf.maxIORetries(); + if (maxRetries > 0) { + // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's + // a bug in this code. We should remove the if statement once we're sure of the stability. + new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start(); + } else { + blockFetchStarter.createAndStart(blockIds, listener); + } } catch (Exception e) { logger.error("Exception while beginning fetchBlocks", e); for (String blockId : blockIds) { @@ -108,7 +125,7 @@ public void registerWithShuffleServer( String host, int port, String execId, - ExecutorShuffleInfo executorInfo) { + ExecutorShuffleInfo executorInfo) throws IOException { assert appId != null : "Called before init()"; TransportClient client = clientFactory.createClient(host, port); byte[] registerExecutorMessage = diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 39b6f30f92baf..9e77a1f68c4b0 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -51,9 +51,6 @@ public OneForOneBlockFetcher( TransportClient client, String[] blockIds, BlockFetchingListener listener) { - if (blockIds.length == 0) { - throw new IllegalArgumentException("Zero-sized blockIds array"); - } this.client = client; this.blockIds = blockIds; this.listener = listener; @@ -82,6 +79,10 @@ public void onFailure(int chunkIndex, Throwable e) { * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling. */ public void start(Object openBlocksMessage) { + if (blockIds.length == 0) { + throw new IllegalArgumentException("Zero-sized blockIds array"); + } + client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() { @Override public void onSuccess(byte[] response) { @@ -95,7 +96,7 @@ public void onSuccess(byte[] response) { client.fetchChunk(streamHandle.streamId, i, chunkCallback); } } catch (Exception e) { - logger.error("Failed while starting block fetches", e); + logger.error("Failed while starting block fetches after success", e); failRemainingBlocks(blockIds, e); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java new file mode 100644 index 0000000000000..f8a1a266863bb --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Sets; +import com.google.common.util.concurrent.Uninterruptibles; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Wraps another BlockFetcher with the ability to automatically retry fetches which fail due to + * IOExceptions, which we hope are due to transient network conditions. + * + * This fetcher provides stronger guarantees regarding the parent BlockFetchingListener. In + * particular, the listener will be invoked exactly once per blockId, with a success or failure. + */ +public class RetryingBlockFetcher { + + /** + * Used to initiate the first fetch for all blocks, and subsequently for retrying the fetch on any + * remaining blocks. + */ + public static interface BlockFetchStarter { + /** + * Creates a new BlockFetcher to fetch the given block ids which may do some synchronous + * bootstrapping followed by fully asynchronous block fetching. + * The BlockFetcher must eventually invoke the Listener on every input blockId, or else this + * method must throw an exception. + * + * This method should always attempt to get a new TransportClient from the + * {@link org.apache.spark.network.client.TransportClientFactory} in order to fix connection + * issues. + */ + void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException; + } + + /** Shared executor service used for waiting and retrying. */ + private static final ExecutorService executorService = Executors.newCachedThreadPool( + NettyUtils.createThreadFactory("Block Fetch Retry")); + + private final Logger logger = LoggerFactory.getLogger(RetryingBlockFetcher.class); + + /** Used to initiate new Block Fetches on our remaining blocks. */ + private final BlockFetchStarter fetchStarter; + + /** Parent listener which we delegate all successful or permanently failed block fetches to. */ + private final BlockFetchingListener listener; + + /** Max number of times we are allowed to retry. */ + private final int maxRetries; + + /** Milliseconds to wait before each retry. */ + private final int retryWaitTime; + + // NOTE: + // All of our non-final fields are synchronized under 'this' and should only be accessed/mutated + // while inside a synchronized block. + /** Number of times we've attempted to retry so far. */ + private int retryCount = 0; + + /** + * Set of all block ids which have not been fetched successfully or with a non-IO Exception. + * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet, + * input ordering is preserved, so we always request blocks in the same order the user provided. + */ + private final LinkedHashSet outstandingBlocksIds; + + /** + * The BlockFetchingListener that is active with our current BlockFetcher. + * When we start a retry, we immediately replace this with a new Listener, which causes all any + * old Listeners to ignore all further responses. + */ + private RetryingBlockFetchListener currentListener; + + public RetryingBlockFetcher( + TransportConf conf, + BlockFetchStarter fetchStarter, + String[] blockIds, + BlockFetchingListener listener) { + this.fetchStarter = fetchStarter; + this.listener = listener; + this.maxRetries = conf.maxIORetries(); + this.retryWaitTime = conf.ioRetryWaitTime(); + this.outstandingBlocksIds = Sets.newLinkedHashSet(); + Collections.addAll(outstandingBlocksIds, blockIds); + this.currentListener = new RetryingBlockFetchListener(); + } + + /** + * Initiates the fetch of all blocks provided in the constructor, with possible retries in the + * event of transient IOExceptions. + */ + public void start() { + fetchAllOutstanding(); + } + + /** + * Fires off a request to fetch all blocks that have not been fetched successfully or permanently + * failed (i.e., by a non-IOException). + */ + private void fetchAllOutstanding() { + // Start by retrieving our shared state within a synchronized block. + String[] blockIdsToFetch; + int numRetries; + RetryingBlockFetchListener myListener; + synchronized (this) { + blockIdsToFetch = outstandingBlocksIds.toArray(new String[outstandingBlocksIds.size()]); + numRetries = retryCount; + myListener = currentListener; + } + + // Now initiate the fetch on all outstanding blocks, possibly initiating a retry if that fails. + try { + fetchStarter.createAndStart(blockIdsToFetch, myListener); + } catch (Exception e) { + logger.error(String.format("Exception while beginning fetch of %s outstanding blocks %s", + blockIdsToFetch.length, numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e); + + if (shouldRetry(e)) { + initiateRetry(); + } else { + for (String bid : blockIdsToFetch) { + listener.onBlockFetchFailure(bid, e); + } + } + } + } + + /** + * Lightweight method which initiates a retry in a different thread. The retry will involve + * calling fetchAllOutstanding() after a configured wait time. + */ + private synchronized void initiateRetry() { + retryCount += 1; + currentListener = new RetryingBlockFetchListener(); + + logger.info("Retrying fetch ({}/{}) for {} outstanding blocks after {} ms", + retryCount, maxRetries, outstandingBlocksIds.size(), retryWaitTime); + + executorService.submit(new Runnable() { + @Override + public void run() { + Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS); + fetchAllOutstanding(); + } + }); + } + + /** + * Returns true if we should retry due a block fetch failure. We will retry if and only if + * the exception was an IOException and we haven't retried 'maxRetries' times already. + */ + private synchronized boolean shouldRetry(Throwable e) { + boolean isIOException = e instanceof IOException + || (e.getCause() != null && e.getCause() instanceof IOException); + boolean hasRemainingRetries = retryCount < maxRetries; + return isIOException && hasRemainingRetries; + } + + /** + * Our RetryListener intercepts block fetch responses and forwards them to our parent listener. + * Note that in the event of a retry, we will immediately replace the 'currentListener' field, + * indicating that any responses from non-current Listeners should be ignored. + */ + private class RetryingBlockFetchListener implements BlockFetchingListener { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + // We will only forward this success message to our parent listener if this block request is + // outstanding and we are still the active listener. + boolean shouldForwardSuccess = false; + synchronized (RetryingBlockFetcher.this) { + if (this == currentListener && outstandingBlocksIds.contains(blockId)) { + outstandingBlocksIds.remove(blockId); + shouldForwardSuccess = true; + } + } + + // Now actually invoke the parent listener, outside of the synchronized block. + if (shouldForwardSuccess) { + listener.onBlockFetchSuccess(blockId, data); + } + } + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + // We will only forward this failure to our parent listener if this block request is + // outstanding, we are still the active listener, AND we cannot retry the fetch. + boolean shouldForwardFailure = false; + synchronized (RetryingBlockFetcher.this) { + if (this == currentListener && outstandingBlocksIds.contains(blockId)) { + if (shouldRetry(exception)) { + initiateRetry(); + } else { + logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)", + blockId, retryCount), exception); + outstandingBlocksIds.remove(blockId); + shouldForwardFailure = true; + } + } + } + + // Now actually invoke the parent listener, outside of the synchronized block. + if (shouldForwardFailure) { + listener.onBlockFetchFailure(blockId, exception); + } + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 84781207861ed..d25283e46ef96 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -93,7 +93,7 @@ public void afterEach() { } @Test - public void testGoodClient() { + public void testGoodClient() throws IOException { clientFactory = context.createClientFactory( Lists.newArrayList( new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); @@ -119,7 +119,7 @@ public void testBadClient() { } @Test - public void testNoSaslClient() { + public void testNoSaslClient() throws IOException { clientFactory = context.createClientFactory( Lists.newArrayList()); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 71e017b9e4e74..06294fef19621 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -259,14 +259,20 @@ public void testFetchUnregisteredExecutor() throws Exception { @Test public void testFetchNoServer() throws Exception { - registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-0", - new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }, 1 /* port */); - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + System.setProperty("spark.shuffle.io.maxRetries", "0"); + try { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } finally { + System.clearProperty("spark.shuffle.io.maxRetries"); + } } - private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { + private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) + throws IOException { ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 4c18fcdfbcd88..848c88f743d50 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -17,6 +17,8 @@ package org.apache.spark.network.shuffle; +import java.io.IOException; + import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -54,7 +56,7 @@ public void afterEach() { } @Test - public void testValid() { + public void testValid() throws IOException { validate("my-app-id", "secret"); } @@ -77,7 +79,7 @@ public void testBadSecret() { } /** Creates an ExternalShuffleClient and attempts to register with the server. */ - private void validate(String appId, String secretKey) { + private void validate(String appId, String secretKey) throws IOException { ExternalShuffleClient client = new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true); client.init(appId); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java new file mode 100644 index 0000000000000..0191fe529e1be --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.LinkedHashSet; +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.stubbing.Stubber; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; +import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter; + +/** + * Tests retry logic by throwing IOExceptions and ensuring that subsequent attempts are made to + * fetch the lost blocks. + */ +public class RetryingBlockFetcherSuite { + + ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); + ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); + + @Before + public void beforeEach() { + System.setProperty("spark.shuffle.io.maxRetries", "2"); + System.setProperty("spark.shuffle.io.retryWaitMs", "0"); + } + + @After + public void afterEach() { + System.clearProperty("spark.shuffle.io.maxRetries"); + System.clearProperty("spark.shuffle.io.retryWaitMs"); + } + + @Test + public void testNoFailures() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // Immediately return both blocks successfully. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchSuccess("b0", block0); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testUnrecoverableFailure() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0 throws a non-IOException error, so it will be failed without retry. + ImmutableMap.builder() + .put("b0", new RuntimeException("Ouch!")) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any()); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnFirst() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // IOException will cause a retry. Since b0 fails, we will retry both. + ImmutableMap.builder() + .put("b0", new IOException("Connection failed or something")) + .put("b1", block1) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnSecond() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // IOException will cause a retry. Since b1 fails, we will not retry b0. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException("Connection failed or something")) + .build(), + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testTwoIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 returns successfully within 2 retries. + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testThreeIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 errors again, but this was the last retry + ImmutableMap.builder() + .put("b1", new IOException()) + .build(), + // This is not reached -- b1 has failed. + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verifyNoMoreInteractions(listener); + } + + @Test + public void testRetryAndUnrecoverable() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, subsequent messages will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new RuntimeException()) + .put("b2", block2) + .build(), + // Next, b0 is successful, b1 errors unrecoverably, and b2 triggers a retry. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new RuntimeException()) + .put("b2", new IOException()) + .build(), + // b2 succeeds in its last retry. + ImmutableMap.builder() + .put("b2", block2) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2); + verifyNoMoreInteractions(listener); + } + + /** + * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. + * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction + * means "respond to the next block fetch request with these Successful buffers and these Failure + * exceptions". We verify that the expected block ids are exactly the ones requested. + * + * If multiple interactions are supplied, they will be used in order. This is useful for encoding + * retries -- the first interaction may include an IOException, which causes a retry of some + * subset of the original blocks in a second interaction. + */ + @SuppressWarnings("unchecked") + private void performInteractions(final Map[] interactions, BlockFetchingListener listener) + throws IOException { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); + + Stubber stub = null; + + // Contains all blockIds that are referenced across all interactions. + final LinkedHashSet blockIds = Sets.newLinkedHashSet(); + + for (final Map interaction : interactions) { + blockIds.addAll(interaction.keySet()); + + Answer answer = new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + try { + // Verify that the RetryingBlockFetcher requested the expected blocks. + String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0]; + String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]); + assertArrayEquals(desiredBlockIds, requestedBlockIds); + + // Now actually invoke the success/failure callbacks on each block. + BlockFetchingListener retryListener = + (BlockFetchingListener) invocationOnMock.getArguments()[1]; + for (Map.Entry block : interaction.entrySet()) { + String blockId = block.getKey(); + Object blockValue = block.getValue(); + + if (blockValue instanceof ManagedBuffer) { + retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue); + } else if (blockValue instanceof Exception) { + retryListener.onBlockFetchFailure(blockId, (Exception) blockValue); + } else { + fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue); + } + } + return null; + } catch (Throwable e) { + e.printStackTrace(); + throw e; + } + } + }; + + // This is either the first stub, or should be chained behind the prior ones. + if (stub == null) { + stub = doAnswer(answer); + } else { + stub.doAnswer(answer); + } + } + + assert stub != null; + stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject()); + String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); + new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start(); + } +} From 48a19a6dba896f7d0b637f84e114b7efbb814e51 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 6 Nov 2014 19:54:32 -0800 Subject: [PATCH 62/79] [SPARK-4236] Cleanup removed applications' files in shuffle service This relies on a hook from whoever is hosting the shuffle service to invoke removeApplication() when the application is completed. Once invoked, we will clean up all the executors' shuffle directories we know about. Author: Aaron Davidson Closes #3126 from aarondav/cleanup and squashes the following commits: 33a64a9 [Aaron Davidson] Missing brace e6e428f [Aaron Davidson] Address comments 16a0d27 [Aaron Davidson] Cleanup e4df3e7 [Aaron Davidson] [SPARK-4236] Cleanup removed applications' files in shuffle service --- .../scala/org/apache/spark/util/Utils.scala | 1 + .../spark/ExternalShuffleServiceSuite.scala | 5 +- .../apache/spark/network/util/JavaUtils.java | 59 ++++++++ .../shuffle/ExternalShuffleBlockHandler.java | 10 +- .../shuffle/ExternalShuffleBlockManager.java | 118 +++++++++++++-- .../shuffle/ExternalShuffleCleanupSuite.java | 142 ++++++++++++++++++ .../ExternalShuffleIntegrationSuite.java | 2 +- .../shuffle/TestShuffleDataContext.java | 4 +- 8 files changed, 319 insertions(+), 22 deletions(-) create mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 7caf6bcf94ef3..2cbd38d72caa1 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -755,6 +755,7 @@ private[spark] object Utils extends Logging { /** * Delete a file or directory and its contents recursively. * Don't follow directories if they are symlinks. + * Throws an exception if deletion is unsuccessful. */ def deleteRecursively(file: File) { if (file != null) { diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 792b9cd8b6ff2..6608ed1e57b38 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -63,8 +63,9 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { rdd.count() rdd.count() - // Invalidate the registered executors, disallowing access to their shuffle blocks. - rpcHandler.clearRegisteredExecutors() + // Invalidate the registered executors, disallowing access to their shuffle blocks (without + // deleting the actual shuffle files, so we could access them without the shuffle service). + rpcHandler.applicationRemoved(sc.conf.getAppId, false /* cleanupLocalDirs */) // Now Spark will receive FetchFailed, and not retry the stage due to "spark.test.noStageRetry" // being set. diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 2856d1c8c9337..75c4a3981a240 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -22,16 +22,22 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.Closeable; +import java.io.File; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import com.google.common.base.Preconditions; import com.google.common.io.Closeables; import com.google.common.base.Charsets; import io.netty.buffer.Unpooled; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +/** + * General utilities available in the network package. Many of these are sourced from Spark's + * own Utils, just accessible within this package. + */ public class JavaUtils { private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); @@ -93,4 +99,57 @@ public static ByteBuffer stringToBytes(String s) { public static String bytesToString(ByteBuffer b) { return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8); } + + /* + * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. + * Throws an exception if deletion is unsuccessful. + */ + public static void deleteRecursively(File file) throws IOException { + if (file == null) { return; } + + if (file.isDirectory() && !isSymlink(file)) { + IOException savedIOException = null; + for (File child : listFilesSafely(file)) { + try { + deleteRecursively(child); + } catch (IOException e) { + // In case of multiple exceptions, only last one will be thrown + savedIOException = e; + } + } + if (savedIOException != null) { + throw savedIOException; + } + } + + boolean deleted = file.delete(); + // Delete can also fail if the file simply did not exist. + if (!deleted && file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath()); + } + } + + private static File[] listFilesSafely(File file) throws IOException { + if (file.exists()) { + File[] files = file.listFiles(); + if (files == null) { + throw new IOException("Failed to list files for dir: " + file); + } + return files; + } else { + return new File[0]; + } + } + + private static boolean isSymlink(File file) throws IOException { + Preconditions.checkNotNull(file); + File fileInCanonicalDir = null; + if (file.getParent() == null) { + fileInCanonicalDir = file; + } else { + fileInCanonicalDir = new File(file.getParentFile().getCanonicalFile(), file.getName()); + } + return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index cd3fea85b19a4..75ebf8c7b0604 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -94,9 +94,11 @@ public StreamManager getStreamManager() { return streamManager; } - /** For testing, clears all executors registered with "RegisterExecutor". */ - @VisibleForTesting - public void clearRegisteredExecutors() { - blockManager.clearRegisteredExecutors(); + /** + * Removes an application (once it has been terminated), and optionally will clean up any + * local directories associated with the executors of that application in a separate thread. + */ + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + blockManager.applicationRemoved(appId, cleanupLocalDirs); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java index 6589889fe1be7..98fcfb82aa5d1 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -21,9 +21,15 @@ import java.io.File; import java.io.FileInputStream; import java.io.IOException; -import java.util.concurrent.ConcurrentHashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Objects; +import com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,13 +49,22 @@ public class ExternalShuffleBlockManager { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockManager.class); - // Map from "appId-execId" to the executor's configuration. - private final ConcurrentHashMap executors = - new ConcurrentHashMap(); + // Map containing all registered executors' metadata. + private final ConcurrentMap executors; - // Returns an id suitable for a single executor within a single application. - private String getAppExecId(String appId, String execId) { - return appId + "-" + execId; + // Single-threaded Java executor used to perform expensive recursive directory deletion. + private final Executor directoryCleaner; + + public ExternalShuffleBlockManager() { + // TODO: Give this thread a name. + this(Executors.newSingleThreadExecutor()); + } + + // Allows tests to have more control over when directories are cleaned up. + @VisibleForTesting + ExternalShuffleBlockManager(Executor directoryCleaner) { + this.executors = Maps.newConcurrentMap(); + this.directoryCleaner = directoryCleaner; } /** Registers a new Executor with all the configuration we need to find its shuffle files. */ @@ -57,7 +72,7 @@ public void registerExecutor( String appId, String execId, ExecutorShuffleInfo executorInfo) { - String fullId = getAppExecId(appId, execId); + AppExecId fullId = new AppExecId(appId, execId); logger.info("Registered executor {} with {}", fullId, executorInfo); executors.put(fullId, executorInfo); } @@ -78,7 +93,7 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { int mapId = Integer.parseInt(blockIdParts[2]); int reduceId = Integer.parseInt(blockIdParts[3]); - ExecutorShuffleInfo executor = executors.get(getAppExecId(appId, execId)); + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); @@ -94,6 +109,56 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { } } + /** + * Removes our metadata of all executors registered for the given application, and optionally + * also deletes the local directories associated with the executors of that application in a + * separate thread. + * + * It is not valid to call registerExecutor() for an executor with this appId after invoking + * this method. + */ + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + logger.info("Application {} removed, cleanupLocalDirs = {}", appId, cleanupLocalDirs); + Iterator> it = executors.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + AppExecId fullId = entry.getKey(); + final ExecutorShuffleInfo executor = entry.getValue(); + + // Only touch executors associated with the appId that was removed. + if (appId.equals(fullId.appId)) { + it.remove(); + + if (cleanupLocalDirs) { + logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length); + + // Execute the actual deletion in a different thread, as it may take some time. + directoryCleaner.execute(new Runnable() { + @Override + public void run() { + deleteExecutorDirs(executor.localDirs); + } + }); + } + } + } + } + + /** + * Synchronously deletes each directory one at a time. + * Should be executed in its own thread, as this may take a long time. + */ + private void deleteExecutorDirs(String[] dirs) { + for (String localDir : dirs) { + try { + JavaUtils.deleteRecursively(new File(localDir)); + logger.debug("Successfully cleaned up directory: " + localDir); + } catch (Exception e) { + logger.error("Failed to delete directory: " + localDir, e); + } + } + } + /** * Hash-based shuffle data is simply stored as one file per block. * This logic is from FileShuffleBlockManager. @@ -146,9 +211,36 @@ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) return new File(new File(localDir, String.format("%02x", subDirId)), filename); } - /** For testing, clears all registered executors. */ - @VisibleForTesting - void clearRegisteredExecutors() { - executors.clear(); + /** Simply encodes an executor's full ID, which is appId + execId. */ + private static class AppExecId { + final String appId; + final String execId; + + private AppExecId(String appId, String execId) { + this.appId = appId; + this.execId = execId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + AppExecId appExecId = (AppExecId) o; + return Objects.equal(appId, appExecId.appId) && Objects.equal(execId, appExecId.execId); + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .toString(); + } } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java new file mode 100644 index 0000000000000..c8ece3bc53ac3 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class ExternalShuffleCleanupSuite { + + // Same-thread Executor used to ensure cleanup happens synchronously in test thread. + Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + + @Test + public void noCleanupAndCleanup() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", false /* cleanup */); + + assertStillThere(dataContext); + + manager.registerExecutor("app", "exec1", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true /* cleanup */); + + assertCleanedUp(dataContext); + } + + @Test + public void cleanupUsesExecutor() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + final AtomicBoolean cleanupCalled = new AtomicBoolean(false); + + // Executor which does nothing to ensure we're actually using it. + Executor noThreadExecutor = new Executor() { + @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } + }; + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(noThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true); + + assertTrue(cleanupCalled.get()); + assertStillThere(dataContext); + + dataContext.cleanup(); + assertCleanedUp(dataContext); + } + + @Test + public void cleanupMultipleExecutors() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + manager.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true); + + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + @Test + public void cleanupOnlyRemovedApp() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + + manager.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + manager.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); + + manager.applicationRemoved("app-nonexistent", true); + assertStillThere(dataContext0); + assertStillThere(dataContext1); + + manager.applicationRemoved("app-0", true); + assertCleanedUp(dataContext0); + assertStillThere(dataContext1); + + manager.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + + // Make sure it's not an error to cleanup multiple times + manager.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + private void assertStillThere(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); + } + } + + private void assertCleanedUp(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertFalse(localDir + " wasn't cleaned up", new File(localDir).exists()); + } + } + + private TestShuffleDataContext createSomeData() throws IOException { + Random rand = new Random(123); + TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); + + dataContext.create(); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), + new byte[][] { "ABC".getBytes(), "DEF".getBytes() } ); + dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000, + new byte[][] { "GHI".getBytes(), "JKLMNOPQRSTUVWXYZ".getBytes() } ); + return dataContext; + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 06294fef19621..3bea5b0f253c6 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -105,7 +105,7 @@ public static void afterAll() { @After public void afterEach() { - handler.clearRegisteredExecutors(); + handler.applicationRemoved(APP_ID, false /* cleanupLocalDirs */); } class FetchResult { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 442b756467442..337b5c7bdb5da 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -30,8 +30,8 @@ * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}. */ public class TestShuffleDataContext { - private final String[] localDirs; - private final int subDirsPerLocalDir; + public final String[] localDirs; + public final int subDirsPerLocalDir; public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) { this.localDirs = new String[numLocalDirs]; From 3abdb1b24aa48f21e7eed1232c01d3933873688c Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 6 Nov 2014 21:52:12 -0800 Subject: [PATCH 63/79] [SPARK-4204][Core][WebUI] Change Utils.exceptionString to contain the inner exceptions and make the error information in Web UI more friendly This PR fixed `Utils.exceptionString` to output the full exception information. However, the stack trace may become very huge, so I also updated the Web UI to collapse the error information by default (display the first line and clicking `+detail` will display the full info). Here are the screenshots: Stages: ![stages](https://cloud.githubusercontent.com/assets/1000778/4882441/66d8cc68-6356-11e4-8346-6318677d9470.png) Details for one stage: ![stage](https://cloud.githubusercontent.com/assets/1000778/4882513/1311043c-6357-11e4-8804-ca14240a9145.png) The full information in the gray text field is: ```Java org.apache.spark.shuffle.FetchFailedException: Connection reset by peer at org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher$.org$apache$spark$shuffle$hash$BlockStoreShuffleFetcher$$unpackBlock$1(BlockStoreShuffleFetcher.scala:67) at org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher$$anonfun$3.apply(BlockStoreShuffleFetcher.scala:83) at org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher$$anonfun$3.apply(BlockStoreShuffleFetcher.scala:83) at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:371) at org.apache.spark.util.CompletionIterator.hasNext(CompletionIterator.scala:30) at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:39) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327) at org.apache.spark.util.collection.ExternalAppendOnlyMap.insertAll(ExternalAppendOnlyMap.scala:129) at org.apache.spark.rdd.CoGroupedRDD$$anonfun$compute$5.apply(CoGroupedRDD.scala:160) at org.apache.spark.rdd.CoGroupedRDD$$anonfun$compute$5.apply(CoGroupedRDD.scala:159) at scala.collection.TraversableLike$WithFilter$$anonfun$foreach$1.apply(TraversableLike.scala:772) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47) at scala.collection.TraversableLike$WithFilter.foreach(TraversableLike.scala:771) at org.apache.spark.rdd.CoGroupedRDD.compute(CoGroupedRDD.scala:159) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:263) at org.apache.spark.rdd.RDD.iterator(RDD.scala:230) at org.apache.spark.rdd.MappedValuesRDD.compute(MappedValuesRDD.scala:31) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:263) at org.apache.spark.rdd.RDD.iterator(RDD.scala:230) at org.apache.spark.rdd.FlatMappedValuesRDD.compute(FlatMappedValuesRDD.scala:31) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:263) at org.apache.spark.rdd.RDD.iterator(RDD.scala:230) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:61) at org.apache.spark.scheduler.Task.run(Task.scala:56) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:189) at java.util.concurrent.ThreadPoolExecutor$Worker.runTask(ThreadPoolExecutor.java:886) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:908) at java.lang.Thread.run(Thread.java:662) Caused by: java.io.IOException: Connection reset by peer at sun.nio.ch.FileDispatcher.read0(Native Method) at sun.nio.ch.SocketDispatcher.read(SocketDispatcher.java:21) at sun.nio.ch.IOUtil.readIntoNativeBuffer(IOUtil.java:198) at sun.nio.ch.IOUtil.read(IOUtil.java:166) at sun.nio.ch.SocketChannelImpl.read(SocketChannelImpl.java:245) at io.netty.buffer.PooledUnsafeDirectByteBuf.setBytes(PooledUnsafeDirectByteBuf.java:311) at io.netty.buffer.AbstractByteBuf.writeBytes(AbstractByteBuf.java:881) at io.netty.channel.socket.nio.NioSocketChannel.doReadBytes(NioSocketChannel.java:225) at io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:119) at io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:511) at io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:468) at io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:382) at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:354) at io.netty.util.concurrent.SingleThreadEventExecutor$2.run(SingleThreadEventExecutor.java:116) ... 1 more ``` /cc aarondav Author: zsxwing Closes #3073 from zsxwing/SPARK-4204 and squashes the following commits: 176d1e3 [zsxwing] Add comments to explain the stack trace difference ca509d3 [zsxwing] Add fullStackTrace to the constructor of ExceptionFailure a07057b [zsxwing] Core style fix dfb0032 [zsxwing] Backward compatibility for old history server 1e50f71 [zsxwing] Update as per review and increase the max height of the stack trace details 94f2566 [zsxwing] Change Utils.exceptionString to contain the inner exceptions and make the error information in Web UI more friendly --- .../org/apache/spark/ui/static/webui.css | 14 ++++++++ .../org/apache/spark/TaskEndReason.scala | 35 ++++++++++++++++++- .../org/apache/spark/executor/Executor.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 4 +-- .../spark/shuffle/FetchFailedException.scala | 17 +++++++-- .../hash/BlockStoreShuffleFetcher.scala | 5 ++- .../org/apache/spark/ui/jobs/StagePage.scala | 32 +++++++++++++++-- .../org/apache/spark/ui/jobs/StageTable.scala | 28 +++++++++++++-- .../org/apache/spark/util/JsonProtocol.scala | 5 ++- .../scala/org/apache/spark/util/Utils.scala | 24 ++++++------- .../ui/jobs/JobProgressListenerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 10 +++++- 12 files changed, 148 insertions(+), 30 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index a2220e761ac98..db57712c83503 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -120,6 +120,20 @@ pre { border: none; } +.stacktrace-details { + max-height: 300px; + overflow-y: auto; + margin: 0; + transition: max-height 0.5s ease-out, padding 0.5s ease-out; +} + +.stacktrace-details.collapsed { + max-height: 0; + padding-top: 0; + padding-bottom: 0; + border: none; +} + span.expand-additional-metrics { cursor: pointer; } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index f45b463fb6f62..af5fd8e0ac00c 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -83,15 +83,48 @@ case class FetchFailed( * :: DeveloperApi :: * Task failed due to a runtime exception. This is the most common failure case and also captures * user program exceptions. + * + * `stackTrace` contains the stack trace of the exception itself. It still exists for backward + * compatibility. It's better to use `this(e: Throwable, metrics: Option[TaskMetrics])` to + * create `ExceptionFailure` as it will handle the backward compatibility properly. + * + * `fullStackTrace` is a better representation of the stack trace because it contains the whole + * stack trace including the exception and its causes */ @DeveloperApi case class ExceptionFailure( className: String, description: String, stackTrace: Array[StackTraceElement], + fullStackTrace: String, metrics: Option[TaskMetrics]) extends TaskFailedReason { - override def toErrorString: String = Utils.exceptionString(className, description, stackTrace) + + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics) + } + + override def toErrorString: String = + if (fullStackTrace == null) { + // fullStackTrace is added in 1.2.0 + // If fullStackTrace is null, use the old error string for backward compatibility + exceptionString(className, description, stackTrace) + } else { + fullStackTrace + } + + /** + * Return a nice string representation of the exception, including the stack trace. + * Note: It does not include the exception's causes, and is only used for backward compatibility. + */ + private def exceptionString( + className: String, + description: String, + stackTrace: Array[StackTraceElement]): String = { + val desc = if (description == null) "" else description + val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n") + s"$className: $desc\n$st" + } } /** diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 96114571d6c77..caf4d76713d49 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -263,7 +263,7 @@ private[spark] class Executor( m.executorRunTime = serviceTime m.jvmGCTime = gcTime - startGCTime } - val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics) + val reason = new ExceptionFailure(t, metrics) execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // Don't forcibly exit unless the exception was inherently fatal, to avoid diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 96114c0423a9e..22449517d100f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1063,7 +1063,7 @@ class DAGScheduler( if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some("Fetch failure: " + failureMessage)) + markStageAsFinished(failedStage, Some(failureMessage)) runningStages -= failedStage } @@ -1094,7 +1094,7 @@ class DAGScheduler( handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) } - case ExceptionFailure(className, description, stackTrace, metrics) => + case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures case TaskResultLost => diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 0c1b6f4defdb3..be184464e0ae9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -32,10 +32,21 @@ private[spark] class FetchFailedException( shuffleId: Int, mapId: Int, reduceId: Int, - message: String) - extends Exception(message) { + message: String, + cause: Throwable = null) + extends Exception(message, cause) { + + def this( + bmAddress: BlockManagerId, + shuffleId: Int, + mapId: Int, + reduceId: Int, + cause: Throwable) { + this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) + } - def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) + def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, + Utils.exceptionString(this)) } /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 0d5247f4176d4..e3e7434df45b0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -25,7 +25,7 @@ import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.util.CompletionIterator private[hash] object BlockStoreShuffleFetcher extends Logging { def fetch[T]( @@ -64,8 +64,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blockId match { case ShuffleBlockId(shufId, mapId, _) => val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, - Utils.exceptionString(e)) + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block", e) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 63ed5fc4949c2..250bddbe2f262 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -22,6 +22,8 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.executor.TaskMetrics import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} import org.apache.spark.ui.jobs.UIData._ @@ -436,13 +438,37 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {diskBytesSpilledReadable} }} - - {errorMessage.map { e =>
    {e}
    }.getOrElse("")} - + {errorMessageCell(errorMessage)} } } + private def errorMessageCell(errorMessage: Option[String]): Seq[Node] = { + val error = errorMessage.getOrElse("") + val isMultiline = error.indexOf('\n') >= 0 + // Display the first line by default + val errorSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + error.substring(0, error.indexOf('\n')) + } else { + error + }) + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + {errorSummary}{details} + } + private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { val totalExecutionTime = { if (info.gettingResultTime > 0) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 4ee7f08ab47a2..3b4866e05956d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -22,6 +22,8 @@ import scala.xml.Text import java.util.Date +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.scheduler.StageInfo import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.util.Utils @@ -195,7 +197,29 @@ private[ui] class FailedStageTable( override protected def stageRow(s: StageInfo): Seq[Node] = { val basicColumns = super.stageRow(s) - val failureReason =
    {s.failureReason.getOrElse("")}
    - basicColumns ++ failureReason + val failureReason = s.failureReason.getOrElse("") + val isMultiline = failureReason.indexOf('\n') >= 0 + // Display the first line by default + val failureReasonSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + failureReason.substring(0, failureReason.indexOf('\n')) + } else { + failureReason + }) + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + val failureReasonHtml = {failureReasonSummary}{details} + basicColumns ++ failureReasonHtml } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index f7ae1f7f334de..f15d0c856663f 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -287,6 +287,7 @@ private[spark] object JsonProtocol { ("Class Name" -> exceptionFailure.className) ~ ("Description" -> exceptionFailure.description) ~ ("Stack Trace" -> stackTrace) ~ + ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ ("Metrics" -> metrics) case ExecutorLostFailure(executorId) => ("Executor ID" -> executorId) @@ -637,8 +638,10 @@ private[spark] object JsonProtocol { val className = (json \ "Class Name").extract[String] val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") + val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace"). + map(_.extract[String]).orNull val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson) - new ExceptionFailure(className, description, stackTrace, metrics) + ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `executorLostFailure` => diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2cbd38d72caa1..a14d6125484fe 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1599,19 +1599,19 @@ private[spark] object Utils extends Logging { .orNull } - /** Return a nice string representation of the exception, including the stack trace. */ + /** + * Return a nice string representation of the exception. It will call "printStackTrace" to + * recursively generate the stack trace including the exception and its causes. + */ def exceptionString(e: Throwable): String = { - if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace) - } - - /** Return a nice string representation of the exception, including the stack trace. */ - def exceptionString( - className: String, - description: String, - stackTrace: Array[StackTraceElement]): String = { - val desc = if (description == null) "" else description - val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n") - s"$className: $desc\n$st" + if (e == null) { + "" + } else { + // Use e.printStackTrace here because e.getStackTrace doesn't include the cause + val stringWriter = new StringWriter() + e.printStackTrace(new PrintWriter(stringWriter)) + stringWriter.toString + } } /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 2efbae689771a..2608ad4b32e1e 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -116,7 +116,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val taskFailedReasons = Seq( Resubmitted, new FetchFailed(null, 0, 0, 0, "ignored"), - new ExceptionFailure("Exception", "description", null, None), + ExceptionFailure("Exception", "description", null, null, None), TaskResultLost, TaskKilled, ExecutorLostFailure("0"), diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index aec1e409db95c..39e69851e7e3c 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -109,7 +109,7 @@ class JsonProtocolSuite extends FunSuite { // TaskEndReason val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, "Some exception") - val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, None) + val exceptionFailure = new ExceptionFailure(exception, None) testTaskEndReason(Success) testTaskEndReason(Resubmitted) testTaskEndReason(fetchFailed) @@ -127,6 +127,13 @@ class JsonProtocolSuite extends FunSuite { testBlockId(StreamBlockId(1, 2L)) } + test("ExceptionFailure backward compatibility") { + val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None) + val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure) + .removeField({ _._1 == "Full Stack Trace" }) + assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + test("StageInfo backward compatibility") { val info = makeStageInfo(1, 2, 3, 4L, 5L) val newJson = JsonProtocol.stageInfoToJson(info) @@ -422,6 +429,7 @@ class JsonProtocolSuite extends FunSuite { assert(r1.className === r2.className) assert(r1.description === r2.description) assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals) + assert(r1.fullStackTrace === r2.fullStackTrace) assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => From c3f9ce1f9d7d4e1affcf3240768d55bc43858abd Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Fri, 7 Nov 2014 07:36:39 -0600 Subject: [PATCH 64/79] Removing new test until I rebase the repository. --- .../spark/sql/UserDefinedTypeSuite.scala | 98 ------------------- 1 file changed, 98 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala deleted file mode 100644 index 3c1505fd63e43..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.beans.{BeanInfo, BeanProperty} - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType -import org.apache.spark.sql.catalyst.types.UserDefinedType -import org.apache.spark.sql.test.TestSQLContext._ - -/* - * Note: the DSL conversions collide with the scalatest === operator! - * We can apply the scalatest conversion explicitly: - * assert(X === Y) --> assert(EQ(X).===(Y)) - */ -import org.scalatest.Assertions.{convertToEqualizer => EQ} - -@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) -private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { - override def equals(other: Any): Boolean = other match { - case v: MyDenseVector => - java.util.Arrays.equals(this.data, v.data) - case _ => false - } -} - -@BeanInfo -private[sql] case class MyLabeledPoint( - @BeanProperty label: Double, - @BeanProperty features: MyDenseVector) - -private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { - - override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - - override def serialize(obj: Any): Seq[Double] = { - obj match { - case features: MyDenseVector => - features.data.toSeq - } - } - - override def deserialize(datum: Any): MyDenseVector = { - datum match { - case data: Seq[_] => - new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) - } - } - - override def userClass = classOf[MyDenseVector] -} - -class UserDefinedTypeSuite extends QueryTest { - val points = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) - - - test("register user type: MyDenseVector for MyLabeledPoint") { - val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } - val labelsArrays: Array[Double] = labels.collect() - assert(EQ(labelsArrays.size).===(2)) - assert(labelsArrays.contains(1.0)) - assert(labelsArrays.contains(0.0)) - - val features: RDD[MyDenseVector] = - pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } - val featuresArrays: Array[MyDenseVector] = features.collect() - assert(EQ(featuresArrays.size).===(2)) - assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) - assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) - } - - test("UDTs and UDFs") { - registerFunction("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) - pointsRDD.registerTempTable("points") - checkAnswer( - sql("SELECT testType(features) from points"), - Seq(Row(true), Row(true))) - } -} From 908fc6a1df7a216b965ce75e129db93db6fd02b4 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Sun, 2 Nov 2014 20:57:51 -0600 Subject: [PATCH 65/79] Adding Timestamp and Date classes which support the standard comparison operators, as well as implicit conversions to support using these classes in the catalalyst DSL. This commit also adds a method to Row which builds a row from a schema and a list of strings. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../spark/sql/catalyst/dsl/package.scala | 33 ++++++- .../spark/sql/catalyst/expressions/Cast.scala | 4 +- .../sql/catalyst/expressions/Projection.scala | 35 +++++++ .../spark/sql/catalyst/expressions/Row.scala | 48 ++++++++- .../expressions/SpecificMutableRow.scala | 49 ++++++++++ .../sql/catalyst/expressions/literals.scala | 2 +- .../spark/sql/catalyst/types/dataTypes.scala | 5 +- .../spark/sql/catalyst/types/timetypes.scala | 97 +++++++++++++++++++ .../spark/sql/columnar/ColumnType.scala | 5 + .../scala/org/apache/spark/sql/package.scala | 38 +++++++- 11 files changed, 303 insertions(+), 15 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9cda373623cb5..68419834069b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -159,9 +159,9 @@ object ScalaReflection { case obj: LongType.JvmType => LongType case obj: FloatType.JvmType => FloatType case obj: DoubleType.JvmType => DoubleType - case obj: DateType.JvmType => DateType case obj: BigDecimal => DecimalType.Unlimited case obj: Decimal => DecimalType.Unlimited + case obj: DateType.JvmType => DateType case obj: TimestampType.JvmType => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 31dc5a58e68e5..ed75a183b7fa8 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -127,9 +127,9 @@ package object dsl { implicit def floatToLiteral(f: Float) = Literal(f) implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) - implicit def dateToLiteral(d: Date) = Literal(d) implicit def bigDecimalToLiteral(d: BigDecimal) = Literal(d) implicit def decimalToLiteral(d: Decimal) = Literal(d) + implicit def dateToLiteral(d: Date) = Literal(d) implicit def timestampToLiteral(t: Timestamp) = Literal(t) implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) @@ -148,6 +148,31 @@ package object dsl { def upper(e: Expression) = Upper(e) def lower(e: Expression) = Lower(e) + /* + * Conversions to provide the standard operators in the special case + * where a literal is being combined with a symbol. Without these an + * expression such as 0 < 'x is not recognized. + */ + implicit class InitialLiteral(x: Any) { + val literal = Literal(x) + def + (other: Symbol):Expression = {literal + other} + def - (other: Symbol):Expression = {literal - other} + def * (other: Symbol):Expression = {literal * other} + def / (other: Symbol):Expression = {literal / other} + def % (other: Symbol):Expression = {literal % other} + + def && (other: Symbol):Expression = {literal && other} + def || (other: Symbol):Expression = {literal || other} + + def < (other: Symbol):Expression = {literal < other} + def <= (other: Symbol):Expression = {literal <= other} + def > (other: Symbol):Expression = {literal > other} + def >= (other: Symbol):Expression = {literal >= other} + def === (other: Symbol):Expression = {literal === other} + def <=> (other: Symbol):Expression = {literal <=> other} + def !== (other: Symbol):Expression = {literal !== other} + } + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { @@ -184,9 +209,6 @@ package object dsl { /** Creates a new AttributeReference of type string */ def string = AttributeReference(s, StringType, nullable = true)() - /** Creates a new AttributeReference of type date */ - def date = AttributeReference(s, DateType, nullable = true)() - /** Creates a new AttributeReference of type decimal */ def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)() @@ -194,6 +216,9 @@ package object dsl { def decimal(precision: Int, scale: Int) = AttributeReference(s, DecimalType(precision, scale), nullable = true)() + /** Creates a new AttributeReference of type date */ + def date = AttributeReference(s, DateType, nullable = true)() + /** Creates a new AttributeReference of type timestamp */ def timestamp = AttributeReference(s, TimestampType, nullable = true)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 22009666196a1..38172eb1a50fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -31,8 +31,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w override def nullable = (child.dataType, dataType) match { case (StringType, _: NumericType) => true - case (StringType, TimestampType) => true case (StringType, DateType) => true + case (StringType, TimestampType) => true case (_: NumericType, DateType) => true case (BooleanType, DateType) => true case (DateType, _: NumericType) => true @@ -333,8 +333,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case dt if dt == child.dataType => identity[Any] case StringType => castToString case BinaryType => castToBinary - case DateType => castToDate case decimal: DecimalType => castToDecimal(decimal) + case DateType => castToDate case TimestampType => castToTimestamp case BooleanType => castToBoolean case ByteType => castToByte diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index e7e81a21fdf03..4391bcede66d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -139,6 +140,12 @@ class JoinedRow extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + def getDate(i: Int): Date = + if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -231,6 +238,13 @@ class JoinedRow2 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def getDate(i: Int): Date = + if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -317,6 +331,13 @@ class JoinedRow3 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def getDate(i: Int): Date = + if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -403,6 +424,13 @@ class JoinedRow4 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def getDate(i: Int): Date = + if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -489,6 +517,13 @@ class JoinedRow5 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def getDate(i: Int): Date = + if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index d00ec39774c35..738a5f3a53f3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.catalyst.types._ +import java.sql.{Date, Timestamp} +import java.math.BigDecimal object Row { /** @@ -42,6 +44,31 @@ object Row { * This method can be used to construct a [[Row]] from a [[Seq]] of values. */ def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) + + /** + * This method can be used to construct a [[Row]] from a [[Seq]] of Strings, + * converting each item to the type specified in a [[StructType]] schema. + * Only primitive types can be used. + */ + def fromStringsBySchema(strings: Seq[String], schema: StructType): Row = { + val values = for { + (field, str) <- schema.fields zip strings + item = field.dataType match { + case IntegerType => str.toInt + case LongType => str.toLong + case DoubleType => str.toDouble + case FloatType => str.toFloat + case ByteType => str.toByte + case ShortType => str.toShort + case StringType => str + case BooleanType => (str != "") + case DateType => Date.valueOf(str) + case TimestampType => Timestamp.valueOf(str) + case DecimalType() => new BigDecimal(str) + } + } yield item + new GenericRow(values.toArray) + } } /** @@ -64,6 +91,8 @@ trait Row extends Seq[Any] with Serializable { def getShort(i: Int): Short def getByte(i: Int): Byte def getString(i: Int): String + def getDate(i: Int): Date + def getTimestamp(i: Int): Timestamp def getAs[T](i: Int): T = apply(i).asInstanceOf[T] override def toString() = @@ -99,6 +128,8 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) + def setDate(ordinal: Int, value: Date) + def setTimestamp(ordinal: Int, value: Timestamp) } /** @@ -119,6 +150,9 @@ object EmptyRow extends Row { def getShort(i: Int): Short = throw new UnsupportedOperationException def getByte(i: Int): Byte = throw new UnsupportedOperationException def getString(i: Int): String = throw new UnsupportedOperationException + def getDate(i: Int): Date = throw new UnsupportedOperationException + def getTimestamp(i: Int): Timestamp = throw new UnsupportedOperationException + override def getAs[T](i: Int): T = throw new UnsupportedOperationException def copy() = this @@ -183,6 +217,16 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { values(i).asInstanceOf[String] } + def getDate(i: Int): Date = { + if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") + values(i).asInstanceOf[Date] + } + + def getTimestamp(i: Int): Timestamp = { + if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") + values(i).asInstanceOf[Timestamp] + } + // Custom hashCode function that matches the efficient code generated version. override def hashCode(): Int = { var result: Int = 37 @@ -226,6 +270,8 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } + override def setDate(ordinal: Int,value: Date): Unit = { values(ordinal) = value } + override def setTimestamp(ordinal: Int,value: Timestamp): Unit = { values(ordinal) = value } override def setNullAt(i: Int): Unit = { values(i) = null } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 570379c533e1f..34a10cf4a6945 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types._ +import java.sql.{Date, Timestamp} /** * A parent class for mutable container objects that are reused when the values are changed, @@ -169,6 +170,35 @@ final class MutableByte extends MutableValue { newCopy.asInstanceOf[this.type] } } +final class MutableDate extends MutableValue { + var value: Date = new Date(0) + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Date] + } + def copy() = { + val newCopy = new MutableDate + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableTimestamp extends MutableValue { + var value: Timestamp = new Timestamp(0) + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Timestamp] + } + def copy() = { + val newCopy = new MutableTimestamp + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} final class MutableAny extends MutableValue { var value: Any = _ @@ -307,6 +337,25 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).asInstanceOf[MutableByte].value } + override def setDate(ordinal: Int, value: Date): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableDate] + currentValue.isNull = false + currentValue.value = value + } + + override def getDate(i: Int): Date = { + values(i).asInstanceOf[MutableDate].value + } + override def setTimestamp(ordinal: Int, value: Timestamp): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableTimestamp] + currentValue.isNull = false + currentValue.value = value + } + + override def getTimestamp(i: Int): Timestamp = { + values(i).asInstanceOf[MutableTimestamp].value + } + override def getAs[T](i: Int): T = { values(i).boxed.asInstanceOf[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 93c19325151bf..548a9185998c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -34,8 +34,8 @@ object Literal { case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: Decimal => Literal(d, DecimalType.Unlimited) - case t: Timestamp => Literal(t, TimestampType) case d: Date => Literal(d, DateType) + case t: Timestamp => Literal(t, TimestampType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 5dd19dd12d8dd..67f9bc0378cfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -100,9 +100,9 @@ object DataType { | "LongType" ^^^ LongType | "BinaryType" ^^^ BinaryType | "BooleanType" ^^^ BooleanType - | "DateType" ^^^ DateType | "DecimalType()" ^^^ DecimalType.Unlimited | fixedDecimalType + | "DateType" ^^^ DateType | "TimestampType" ^^^ TimestampType ) @@ -195,7 +195,8 @@ case object NullType extends DataType object NativeType { val all = Seq( - IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, + ByteType, StringType, DateType, TimestampType) def unapply(dt: DataType): Boolean = all.contains(dt) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala new file mode 100644 index 0000000000000..189412b312ccc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date => JDate, Timestamp => JTimestamp} +import scala.language.implicitConversions + +/* + * Subclass of java.sql.Date which provides the usual comparison + * operators (as required for catalyst expressions) and which can + * be constructed from a string. + * + * scala> val d1 = Date("2014-02-01") + * d1: Date = 2014-02-01 + * + * scala> val d2 = Date("2014-02-02") + * d2: Date = 2014-02-02 + * + * scala> d1 < d2 + * res1: Boolean = true + */ + +class Date(milliseconds: Long) extends JDate(milliseconds) { + def <(that: Date): Boolean = this.before(that) + def >(that: Date): Boolean = this.after(that) + def <=(that: Date): Boolean = (this.before(that) || this.equals(that)) + def >=(that: Date): Boolean = (this.after(that) || this.equals(that)) + def ===(that: Date): Boolean = this.equals(that) +} + +object Date { + def apply(init: String) = new Date(JDate.valueOf(init).getTime) +} + +/* + * Analogous subclass of java.sql.Timestamp. + * + * scala> val ts1 = Timestamp("2014-03-04 12:34:56.12") + * ts1: Timestamp = 2014-03-04 12:34:56.12 + * + * scala> val ts2 = Timestamp("2014-03-04 12:34:56.13") + * ts2: Timestamp = 2014-03-04 12:34:56.13 + * + * scala> ts1 < ts2 + * res13: Boolean = true + */ + +class Timestamp(milliseconds: Long) extends JTimestamp(milliseconds) { + def <(that: Timestamp): Boolean = this.before(that) + def >(that: Timestamp): Boolean = this.after(that) + def <=(that: Timestamp): Boolean = (this.before(that) || this.equals(that)) + def >=(that: Timestamp): Boolean = (this.after(that) || this.equals(that)) + def ===(that: Timestamp): Boolean = this.equals(that) +} + +object Timestamp { + def apply(init: String) = new Timestamp(JTimestamp.valueOf(init).getTime) +} + +/* + * Implicit conversions. + */ + +object TimeConversions { + + implicit def JDateToDate(jdate: JDate): Date = { + new Date(jdate.getTime) + } + + implicit def JTimestampToTimestamp(jtimestamp: JTimestamp): Timestamp = { + new Timestamp(jtimestamp.getTime) + } + + implicit def DateToJDate(date: Date): JDate = { + new JDate(date.getTime) + } + + implicit def TimestampToJTimestamp(timestamp: Timestamp): JTimestamp = { + new JTimestamp(timestamp.getTime) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index ab66c85c4f242..475b65c2798c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -372,6 +372,11 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = { row(ordinal) = value } + + def append(v: Date, buffer: ByteBuffer) { + buffer.putLong(v.getTime) + } + } private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 51dad54f1a3f3..e9a69a0388606 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -162,22 +162,22 @@ package object sql { /** * :: DeveloperApi :: * - * The data type representing `java.sql.Timestamp` values. + * The data type representing `java.sql.Date` values. * * @group dataType */ @DeveloperApi - val TimestampType = catalyst.types.TimestampType + val DateType = catalyst.types.DateType /** * :: DeveloperApi :: * - * The data type representing `java.sql.Date` values. + * The data type representing `java.sql.Timestamp` values. * * @group dataType */ @DeveloperApi - val DateType = catalyst.types.DateType + val TimestampType = catalyst.types.TimestampType /** * :: DeveloperApi :: @@ -460,4 +460,34 @@ package object sql { */ @DeveloperApi type MetadataBuilder = catalyst.util.MetadataBuilder + + /** + * :: DeveloperApi :: + * + * A Timestamp class which support the standard comparison + * operators, for use in DSL expressions. Implicit conversions to + * java.sql.Date are provided. The class intializer accepts a + * String, e.g. + * + * val ts = Date("2014-01-01") + * + * @group dataType + */ + @DeveloperApi + val Date = catalyst.expressions.Date + + /** + * :: DeveloperApi :: + * + * A Timestamp class which support the standard comparison + * operators, for use in DSL expressions. Implicit conversions to + * java.sql.timestamp are provided. The class intializer accepts a + * String, e.g. + * + * val ts = Timestamp("2014-01-01 12:34:56.78") + * + * @group timeClasses + */ + @DeveloperApi + val Timestamp = catalyst.expressions.Timestamp } From dc2ed7237140f9de62c1979a0a714543a42f3bd3 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Sun, 2 Nov 2014 21:11:35 -0600 Subject: [PATCH 66/79] Correcting a typo in the documentation. --- sql/core/src/main/scala/org/apache/spark/sql/package.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index e9a69a0388606..35a17a45fd65f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -464,10 +464,9 @@ package object sql { /** * :: DeveloperApi :: * - * A Timestamp class which support the standard comparison - * operators, for use in DSL expressions. Implicit conversions to - * java.sql.Date are provided. The class intializer accepts a - * String, e.g. + * A Date class which support the standard comparison operators, for + * use in DSL expressions. Implicit conversions to java.sql.Date + * are provided. The class intializer accepts a String, e.g. * * val ts = Date("2014-01-01") * From b0036199d4910073e3318009fcfe04dc81fd7fba Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Mon, 3 Nov 2014 08:45:43 -0600 Subject: [PATCH 67/79] Correcting the bugs and issues pointed out in liancheng's very helpful comments. --- .../spark/sql/catalyst/dsl/package.scala | 47 ++++++++++++------- .../sql/catalyst/expressions/Projection.scala | 34 -------------- .../spark/sql/catalyst/expressions/Row.scala | 27 ----------- .../expressions/SpecificMutableRow.scala | 7 --- .../spark/sql/catalyst/types/timetypes.scala | 26 +++++----- .../scala/org/apache/spark/sql/package.scala | 10 ++-- 6 files changed, 50 insertions(+), 101 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index ed75a183b7fa8..6b30115177358 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -153,26 +153,39 @@ package object dsl { * where a literal is being combined with a symbol. Without these an * expression such as 0 < 'x is not recognized. */ - implicit class InitialLiteral(x: Any) { + case class LhsLiteral(x: Any) { val literal = Literal(x) - def + (other: Symbol):Expression = {literal + other} - def - (other: Symbol):Expression = {literal - other} - def * (other: Symbol):Expression = {literal * other} - def / (other: Symbol):Expression = {literal / other} - def % (other: Symbol):Expression = {literal % other} - - def && (other: Symbol):Expression = {literal && other} - def || (other: Symbol):Expression = {literal || other} - - def < (other: Symbol):Expression = {literal < other} - def <= (other: Symbol):Expression = {literal <= other} - def > (other: Symbol):Expression = {literal > other} - def >= (other: Symbol):Expression = {literal >= other} - def === (other: Symbol):Expression = {literal === other} - def <=> (other: Symbol):Expression = {literal <=> other} - def !== (other: Symbol):Expression = {literal !== other} + def + (other: Symbol): Expression = literal + other + def - (other: Symbol): Expression = literal - other + def * (other: Symbol): Expression = literal * other + def / (other: Symbol): Expression = literal / other + def % (other: Symbol): Expression = literal % other + + def && (other: Symbol): Expression = literal && other + def || (other: Symbol): Expression = literal || other + + def < (other: Symbol): Expression = literal < other + def <= (other: Symbol): Expression = literal <= other + def > (other: Symbol): Expression = literal > other + def >= (other: Symbol): Expression = literal >= other + def === (other: Symbol): Expression = literal === other + def <=> (other: Symbol): Expression = literal <=> other + def !== (other: Symbol): Expression = literal !== other } + implicit def booleanToLhsLiteral(b: Boolean) = new LhsLiteral(b) + implicit def byteToLhsLiteral(b: Byte) = new LhsLiteral(b) + implicit def shortToLhsLiteral(s: Short) = new LhsLiteral(s) + implicit def intToLhsLiteral(i: Int) = new LhsLiteral(i) + implicit def longToLhsLiteral(l: Long) = new LhsLiteral(l) + implicit def floatToLhsLiteral(f: Float) = new LhsLiteral(f) + implicit def doubleToLhsLiteral(d: Double) = new LhsLiteral(d) + implicit def stringToLhsLiteral(s: String) = new LhsLiteral(s) + implicit def bigDecimalToLhsLiteral(d: BigDecimal) = new LhsLiteral(d) + implicit def decimalToLhsLiteral(d: Decimal) = new LhsLiteral(d) + implicit def dateToLhsLiteral(d: Date) = new LhsLiteral(d) + implicit def timestampToLhsLiteral(t: Timestamp) = new LhsLiteral(t) + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 4391bcede66d1..45b5e6e2c289a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -140,12 +140,6 @@ class JoinedRow extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) - def getDate(i: Int): Date = - if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) - - def getTimestamp(i: Int): Timestamp = - if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) - override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -238,13 +232,6 @@ class JoinedRow2 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) - - def getDate(i: Int): Date = - if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) - - def getTimestamp(i: Int): Timestamp = - if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) - override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -331,13 +318,6 @@ class JoinedRow3 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) - - def getDate(i: Int): Date = - if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) - - def getTimestamp(i: Int): Timestamp = - if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) - override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -424,13 +404,6 @@ class JoinedRow4 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) - - def getDate(i: Int): Date = - if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) - - def getTimestamp(i: Int): Timestamp = - if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) - override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) @@ -517,13 +490,6 @@ class JoinedRow5 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) - - def getDate(i: Int): Date = - if (i < row1.size) row1.getDate(i) else row2.getDate(i - row1.size) - - def getTimestamp(i: Int): Timestamp = - if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) - override def getAs[T](i: Int): T = if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 738a5f3a53f3c..5c0864c896628 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -44,31 +44,6 @@ object Row { * This method can be used to construct a [[Row]] from a [[Seq]] of values. */ def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) - - /** - * This method can be used to construct a [[Row]] from a [[Seq]] of Strings, - * converting each item to the type specified in a [[StructType]] schema. - * Only primitive types can be used. - */ - def fromStringsBySchema(strings: Seq[String], schema: StructType): Row = { - val values = for { - (field, str) <- schema.fields zip strings - item = field.dataType match { - case IntegerType => str.toInt - case LongType => str.toLong - case DoubleType => str.toDouble - case FloatType => str.toFloat - case ByteType => str.toByte - case ShortType => str.toShort - case StringType => str - case BooleanType => (str != "") - case DateType => Date.valueOf(str) - case TimestampType => Timestamp.valueOf(str) - case DecimalType() => new BigDecimal(str) - } - } yield item - new GenericRow(values.toArray) - } } /** @@ -91,8 +66,6 @@ trait Row extends Seq[Any] with Serializable { def getShort(i: Int): Short def getByte(i: Int): Byte def getString(i: Int): String - def getDate(i: Int): Date - def getTimestamp(i: Int): Timestamp def getAs[T](i: Int): T = apply(i).asInstanceOf[T] override def toString() = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 34a10cf4a6945..0c4eb53ee81f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -343,19 +343,12 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR currentValue.value = value } - override def getDate(i: Int): Date = { - values(i).asInstanceOf[MutableDate].value - } override def setTimestamp(ordinal: Int, value: Timestamp): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableTimestamp] currentValue.isNull = false currentValue.value = value } - override def getTimestamp(i: Int): Timestamp = { - values(i).asInstanceOf[MutableTimestamp].value - } - override def getAs[T](i: Int): T = { values(i).boxed.asInstanceOf[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala index 189412b312ccc..fcb77e640a8ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/timetypes.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Date => JDate, Timestamp => JTimestamp} +import java.sql.{Date, Timestamp} import scala.language.implicitConversions /* @@ -35,7 +35,7 @@ import scala.language.implicitConversions * res1: Boolean = true */ -class Date(milliseconds: Long) extends JDate(milliseconds) { +class RichDate(milliseconds: Long) extends Date(milliseconds) { def <(that: Date): Boolean = this.before(that) def >(that: Date): Boolean = this.after(that) def <=(that: Date): Boolean = (this.before(that) || this.equals(that)) @@ -43,8 +43,8 @@ class Date(milliseconds: Long) extends JDate(milliseconds) { def ===(that: Date): Boolean = this.equals(that) } -object Date { - def apply(init: String) = new Date(JDate.valueOf(init).getTime) +object RichDate { + def apply(init: String) = new RichDate(Date.valueOf(init).getTime) } /* @@ -60,7 +60,7 @@ object Date { * res13: Boolean = true */ -class Timestamp(milliseconds: Long) extends JTimestamp(milliseconds) { +class RichTimestamp(milliseconds: Long) extends Timestamp(milliseconds) { def <(that: Timestamp): Boolean = this.before(that) def >(that: Timestamp): Boolean = this.after(that) def <=(that: Timestamp): Boolean = (this.before(that) || this.equals(that)) @@ -68,8 +68,8 @@ class Timestamp(milliseconds: Long) extends JTimestamp(milliseconds) { def ===(that: Timestamp): Boolean = this.equals(that) } -object Timestamp { - def apply(init: String) = new Timestamp(JTimestamp.valueOf(init).getTime) +object RichTimestamp { + def apply(init: String) = new RichTimestamp(Timestamp.valueOf(init).getTime) } /* @@ -78,20 +78,20 @@ object Timestamp { object TimeConversions { - implicit def JDateToDate(jdate: JDate): Date = { + implicit def javaDateToRichDate(jdate: Date): RichDate = { new Date(jdate.getTime) } - implicit def JTimestampToTimestamp(jtimestamp: JTimestamp): Timestamp = { + implicit def javaTimestampToRichTimestamp(jtimestamp: Timestamp): RichTimestamp = { new Timestamp(jtimestamp.getTime) } - implicit def DateToJDate(date: Date): JDate = { - new JDate(date.getTime) + implicit def richDateToJavaDate(date: RichDate): Date = { + new Date(date.getTime) } - implicit def TimestampToJTimestamp(timestamp: Timestamp): JTimestamp = { - new JTimestamp(timestamp.getTime) + implicit def richTimestampToJavaTimestamp(timestamp: RichTimestamp): Timestamp = { + new Timestamp(timestamp.getTime) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 35a17a45fd65f..9faa24a085976 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -468,12 +468,14 @@ package object sql { * use in DSL expressions. Implicit conversions to java.sql.Date * are provided. The class intializer accepts a String, e.g. * - * val ts = Date("2014-01-01") + * {{{ + * val d = Date("2014-01-01") + * }}} * * @group dataType */ @DeveloperApi - val Date = catalyst.expressions.Date + val Date = catalyst.expressions.RichDate /** * :: DeveloperApi :: @@ -483,10 +485,12 @@ package object sql { * java.sql.timestamp are provided. The class intializer accepts a * String, e.g. * + * {{{ * val ts = Timestamp("2014-01-01 12:34:56.78") + * }}} * * @group timeClasses */ @DeveloperApi - val Timestamp = catalyst.expressions.Timestamp + val Timestamp = catalyst.expressions.RichTimestamp } From a006ddb1a6e0ce4aa80a3b1f4d51570d572f2741 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Mon, 3 Nov 2014 14:32:10 -0600 Subject: [PATCH 68/79] Make implicit conversions for Literal op Symbol return a specific type, e.g. Add(1, 'x). --- .../spark/sql/catalyst/dsl/package.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 6b30115177358..a32f1d6080df4 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -155,22 +155,22 @@ package object dsl { */ case class LhsLiteral(x: Any) { val literal = Literal(x) - def + (other: Symbol): Expression = literal + other - def - (other: Symbol): Expression = literal - other - def * (other: Symbol): Expression = literal * other - def / (other: Symbol): Expression = literal / other - def % (other: Symbol): Expression = literal % other - - def && (other: Symbol): Expression = literal && other - def || (other: Symbol): Expression = literal || other - - def < (other: Symbol): Expression = literal < other - def <= (other: Symbol): Expression = literal <= other - def > (other: Symbol): Expression = literal > other - def >= (other: Symbol): Expression = literal >= other - def === (other: Symbol): Expression = literal === other - def <=> (other: Symbol): Expression = literal <=> other - def !== (other: Symbol): Expression = literal !== other + def + (other: Symbol) = Add(literal, other) + def - (other: Symbol) = Subtract(literal, other) + def * (other: Symbol) = Multiply(literal, other) + def / (other: Symbol) = Divide(literal, other) + def % (other: Symbol) = Remainder(literal, other) + + def && (other: Symbol) = And(literal, other) + def || (other: Symbol) = Or(literal, other) + + def < (other: Symbol) = LessThan(literal, other) + def <= (other: Symbol) = LessThanOrEqual(literal, other) + def > (other: Symbol) = GreaterThan(literal, other) + def >= (other: Symbol) = GreaterThanOrEqual(literal, other) + def === (other: Symbol) = EqualTo(literal, other) + def <=> (other: Symbol) = EqualNullSafe(literal, other) + def !== (other: Symbol) = Not(EqualTo(literal, other)) } implicit def booleanToLhsLiteral(b: Boolean) = new LhsLiteral(b) From 19352898545a12804e1b0ce0cc2ceec2d3509df2 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Tue, 4 Nov 2014 00:14:13 -0600 Subject: [PATCH 69/79] Reversed random line permutations. Eliminated all getters and setters for Date and Timestamp. Added Date and Timestamp to NativeType.defaultSizeOf. --- .../spark/sql/catalyst/dsl/package.scala | 8 ++++---- .../spark/sql/catalyst/expressions/Cast.scala | 4 ++-- .../sql/catalyst/expressions/Projection.scala | 2 -- .../spark/sql/catalyst/expressions/Row.scala | 19 +------------------ .../expressions/SpecificMutableRow.scala | 12 ------------ .../sql/catalyst/expressions/literals.scala | 2 +- .../spark/sql/catalyst/types/dataTypes.scala | 4 +++- .../spark/sql/columnar/ColumnType.scala | 5 ----- .../scala/org/apache/spark/sql/package.scala | 8 ++++---- 9 files changed, 15 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index a32f1d6080df4..4b05ac62b9dc6 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -127,9 +127,9 @@ package object dsl { implicit def floatToLiteral(f: Float) = Literal(f) implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) + implicit def dateToLiteral(d: Date) = Literal(d) implicit def bigDecimalToLiteral(d: BigDecimal) = Literal(d) implicit def decimalToLiteral(d: Decimal) = Literal(d) - implicit def dateToLiteral(d: Date) = Literal(d) implicit def timestampToLiteral(t: Timestamp) = Literal(t) implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) @@ -222,6 +222,9 @@ package object dsl { /** Creates a new AttributeReference of type string */ def string = AttributeReference(s, StringType, nullable = true)() + /** Creates a new AttributeReference of type date */ + def date = AttributeReference(s, DateType, nullable = true)() + /** Creates a new AttributeReference of type decimal */ def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)() @@ -229,9 +232,6 @@ package object dsl { def decimal(precision: Int, scale: Int) = AttributeReference(s, DecimalType(precision, scale), nullable = true)() - /** Creates a new AttributeReference of type date */ - def date = AttributeReference(s, DateType, nullable = true)() - /** Creates a new AttributeReference of type timestamp */ def timestamp = AttributeReference(s, TimestampType, nullable = true)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 38172eb1a50fb..22009666196a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -31,8 +31,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w override def nullable = (child.dataType, dataType) match { case (StringType, _: NumericType) => true - case (StringType, DateType) => true case (StringType, TimestampType) => true + case (StringType, DateType) => true case (_: NumericType, DateType) => true case (BooleanType, DateType) => true case (DateType, _: NumericType) => true @@ -333,8 +333,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case dt if dt == child.dataType => identity[Any] case StringType => castToString case BinaryType => castToBinary - case decimal: DecimalType => castToDecimal(decimal) case DateType => castToDate + case decimal: DecimalType => castToDecimal(decimal) case TimestampType => castToTimestamp case BooleanType => castToBoolean case ByteType => castToByte diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 45b5e6e2c289a..80a54eb74a352 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Date, Timestamp} - /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. * @param expressions a sequence of expressions that determine the value of each column of the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 5c0864c896628..99b9e6efbab90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.NativeType import java.sql.{Date, Timestamp} import java.math.BigDecimal @@ -101,8 +101,6 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) - def setDate(ordinal: Int, value: Date) - def setTimestamp(ordinal: Int, value: Timestamp) } /** @@ -123,9 +121,6 @@ object EmptyRow extends Row { def getShort(i: Int): Short = throw new UnsupportedOperationException def getByte(i: Int): Byte = throw new UnsupportedOperationException def getString(i: Int): String = throw new UnsupportedOperationException - def getDate(i: Int): Date = throw new UnsupportedOperationException - def getTimestamp(i: Int): Timestamp = throw new UnsupportedOperationException - override def getAs[T](i: Int): T = throw new UnsupportedOperationException def copy() = this @@ -190,16 +185,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { values(i).asInstanceOf[String] } - def getDate(i: Int): Date = { - if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") - values(i).asInstanceOf[Date] - } - - def getTimestamp(i: Int): Timestamp = { - if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") - values(i).asInstanceOf[Timestamp] - } - // Custom hashCode function that matches the efficient code generated version. override def hashCode(): Int = { var result: Int = 37 @@ -243,8 +228,6 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } - override def setDate(ordinal: Int,value: Date): Unit = { values(ordinal) = value } - override def setTimestamp(ordinal: Int,value: Timestamp): Unit = { values(ordinal) = value } override def setNullAt(i: Int): Unit = { values(i) = null } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 0c4eb53ee81f2..9f977bf6c2a0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -337,18 +337,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).asInstanceOf[MutableByte].value } - override def setDate(ordinal: Int, value: Date): Unit = { - val currentValue = values(ordinal).asInstanceOf[MutableDate] - currentValue.isNull = false - currentValue.value = value - } - - override def setTimestamp(ordinal: Int, value: Timestamp): Unit = { - val currentValue = values(ordinal).asInstanceOf[MutableTimestamp] - currentValue.isNull = false - currentValue.value = value - } - override def getAs[T](i: Int): T = { values(i).boxed.asInstanceOf[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 548a9185998c3..93c19325151bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -34,8 +34,8 @@ object Literal { case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: Decimal => Literal(d, DecimalType.Unlimited) - case d: Date => Literal(d, DateType) case t: Timestamp => Literal(t, TimestampType) + case d: Date => Literal(d, DateType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 67f9bc0378cfe..1d4c4783154bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -100,9 +100,9 @@ object DataType { | "LongType" ^^^ LongType | "BinaryType" ^^^ BinaryType | "BooleanType" ^^^ BooleanType + | "DateType" ^^^ DateType | "DecimalType()" ^^^ DecimalType.Unlimited | fixedDecimalType - | "DateType" ^^^ DateType | "TimestampType" ^^^ TimestampType ) @@ -208,6 +208,8 @@ object NativeType { FloatType -> 4, ShortType -> 2, ByteType -> 1, + DateType -> 8, + TimestampType -> 12, StringType -> 4096) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 475b65c2798c3..ab66c85c4f242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -372,11 +372,6 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = { row(ordinal) = value } - - def append(v: Date, buffer: ByteBuffer) { - buffer.putLong(v.getTime) - } - } private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 9faa24a085976..30b0c77d4c461 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -162,22 +162,22 @@ package object sql { /** * :: DeveloperApi :: * - * The data type representing `java.sql.Date` values. + * The data type representing `java.sql.Timestamp` values. * * @group dataType */ @DeveloperApi - val DateType = catalyst.types.DateType + val TimestampType = catalyst.types.TimestampType /** * :: DeveloperApi :: * - * The data type representing `java.sql.Timestamp` values. + * The data type representing `java.sql.Date` values. * * @group dataType */ @DeveloperApi - val TimestampType = catalyst.types.TimestampType + val DateType = catalyst.types.DateType /** * :: DeveloperApi :: From d59e5d930a6b0fafb5d8c06403a65df97adb0e74 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Tue, 4 Nov 2014 00:25:20 -0600 Subject: [PATCH 70/79] A couple more pointless changes undone. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 2 +- .../org/apache/spark/sql/catalyst/expressions/Projection.scala | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 68419834069b1..9cda373623cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -159,9 +159,9 @@ object ScalaReflection { case obj: LongType.JvmType => LongType case obj: FloatType.JvmType => FloatType case obj: DoubleType.JvmType => DoubleType + case obj: DateType.JvmType => DateType case obj: BigDecimal => DecimalType.Unlimited case obj: Decimal => DecimalType.Unlimited - case obj: DateType.JvmType => DateType case obj: TimestampType.JvmType => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 80a54eb74a352..e7e81a21fdf03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions + /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. * @param expressions a sequence of expressions that determine the value of each column of the From 43406fe03c9f9cdfc9a07ada1f948d12a1f201f8 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Wed, 5 Nov 2014 11:22:29 -0600 Subject: [PATCH 71/79] In tests where the scalatest assert conversion collides with the new DSL conversion (due to the existence of an === operator), applied the transformation assert(X == Y) --> assert(convertToEqualizer(X).===(Y)) --- .../ExpressionEvaluationSuite.scala | 89 ++--- .../sql/catalyst/trees/TreeNodeSuite.scala | 25 +- .../apache/spark/sql/CachedTableSuite.scala | 11 +- .../org/apache/spark/sql/DslQuerySuite.scala | 12 +- .../org/apache/spark/sql/JoinSuite.scala | 13 +- .../org/apache/spark/sql/SQLConfSuite.scala | 9 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 11 +- .../columnar/PartitionBatchPruningSuite.scala | 11 +- .../spark/sql/execution/PlannerSuite.scala | 17 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 311 +++++++++--------- .../spark/sql/hive/StatisticsSuite.scala | 25 +- .../sql/hive/execution/HiveQuerySuite.scala | 21 +- .../spark/sql/parquet/HiveParquetSuite.scala | 14 +- 13 files changed, 331 insertions(+), 238 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 6bfa0dbd65ba7..128a4860843ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -32,6 +32,13 @@ import org.apache.spark.sql.catalyst.types._ /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class ExpressionEvaluationSuite extends FunSuite { test("literals") { @@ -318,18 +325,18 @@ class ExpressionEvaluationSuite extends FunSuite { intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} - assert(("abcdef" cast StringType).nullable === false) - assert(("abcdef" cast BinaryType).nullable === false) - assert(("abcdef" cast BooleanType).nullable === false) - assert(("abcdef" cast TimestampType).nullable === true) - assert(("abcdef" cast LongType).nullable === true) - assert(("abcdef" cast IntegerType).nullable === true) - assert(("abcdef" cast ShortType).nullable === true) - assert(("abcdef" cast ByteType).nullable === true) - assert(("abcdef" cast DecimalType.Unlimited).nullable === true) - assert(("abcdef" cast DecimalType(4, 2)).nullable === true) - assert(("abcdef" cast DoubleType).nullable === true) - assert(("abcdef" cast FloatType).nullable === true) + assert(EQ(("abcdef" cast StringType).nullable).===(false)) + assert(EQ(("abcdef" cast BinaryType).nullable).===(false)) + assert(EQ(("abcdef" cast BooleanType).nullable).===(false)) + assert(EQ(("abcdef" cast TimestampType).nullable).===(true)) + assert(EQ(("abcdef" cast LongType).nullable).===(true)) + assert(EQ(("abcdef" cast IntegerType).nullable).===(true)) + assert(EQ(("abcdef" cast ShortType).nullable).===(true)) + assert(EQ(("abcdef" cast ByteType).nullable).===(true)) + assert(EQ(("abcdef" cast DecimalType.Unlimited).nullable).===(true)) + assert(EQ(("abcdef" cast DecimalType(4, 2)).nullable).===(true)) + assert(EQ(("abcdef" cast DoubleType).nullable).===(true)) + assert(EQ(("abcdef" cast FloatType).nullable).===(true)) checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null) } @@ -346,15 +353,15 @@ class ExpressionEvaluationSuite extends FunSuite { // - Values that would overflow the target precision should turn into null // - Because of this, casts to fixed-precision decimals should be nullable - assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false) - assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === false) - assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === false) - assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false) + assert(EQ(Cast(Literal(123), DecimalType.Unlimited).nullable).===(false)) + assert(EQ(Cast(Literal(10.03f), DecimalType.Unlimited).nullable).===(false)) + assert(EQ(Cast(Literal(10.03), DecimalType.Unlimited).nullable).===(false)) + assert(EQ(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable).===(false)) - assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true) + assert(EQ(Cast(Literal(123), DecimalType(2, 1)).nullable).===(true)) + assert(EQ(Cast(Literal(10.03f), DecimalType(2, 1)).nullable).===(true)) + assert(EQ(Cast(Literal(10.03), DecimalType(2, 1)).nullable).===(true)) + assert(EQ(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable).===(true)) checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123)) checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123)) @@ -500,26 +507,26 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) - assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) + assert(EQ(CaseWhen(Seq(c2, c4, c6)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4, c3, c5)).nullable).===(true)) val c4_notNull = 'a.boolean.notNull.at(3) val c5_notNull = 'a.boolean.notNull.at(4) val c6_notNull = 'a.boolean.notNull.at(5) - assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable).===(false)) + assert(EQ(CaseWhen(Seq(c2, c4, c6_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c6)).nullable).===(true)) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable).===(false)) + assert(EQ(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable).===(true)) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable).===(true)) + assert(EQ(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable).===(true)) } test("complex type") { @@ -559,11 +566,11 @@ class ExpressionEvaluationSuite extends FunSuite { :: StructField("b", StringType, nullable = false) :: Nil ) - assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) - assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) + assert(EQ(GetField(BoundReference(2,typeS, nullable = true), "a").nullable).===(true)) + assert(EQ(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable).===(false)) - assert(GetField(Literal(null, typeS), "a").nullable === true) - assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true) + assert(EQ(GetField(Literal(null, typeS), "a").nullable).===(true)) + assert(EQ(GetField(Literal(null, typeS_notNullable), "a").nullable).===(true)) checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) @@ -717,10 +724,10 @@ class ExpressionEvaluationSuite extends FunSuite { val s_notNull = 'a.string.notNull.at(0) - assert(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false) - assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true) + assert(EQ(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)).nullable).===(true)) + assert(EQ(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable).===(false)) + assert(EQ(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable).===(true)) + assert(EQ(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable).===(true)) checkEvaluation(s.substr(0, 2), "ex", row) checkEvaluation(s.substr(0), "example", row) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 036fd3fa1d6a1..82887dc9d4604 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -24,6 +24,13 @@ import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{StringType, NullType} +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class Dummy(optKey: Option[Expression]) extends Expression { def children = optKey.toSeq def nullable = true @@ -36,21 +43,21 @@ case class Dummy(optKey: Option[Expression]) extends Expression { class TreeNodeSuite extends FunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } - assert(after === Literal(2)) + assert(EQ(after).===(Literal(2))) } test("one child changed") { val before = Add(Literal(1), Literal(2)) val after = before transform { case Literal(2, _) => Literal(1) } - assert(after === Add(Literal(1), Literal(1))) + assert(EQ(after).===(Add(Literal(1), Literal(1)))) } test("no change") { val before = Add(Literal(1), Add(Literal(2), Add(Literal(3), Literal(4)))) val after = before transform { case Literal(5, _) => Literal(1)} - assert(before === after) + assert(EQ(before).===(after)) // Ensure that the objects after are the same objects before the transformation. before.map(identity[Expression]).zip(after.map(identity[Expression])).foreach { case (b, a) => assert(b eq a) @@ -61,7 +68,7 @@ class TreeNodeSuite extends FunSuite { val tree = Add(Literal(1), Add(Literal(2), Add(Literal(3), Literal(4)))) val literals = tree collect {case l: Literal => l} - assert(literals.size === 4) + assert(EQ(literals.size).===(4)) (1 to 4).foreach(i => assert(literals contains Literal(i))) } @@ -74,7 +81,7 @@ class TreeNodeSuite extends FunSuite { case l: Literal => actual.append(l.toString); l } - assert(expected === actual) + assert(EQ(expected).===(actual)) } test("post-order transform") { @@ -86,7 +93,7 @@ class TreeNodeSuite extends FunSuite { case l: Literal => actual.append(l.toString); l } - assert(expected === actual) + assert(EQ(expected).===(actual)) } test("transform works on nodes with Option children") { @@ -95,13 +102,13 @@ class TreeNodeSuite extends FunSuite { val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } var actual = dummy1 transformDown toZero - assert(actual === Dummy(Some(Literal(0)))) + assert(EQ(actual).===(Dummy(Some(Literal(0))))) actual = dummy1 transformUp toZero - assert(actual === Dummy(Some(Literal(0)))) + assert(EQ(actual).===(Dummy(Some(Literal(0))))) actual = dummy2 transform toZero - assert(actual === Dummy(None)) + assert(EQ(actual).===(Dummy(None))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 765fa82776341..d0fec34fd2cf3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,6 +22,13 @@ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class BigData(s: String) class CachedTableSuite extends QueryTest { @@ -74,7 +81,7 @@ class CachedTableSuite extends QueryTest { val data = "*" * 10000 sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData") table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(table("bigData").count() === 200000L) + assert(EQ(table("bigData").count()).===(200000L)) table("bigData").unpersist(blocking = true) } @@ -228,7 +235,7 @@ class CachedTableSuite extends QueryTest { table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum - assert(cached.statistics.sizeInBytes === actualSizeInBytes) + assert(EQ(cached.statistics.sizeInBytes).===(actualSizeInBytes)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index e70ad891eea36..687c5a2707587 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -24,6 +24,14 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.test.TestSQLContext._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + + class DslQuerySuite extends QueryTest { import org.apache.spark.sql.TestData._ @@ -171,7 +179,7 @@ class DslQuerySuite extends QueryTest { } test("count") { - assert(testData2.count() === testData2.map(_ => 1).count()) + assert(EQ(testData2.count()).===(testData2.map(_ => 1).count())) } test("null count") { @@ -192,7 +200,7 @@ class DslQuerySuite extends QueryTest { } test("zero count") { - assert(emptyTableData.count() === 0) + assert(EQ(emptyTableData.count()).===(0)) } test("except") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8b4cf5bac0187..0260eea467d1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -25,6 +25,13 @@ import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOu import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData @@ -34,7 +41,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val y = testData2.as('y) val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed val planned = planner.HashJoin(join) - assert(planned.size === 1) + assert(EQ(planned.size).===(1)) } def assertJoin(sqlString: String, c: Class[_]): Any = { @@ -50,7 +57,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: BroadcastNestedLoopJoin => j } - assert(operators.size === 1) + assert(EQ(operators.size).===(1)) if (operators(0).getClass() != c) { fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") } @@ -104,7 +111,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed val planned = planner.HashJoin(join) - assert(planned.size === 1) + assert(EQ(planned.size).===(1)) } test("inner join where, one match per row") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 60701f0e154f8..4a09ed517d6e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -24,6 +24,13 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class SQLConfSuite extends QueryTest with FunSuiteLike { val testKey = "test.key.0" @@ -38,7 +45,7 @@ class SQLConfSuite extends QueryTest with FunSuiteLike { test("programmatic ways of basic setting and getting") { clear() - assert(getAllConfs.size === 0) + assert(EQ(getAllConfs.size).===(0)) setConf(testKey, testVal) assert(getConf(testKey) == testVal) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index ef9b76b1e251e..95582eec06975 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -22,18 +22,25 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { test("Simple UDF") { registerFunction("strLenScala", (_: String).length) - assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) + assert(EQ(sql("SELECT strLenScala('test')").first().getInt(0)).===(4)) } test("TwoArgument UDF") { registerFunction("strLenScala", (_: String).length + (_:Int)) - assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) + assert(EQ(sql("SELECT strLenScala('test', 1)").first().getInt(0)).===(5)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 9ba3c210171bd..45086f88a27ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -22,6 +22,13 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { val originalColumnBatchSize = columnBatchSize val originalInMemoryPartitionPruning = inMemoryPartitionPruning @@ -107,8 +114,8 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head - assert(readBatches === expectedReadBatches, "Wrong number of read batches") - assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions") + assert(EQ(readBatches).===(expectedReadBatches), "Wrong number of read batches") + assert(EQ(readPartitions).===(expectedReadPartitions), "Wrong number of read partitions") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a5af71acfc79a..0ed380c63527a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -27,6 +27,13 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class PlannerSuite extends FunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan @@ -34,8 +41,8 @@ class PlannerSuite extends FunSuite { val logicalUnions = query collect { case u: logical.Union => u } val physicalUnions = planned collect { case u: execution.Union => u } - assert(logicalUnions.size === 2) - assert(physicalUnions.size === 1) + assert(EQ(logicalUnions.size).===(2)) + assert(EQ(physicalUnions.size).===(1)) } test("count is partially aggregated") { @@ -43,7 +50,7 @@ class PlannerSuite extends FunSuite { val planned = HashAggregation(query).head val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - assert(aggregations.size === 2) + assert(EQ(aggregations.size).===(2)) } test("count distinct is partially aggregated") { @@ -71,7 +78,7 @@ class PlannerSuite extends FunSuite { val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(EQ(broadcastHashJoins.size).===(1), "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) @@ -91,7 +98,7 @@ class PlannerSuite extends FunSuite { val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(EQ(broadcastHashJoins.size).===(1), "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 08d9da27f1b11..25031910c30de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -30,6 +30,13 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class TestRDDEntry(key: Int, value: String) case class NullReflectData( @@ -172,7 +179,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) var actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(EQ(actualCodec).===(TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -188,7 +195,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(EQ(actualCodec).===(TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -204,7 +211,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === "UNCOMPRESSED" :: Nil) + assert(EQ(actualCodec).===("UNCOMPRESSED" :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -220,7 +227,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(EQ(actualCodec).===(TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -236,7 +243,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(EQ(actualCodec).===(TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil)) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -285,8 +292,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } val result = query.collect() - assert(result.size === 9, "self-join result has incorrect size") - assert(result(0).size === 12, "result row has incorrect size") + assert(EQ(result.size).===(9), "self-join result has incorrect size") + assert(EQ(result(0).size).===(12), "result row has incorrect size") result.zipWithIndex.foreach { case (row, index) => row.zipWithIndex.foreach { case (field, column) => assert(field != null, s"self-join contains null value in row $index field $column") @@ -296,7 +303,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Import of simple Parquet file") { val result = parquetFile(ParquetTestData.testDir.toString).collect() - assert(result.size === 15) + assert(EQ(result.size).===(15)) result.zipWithIndex.foreach { case (row, index) => { val checkBoolean = @@ -304,12 +311,12 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA row(0) == true else row(0) == false - assert(checkBoolean === true, s"boolean field value in line $index did not match") - if (index % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match") - assert(row(2) === "abc", s"string field value in line $index did not match") - assert(row(3) === (index.toLong << 33), s"long value in line $index did not match") - assert(row(4) === 2.5F, s"float field value in line $index did not match") - assert(row(5) === 4.5D, s"double field value in line $index did not match") + assert(EQ(checkBoolean).===(true), s"boolean field value in line $index did not match") + if (index % 5 == 0) assert(EQ(row(1)).===(5), s"int field value in line $index did not match") + assert(EQ(row(2)).===("abc"), s"string field value in line $index did not match") + assert(EQ(row(3)).===((index.toLong << 33)), s"long value in line $index did not match") + assert(EQ(row(4)).===(2.5F), s"float field value in line $index did not match") + assert(EQ(row(5)).===(4.5D), s"double field value in line $index did not match") } } } @@ -319,11 +326,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA result.zipWithIndex.foreach { case (row, index) => { if (index % 3 == 0) - assert(row(0) === true, s"boolean field value in line $index did not match (every third row)") + assert(EQ(row(0)).===(true), s"boolean field value in line $index did not match (every third row)") else - assert(row(0) === false, s"boolean field value in line $index did not match") - assert(row(1) === (index.toLong << 33), s"long field value in line $index did not match") - assert(row.size === 2, s"number of columns in projection in line $index is incorrect") + assert(EQ(row(0)).===(false), s"boolean field value in line $index did not match") + assert(EQ(row(1)).===((index.toLong << 33)), s"long field value in line $index did not match") + assert(EQ(row.size).===(2), s"number of columns in projection in line $index is incorrect") } } } @@ -381,8 +388,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val rdd_copy = sql("SELECT * FROM tmpx").collect() val rdd_orig = rdd.collect() for(i <- 0 to 99) { - assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") - assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i") + assert(EQ(rdd_copy(i).apply(0)).===(rdd_orig(i).key), s"key error in line $i") + assert(EQ(rdd_copy(i).apply(1)).===(rdd_orig(i).value), s"value error in line $i") } Utils.deleteRecursively(file) } @@ -396,11 +403,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA dest_rdd.registerTempTable("dest") sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() val rdd_copy1 = sql("SELECT * FROM dest").collect() - assert(rdd_copy1.size === 100) + assert(EQ(rdd_copy1.size).===(100)) sql("INSERT INTO dest SELECT * FROM source") val rdd_copy2 = sql("SELECT * FROM dest").collect().sortBy(_.getInt(0)) - assert(rdd_copy2.size === 200) + assert(EQ(rdd_copy2.size).===(200)) Utils.deleteRecursively(dirname) } @@ -408,7 +415,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA sql("INSERT INTO testsource SELECT * FROM testsource") val double_rdd = sql("SELECT * FROM testsource").collect() assert(double_rdd != null) - assert(double_rdd.size === 30) + assert(EQ(double_rdd.size).===(30)) // let's restore the original test data Utils.deleteRecursively(ParquetTestData.testDir) @@ -425,7 +432,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val readFile = parquetFile(path) val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Seq.fill(5)(null)) + assert(EQ(rdd_saved(0)).===(Seq.fill(5)(null))) Utils.deleteRecursively(file) assert(true) } @@ -440,7 +447,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val readFile = parquetFile(path) val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Seq.fill(5)(null)) + assert(EQ(rdd_saved(0)).===(Seq.fill(5)(null))) Utils.deleteRecursively(file) assert(true) } @@ -478,11 +485,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val attribute2 = new AttributeReference("second", IntegerType, false)() val predicate5 = new GreaterThan(attribute1, attribute2) val badfilter = ParquetFilters.createFilter(predicate5) - assert(badfilter.isDefined === false) + assert(EQ(badfilter.isDefined).===(false)) val predicate6 = And(GreaterThan(attribute1, attribute2), GreaterThan(attribute1, attribute2)) val badfilter2 = ParquetFilters.createFilter(predicate6) - assert(badfilter2.isDefined === false) + assert(EQ(badfilter2.isDefined).===(false)) } test("test filter by predicate pushdown") { @@ -492,21 +499,21 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result1 = query1.collect() - assert(result1.size === 50) - assert(result1(0)(1) === 100) - assert(result1(49)(1) === 149) + assert(EQ(result1.size).===(50)) + assert(EQ(result1(0)(1)).===(100)) + assert(EQ(result1(49)(1)).===(149)) val query2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") assert( query2.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result2 = query2.collect() - assert(result2.size === 50) + assert(EQ(result2.size).===(50)) if (myval == "myint" || myval == "mylong") { - assert(result2(0)(1) === 151) - assert(result2(49)(1) === 200) + assert(EQ(result2(0)(1)).===(151)) + assert(EQ(result2(49)(1)).===(200)) } else { - assert(result2(0)(1) === 150) - assert(result2(49)(1) === 199) + assert(EQ(result2(0)(1)).===(150)) + assert(EQ(result2(49)(1)).===(199)) } } for(myval <- Seq("myint", "mylong")) { @@ -515,11 +522,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query3.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result3 = query3.collect() - assert(result3.size === 20) - assert(result3(0)(1) === 0) - assert(result3(9)(1) === 9) - assert(result3(10)(1) === 191) - assert(result3(19)(1) === 200) + assert(EQ(result3.size).===(20)) + assert(EQ(result3(0)(1)).===(0)) + assert(EQ(result3(9)(1)).===(9)) + assert(EQ(result3(10)(1)).===(191)) + assert(EQ(result3(19)(1)).===(200)) } for(myval <- Seq("mydouble", "myfloat")) { val result4 = @@ -534,18 +541,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA // currently no way to specify float constants in SqlParser? sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10").collect() } - assert(result4.size === 20) - assert(result4(0)(1) === 0) - assert(result4(9)(1) === 9) - assert(result4(10)(1) === 191) - assert(result4(19)(1) === 200) + assert(EQ(result4.size).===(20)) + assert(EQ(result4(0)(1)).===(0)) + assert(EQ(result4(9)(1)).===(9)) + assert(EQ(result4(10)(1)).===(191)) + assert(EQ(result4(19)(1)).===(200)) } val query5 = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40") assert( query5.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val booleanResult = query5.collect() - assert(booleanResult.size === 10) + assert(EQ(booleanResult.size).===(10)) for(i <- 0 until 10) { if (!booleanResult(i).getBoolean(0)) { fail(s"Boolean value in result row $i not true") @@ -559,16 +566,16 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query6.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val stringResult = query6.collect() - assert(stringResult.size === 1) + assert(EQ(stringResult.size).===(1)) assert(stringResult(0).getString(2) == "100", "stringvalue incorrect") - assert(stringResult(0).getInt(1) === 100) + assert(EQ(stringResult(0).getInt(1)).===(100)) val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40") assert( query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val optResult = query7.collect() - assert(optResult.size === 20) + assert(EQ(optResult.size).===(20)) for(i <- 0 until 20) { if (optResult(i)(7) != i * 2) { fail(s"optional Int value in result row $i should be ${2*4*i}") @@ -580,21 +587,21 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query8.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result8 = query8.collect() - assert(result8.size === 25) - assert(result8(0)(7) === 100) - assert(result8(24)(7) === 148) + assert(EQ(result8.size).===(25)) + assert(EQ(result8(0)(7)).===(100)) + assert(EQ(result8(24)(7)).===(148)) val query9 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") assert( query9.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result9 = query9.collect() - assert(result9.size === 25) + assert(EQ(result9.size).===(25)) if (myval == "myoptint" || myval == "myoptlong") { - assert(result9(0)(7) === 152) - assert(result9(24)(7) === 200) + assert(EQ(result9(0)(7)).===(152)) + assert(EQ(result9(24)(7)).===(200)) } else { - assert(result9(0)(7) === 150) - assert(result9(24)(7) === 198) + assert(EQ(result9(0)(7)).===(150)) + assert(EQ(result9(24)(7)).===(198)) } } val query10 = sql("SELECT * FROM testfiltersource WHERE myoptstring = \"100\"") @@ -602,15 +609,15 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA query10.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result10 = query10.collect() - assert(result10.size === 1) + assert(EQ(result10.size).===(1)) assert(result10(0).getString(8) == "100", "stringvalue incorrect") - assert(result10(0).getInt(7) === 100) + assert(EQ(result10(0).getInt(7)).===(100)) val query11 = sql(s"SELECT * FROM testfiltersource WHERE myoptboolean = true AND myoptint < 40") assert( query11.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], "Top operator should be ParquetTableScan after pushdown") val result11 = query11.collect() - assert(result11.size === 7) + assert(EQ(result11.size).===(7)) for(i <- 0 until 6) { if (!result11(i).getBoolean(6)) { fail(s"optional Boolean value in result row $i not true") @@ -623,7 +630,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { val query = sql(s"SELECT mystring FROM testfiltersource WHERE myint < 10") - assert(query.collect().size === 10) + assert(EQ(query.collect().size).===(10)) } test("Importing nested Parquet file (Addressbook)") { @@ -632,32 +639,32 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD .collect() assert(result != null) - assert(result.size === 2) + assert(EQ(result.size).===(2)) val first_record = result(0) val second_record = result(1) assert(first_record != null) assert(second_record != null) - assert(first_record.size === 3) - assert(second_record(1) === null) - assert(second_record(2) === null) - assert(second_record(0) === "A. Nonymous") - assert(first_record(0) === "Julien Le Dem") + assert(EQ(first_record.size).===(3)) + assert(EQ(second_record(1)).===(null)) + assert(EQ(second_record(2)).===(null)) + assert(EQ(second_record(0)).===("A. Nonymous")) + assert(EQ(first_record(0)).===("Julien Le Dem")) val first_owner_numbers = first_record(1) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] val first_contacts = first_record(2) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] assert(first_owner_numbers != null) - assert(first_owner_numbers(0) === "555 123 4567") - assert(first_owner_numbers(2) === "XXX XXX XXXX") - assert(first_contacts(0) - .asInstanceOf[CatalystConverter.StructScalaType[_]].size === 2) + assert(EQ(first_owner_numbers(0)).===("555 123 4567")) + assert(EQ(first_owner_numbers(2)).===("XXX XXX XXXX")) + assert(EQ(first_contacts(0) + .asInstanceOf[CatalystConverter.StructScalaType[_]].size).===(2)) val first_contacts_entry_one = first_contacts(0) .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(first_contacts_entry_one(0) === "Dmitriy Ryaboy") - assert(first_contacts_entry_one(1) === "555 987 6543") + assert(EQ(first_contacts_entry_one(0)).===("Dmitriy Ryaboy")) + assert(EQ(first_contacts_entry_one(1)).===("555 987 6543")) val first_contacts_entry_two = first_contacts(1) .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(first_contacts_entry_two(0) === "Chris Aniszczyk") + assert(EQ(first_contacts_entry_two(0)).===("Chris Aniszczyk")) } test("Importing nested Parquet file (nested numbers)") { @@ -665,31 +672,31 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .parquetFile(ParquetTestData.testNestedDir2.toString) .toSchemaRDD .collect() - assert(result.size === 1, "number of top-level rows incorrect") - assert(result(0).size === 5, "number of fields in row incorrect") - assert(result(0)(0) === 1) - assert(result(0)(1) === 7) + assert(EQ(result.size).===(1), "number of top-level rows incorrect") + assert(EQ(result(0).size).===(5), "number of fields in row incorrect") + assert(EQ(result(0)(0)).===(1)) + assert(EQ(result(0)(1)).===(7)) val subresult1 = result(0)(2).asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult1.size === 3) - assert(subresult1(0) === (1.toLong << 32)) - assert(subresult1(1) === (1.toLong << 33)) - assert(subresult1(2) === (1.toLong << 34)) + assert(EQ(subresult1.size).===(3)) + assert(EQ(subresult1(0)).===((1.toLong << 32))) + assert(EQ(subresult1(1)).===((1.toLong << 33))) + assert(EQ(subresult1(2)).===((1.toLong << 34))) val subresult2 = result(0)(3) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(subresult2.size === 2) - assert(subresult2(0) === 2.5) - assert(subresult2(1) === false) + assert(EQ(subresult2.size).===(2)) + assert(EQ(subresult2(0)).===(2.5)) + assert(EQ(subresult2(1)).===(false)) val subresult3 = result(0)(4) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult3.size === 2) - assert(subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 2) + assert(EQ(subresult3.size).===(2)) + assert(EQ(subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size).===(2)) val subresult4 = subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult4(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) - assert(subresult4(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) - assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 1) - assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) + assert(EQ(subresult4(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(7)) + assert(EQ(subresult4(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(8)) + assert(EQ(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size).===(1)) + assert(EQ(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(9)) } test("Simple query on addressbook") { @@ -697,8 +704,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .parquetFile(ParquetTestData.testNestedDir1.toString) .toSchemaRDD val tmp = data.where('owner === "Julien Le Dem").select('owner as 'a, 'contacts as 'c).collect() - assert(tmp.size === 1) - assert(tmp(0)(0) === "Julien Le Dem") + assert(EQ(tmp.size).===(1)) + assert(EQ(tmp(0)(0)).===("Julien Le Dem")) } test("Projection in addressbook") { @@ -706,37 +713,37 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA data.registerTempTable("data") val query = sql("SELECT owner, contacts[1].name FROM data") val tmp = query.collect() - assert(tmp.size === 2) - assert(tmp(0).size === 2) - assert(tmp(0)(0) === "Julien Le Dem") - assert(tmp(0)(1) === "Chris Aniszczyk") - assert(tmp(1)(0) === "A. Nonymous") - assert(tmp(1)(1) === null) + assert(EQ(tmp.size).===(2)) + assert(EQ(tmp(0).size).===(2)) + assert(EQ(tmp(0)(0)).===("Julien Le Dem")) + assert(EQ(tmp(0)(1)).===("Chris Aniszczyk")) + assert(EQ(tmp(1)(0)).===("A. Nonymous")) + assert(EQ(tmp(1)(1)).===(null)) } test("Simple query on nested int data") { val data = parquetFile(ParquetTestData.testNestedDir2.toString).toSchemaRDD data.registerTempTable("data") val result1 = sql("SELECT entries[0].value FROM data").collect() - assert(result1.size === 1) - assert(result1(0).size === 1) - assert(result1(0)(0) === 2.5) + assert(EQ(result1.size).===(1)) + assert(EQ(result1(0).size).===(1)) + assert(EQ(result1(0)(0)).===(2.5)) val result2 = sql("SELECT entries[0] FROM data").collect() - assert(result2.size === 1) + assert(EQ(result2.size).===(1)) val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(subresult1.size === 2) - assert(subresult1(0) === 2.5) - assert(subresult1(1) === false) + assert(EQ(subresult1.size).===(2)) + assert(EQ(subresult1(0)).===(2.5)) + assert(EQ(subresult1(1)).===(false)) val result3 = sql("SELECT outerouter FROM data").collect() val subresult2 = result3(0)(0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult2(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) - assert(subresult2(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) - assert(result3(0)(0) + assert(EQ(subresult2(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(7)) + assert(EQ(subresult2(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(8)) + assert(EQ(result3(0)(0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](1) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)).===(9)) } test("nested structs") { @@ -744,17 +751,17 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD data.registerTempTable("data") val result1 = sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() - assert(result1.size === 1) - assert(result1(0).size === 1) - assert(result1(0)(0) === false) + assert(EQ(result1.size).===(1)) + assert(EQ(result1(0).size).===(1)) + assert(EQ(result1(0)(0)).===(false)) val result2 = sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() - assert(result2.size === 1) - assert(result2(0).size === 1) - assert(result2(0)(0) === true) + assert(EQ(result2.size).===(1)) + assert(EQ(result2(0).size).===(1)) + assert(EQ(result2(0)(0)).===(true)) val result3 = sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() - assert(result3.size === 1) - assert(result3(0).size === 1) - assert(result3(0)(0) === false) + assert(EQ(result3.size).===(1)) + assert(EQ(result3(0).size).===(1)) + assert(EQ(result3(0)(0)).===(false)) } test("simple map") { @@ -763,38 +770,38 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD data.registerTempTable("mapTable") val result1 = sql("SELECT data1 FROM mapTable").collect() - assert(result1.size === 1) - assert(result1(0)(0) + assert(EQ(result1.size).===(1)) + assert(EQ(result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, _]] - .getOrElse("key1", 0) === 1) - assert(result1(0)(0) + .getOrElse("key1", 0)).===(1)) + assert(EQ(result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, _]] - .getOrElse("key2", 0) === 2) + .getOrElse("key2", 0)).===(2)) val result2 = sql("""SELECT data1["key1"] FROM mapTable""").collect() - assert(result2(0)(0) === 1) + assert(EQ(result2(0)(0)).===(1)) } test("map with struct values") { val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD data.registerTempTable("mapTable") val result1 = sql("SELECT data2 FROM mapTable").collect() - assert(result1.size === 1) + assert(EQ(result1.size).===(1)) val entry1 = result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] .getOrElse("seven", null) assert(entry1 != null) - assert(entry1(0) === 42) - assert(entry1(1) === "the answer") + assert(EQ(entry1(0)).===(42)) + assert(EQ(entry1(1)).===("the answer")) val entry2 = result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] .getOrElse("eight", null) assert(entry2 != null) - assert(entry2(0) === 49) - assert(entry2(1) === null) + assert(EQ(entry2(0)).===(49)) + assert(EQ(entry2(1)).===(null)) val result2 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() - assert(result2.size === 1) - assert(result2(0)(0) === 42.toLong) - assert(result2(0)(1) === "the answer") + assert(EQ(result2.size).===(1)) + assert(EQ(result2(0)(0)).===(42.toLong)) + assert(EQ(result2(0)(1)).===("the answer")) } test("Writing out Addressbook and reading it back in") { @@ -808,12 +815,12 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD .registerTempTable("tmpcopy") val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() - assert(tmpdata.size === 2) - assert(tmpdata(0).size === 2) - assert(tmpdata(0)(0) === "Julien Le Dem") - assert(tmpdata(0)(1) === "Chris Aniszczyk") - assert(tmpdata(1)(0) === "A. Nonymous") - assert(tmpdata(1)(1) === null) + assert(EQ(tmpdata.size).===(2)) + assert(EQ(tmpdata(0).size).===(2)) + assert(EQ(tmpdata(0)(0)).===("Julien Le Dem")) + assert(EQ(tmpdata(0)(1)).===("Chris Aniszczyk")) + assert(EQ(tmpdata(1)(0)).===("A. Nonymous")) + assert(EQ(tmpdata(1)(1)).===(null)) Utils.deleteRecursively(tmpdir) } @@ -826,26 +833,26 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .toSchemaRDD .registerTempTable("tmpmapcopy") val result1 = sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() - assert(result1.size === 1) - assert(result1(0)(0) === 2) + assert(EQ(result1.size).===(1)) + assert(EQ(result1(0)(0)).===(2)) val result2 = sql("SELECT data2 FROM tmpmapcopy").collect() - assert(result2.size === 1) + assert(EQ(result2.size).===(1)) val entry1 = result2(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] .getOrElse("seven", null) assert(entry1 != null) - assert(entry1(0) === 42) - assert(entry1(1) === "the answer") + assert(EQ(entry1(0)).===(42)) + assert(EQ(entry1(1)).===("the answer")) val entry2 = result2(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] .getOrElse("eight", null) assert(entry2 != null) - assert(entry2(0) === 49) - assert(entry2(1) === null) + assert(EQ(entry2(0)).===(49)) + assert(EQ(entry2(1)).===(null)) val result3 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() - assert(result3.size === 1) - assert(result3(0)(0) === 42.toLong) - assert(result3(0)(1) === "the answer") + assert(EQ(result3.size).===(1)) + assert(EQ(result3(0)(0)).===(42.toLong)) + assert(EQ(result3(0)(1)).===("the answer")) Utils.deleteRecursively(tmpdir) } @@ -854,7 +861,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA Utils.deleteRecursively(tmpdir) createParquetFile[TestRDDEntry](tmpdir.toString()).registerTempTable("tmpemptytable") val result1 = sql("SELECT * FROM tmpemptytable").collect() - assert(result1.size === 0) + assert(EQ(result1.size).===(0)) Utils.deleteRecursively(tmpdir) } @@ -868,7 +875,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA (fromCaseClassString, fromJson).zipped.foreach { (a, b) => assert(a.name == b.name) - assert(a.dataType === b.dataType) + assert(EQ(a.dataType).===(b.dataType)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index a90fc023e67d8..f3323952b75d5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -27,6 +27,13 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + class StatisticsSuite extends QueryTest with BeforeAndAfterAll { TestHive.reset() TestHive.cacheTables = false @@ -39,7 +46,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { case o => o } - assert(operators.size === 1) + assert(EQ(operators.size).===(1)) if (operators(0).getClass() != c) { fail( s"""$analyzeCommand expected command: $c, but got ${operators(0)} @@ -81,11 +88,11 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // TODO: How does it works? needs to add it back for other hive version. if (HiveShim.version =="0.12.0") { - assert(queryTotalSize("analyzeTable") === defaultSizeInBytes) + assert(EQ(queryTotalSize("analyzeTable")).===(defaultSizeInBytes)) } sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") - assert(queryTotalSize("analyzeTable") === BigInt(11624)) + assert(EQ(queryTotalSize("analyzeTable")).===(BigInt(11624))) sql("DROP TABLE analyzeTable").collect() @@ -110,11 +117,11 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === defaultSizeInBytes) + assert(EQ(queryTotalSize("analyzeTable_part")).===(defaultSizeInBytes)) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") - assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) + assert(EQ(queryTotalSize("analyzeTable_part")).===(BigInt(17436))) sql("DROP TABLE analyzeTable_part").collect() @@ -131,7 +138,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => mr.statistics.sizeInBytes } - assert(sizes.size === 1, s"Size wrong for:\n ${rdd.queryExecution}") + assert(EQ(sizes.size).===(1), s"Size wrong for:\n ${rdd.queryExecution}") assert(sizes(0).equals(BigInt(5812)), s"expected exact size 5812 for test table 'src', got: ${sizes(0)}") } @@ -151,14 +158,14 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = rdd.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= autoBroadcastJoinThreshold + assert(EQ(sizes.size).===(2) && sizes(0) <= autoBroadcastJoinThreshold && sizes(1) <= autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be // matched, other strategies need to be applied. var bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } - assert(bhj.size === 1, + assert(EQ(bhj.size).===(1), s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") checkAnswer(rdd, expectedAnswer) // check correctness of output @@ -172,7 +179,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") val shj = rdd.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } - assert(shj.size === 1, + assert(EQ(shj.size).===(1), "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index b897dff0159ff..b3ea6549fab8c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -30,6 +30,13 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{Row, SchemaRDD} +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(EQ(X).===(true)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + case class TestData(a: Int, b: String) /** @@ -139,7 +146,7 @@ class HiveQuerySuite extends HiveComparisonTest { test("CREATE TABLE AS runs once") { sql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() - assert(sql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, + assert(EQ(sql("SELECT COUNT(*) FROM foo").collect().head.getLong(0)).===(1), "Incorrect number of rows in created table") } @@ -161,7 +168,7 @@ class HiveQuerySuite extends HiveComparisonTest { test("Query expressed in SQL") { setConf("spark.sql.dialect", "sql") - assert(sql("SELECT 1").collect() === Array(Seq(1))) + assert(EQ(sql("SELECT 1").collect()).===(Array(Seq(1)))) setConf("spark.sql.dialect", "hiveql") } @@ -365,7 +372,7 @@ class HiveQuerySuite extends HiveComparisonTest { .collect() .toSet - assert(actual === expected) + assert(EQ(actual).===(expected)) } // TODO: adopt this test when Spark SQL has the functionality / framework to report errors. @@ -415,7 +422,7 @@ class HiveQuerySuite extends HiveComparisonTest { .collect() .map(x => Pair(x.getString(0), x.getInt(1))) - assert(results === Array(Pair("foo", 4))) + assert(EQ(results).===(Array(Pair("foo", 4)))) TestHive.reset() } @@ -557,8 +564,8 @@ class HiveQuerySuite extends HiveComparisonTest { sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).map { case (Row(map: Map[_, _]), Row(key: Int, value: String)) => - assert(map.size === 1) - assert(map.head === (key, value)) + assert(EQ(map.size).===(1)) + assert(EQ(map.head).===((key, value))) } } @@ -654,7 +661,7 @@ class HiveQuerySuite extends HiveComparisonTest { sql("CREATE TABLE dp_verify(intcol INT)") sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE dp_verify") - assert(sql("SELECT * FROM dp_verify").collect() === Array(Row(value))) + assert(EQ(sql("SELECT * FROM dp_verify").collect()).===(Array(Row(value)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 6f57fe8958387..be1fbca69708f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -29,6 +29,13 @@ import org.apache.spark.util.Utils // Implicits import org.apache.spark.sql.hive.test.TestHive._ +/* + * Note: the DSL conversions collide with the FunSuite === operator! + * We can apply the Funsuite conversion explicitly: + * assert(X === true) --> assert(convertToEqualizer(X).===(true)) + * (This file already imports convertToEqualizer) + */ + case class Cases(lower: String, UPPER: String) class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { @@ -80,7 +87,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft test("Simple column projection + filter on Parquet table") { val rdd = sql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect() - assert(rdd.size === 5, "Filter returned incorrect number of rows") + assert(convertToEqualizer(rdd.size).===(5), "Filter returned incorrect number of rows") assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value") } @@ -102,7 +109,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() val rddCopy = sql("SELECT * FROM ptable").collect() val rddOrig = sql("SELECT * FROM testsource").collect() - assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??") + assert(convertToEqualizer(rddCopy.size).===(rddOrig.size), "INSERT OVERWRITE changed size of table??") compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames) } @@ -111,7 +118,8 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft (rddOne, rddTwo).zipped.foreach { (a,b) => (a,b).zipped.toArray.zipWithIndex.foreach { case ((value_1, value_2), index) => - assert(value_1 === value_2, s"table $tableName row $counter field ${fieldNames(index)} don't match") + assert(convertToEqualizer(value_1).===(value_2), + s"table $tableName row $counter field ${fieldNames(index)} don't match") } counter = counter + 1 } From c304b1620d85c8eb66a608ecb40e9db3992533d0 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Wed, 5 Nov 2014 11:30:12 -0600 Subject: [PATCH 72/79] Removed Date and Timestamp from NativeTypes as this would force changes in the code generator, and this PR is already too big. Now passes the standard test suite (except SparkSubmitSuite, which I think is unrelated to these changes.) --- .../scala/org/apache/spark/sql/catalyst/dsl/package.scala | 2 +- .../org/apache/spark/sql/catalyst/types/dataTypes.scala | 5 +---- .../src/main/scala/org/apache/spark/sql/package.scala | 8 ++++---- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 4b05ac62b9dc6..84fb594da372d 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -153,7 +153,7 @@ package object dsl { * where a literal is being combined with a symbol. Without these an * expression such as 0 < 'x is not recognized. */ - case class LhsLiteral(x: Any) { + class LhsLiteral(x: Any) { val literal = Literal(x) def + (other: Symbol) = Add(literal, other) def - (other: Symbol) = Subtract(literal, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 1d4c4783154bc..5dd19dd12d8dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -195,8 +195,7 @@ case object NullType extends DataType object NativeType { val all = Seq( - IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, - ByteType, StringType, DateType, TimestampType) + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) def unapply(dt: DataType): Boolean = all.contains(dt) @@ -208,8 +207,6 @@ object NativeType { FloatType -> 4, ShortType -> 2, ByteType -> 1, - DateType -> 8, - TimestampType -> 12, StringType -> 4096) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 30b0c77d4c461..a1b3a221978f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -469,13 +469,13 @@ package object sql { * are provided. The class intializer accepts a String, e.g. * * {{{ - * val d = Date("2014-01-01") + * val d = RichDate("2014-01-01") * }}} * * @group dataType */ @DeveloperApi - val Date = catalyst.expressions.RichDate + val RichDate = catalyst.expressions.RichDate /** * :: DeveloperApi :: @@ -486,11 +486,11 @@ package object sql { * String, e.g. * * {{{ - * val ts = Timestamp("2014-01-01 12:34:56.78") + * val ts = RichTimestamp("2014-01-01 12:34:56.78") * }}} * * @group timeClasses */ @DeveloperApi - val Timestamp = catalyst.expressions.RichTimestamp + val RichTimestamp = catalyst.expressions.RichTimestamp } From 32df4742703b70c84e722452f6604e729f3a2a3b Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Wed, 5 Nov 2014 13:41:05 -0600 Subject: [PATCH 73/79] Added tests for the features in this PR. Added Date and Timestamp as aliases for RichDate and RichTimestamp when importing an SQLContext. --- .../ExpressionEvaluationSuite.scala | 21 +++++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 6 ++++++ 2 files changed, 27 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 128a4860843ff..36a4adf9321a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -779,4 +779,25 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 ^ c2, 3, row) checkEvaluation(~c1, -2, row) } + + test("recognizes literals on the left") { + assert(EQ(-1 + 'x).===(Add(-1, 'x))) + assert(EQ(0 < 'x).===(LessThan(0, 'x))) + assert(EQ(1.5 === 'x).===(EqualTo(1.5, 'x))) + assert(EQ(false !== 'x).===(Not(EqualTo(false, 'x)))) + assert(EQ("a string" >= 'x).===(GreaterThanOrEqual("a string", 'x))) + assert(EQ(RichDate("2014-11-05") > 'date).===(GreaterThan(RichDate("2014-11-05"), 'date))) + assert(EQ(RichTimestamp("2014-11-05 12:34:56.789") < 'now).===( + LessThan(RichTimestamp("2014-11-05 12:34:56.789"), 'now))) + } + + test("comparison operators for RichDate and RichTimestamp") { + assert(EQ(RichDate("2014-11-05") < RichDate("2014-11-06")).===(true)) + assert(EQ(RichDate("2014-11-05") <= RichDate("2013-11-06")).===(false)) + assert(EQ(RichTimestamp("2014-11-05 12:34:56.5432") > RichTimestamp("2014-11-05 00:00:00") + ).===(true)) + assert(EQ(RichTimestamp("2014-11-05 12:34:56") >= RichTimestamp("2014-11-06 00:00:00") + ).===(false)) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 84eaf401f240c..98370dcf81713 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -502,4 +502,10 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } + + /* * + * Map RichDate and RichTimestamp to their expected names in this context. + */ + val Date = org.apache.spark.sql.catalyst.expressions.RichDate + val Timestamp = org.apache.spark.sql.catalyst.expressions.RichTimestamp } From d52e6d716a17b408f81e6f1b78f4eb79455ae393 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Thu, 6 Nov 2014 06:22:49 -0600 Subject: [PATCH 74/79] Removed accidentlay extraneous import from Row.scala. --- .../scala/org/apache/spark/sql/catalyst/expressions/Row.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 99b9e6efbab90..c849f60bafe20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types.NativeType import java.sql.{Date, Timestamp} -import java.math.BigDecimal object Row { /** From 45f9478bf243ddffc0949e16e11ab60394fc2d0b Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Thu, 6 Nov 2014 06:25:47 -0600 Subject: [PATCH 75/79] ... and removed another extraneous import. --- .../scala/org/apache/spark/sql/catalyst/expressions/Row.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index c849f60bafe20..fab0cd31407f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types.NativeType -import java.sql.{Date, Timestamp} + object Row { /** From f1260427d30fafd0edfc2d232f7729e6bbe1ff80 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Thu, 6 Nov 2014 06:27:36 -0600 Subject: [PATCH 76/79] Tiny style issue. --- .../scala/org/apache/spark/sql/catalyst/expressions/Row.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index fab0cd31407f1..d00ec39774c35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types.NativeType - object Row { /** * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: From 2ec6f6bc1fbedcce81d83de20b481382dac18a4c Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Thu, 6 Nov 2014 07:04:53 -0600 Subject: [PATCH 77/79] Cleaning up comments. --- .../catalyst/expressions/ExpressionEvaluationSuite.scala | 6 +++--- .../org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala | 6 +++--- .../test/scala/org/apache/spark/sql/CachedTableSuite.scala | 6 +++--- .../src/test/scala/org/apache/spark/sql/DslQuerySuite.scala | 6 +++--- .../src/test/scala/org/apache/spark/sql/JoinSuite.scala | 6 +++--- .../src/test/scala/org/apache/spark/sql/SQLConfSuite.scala | 6 +++--- sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | 6 +++--- .../spark/sql/columnar/PartitionBatchPruningSuite.scala | 6 +++--- .../scala/org/apache/spark/sql/execution/PlannerSuite.scala | 6 +++--- .../org/apache/spark/sql/parquet/ParquetQuerySuite.scala | 6 +++--- .../scala/org/apache/spark/sql/hive/StatisticsSuite.scala | 6 +++--- .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 6 +++--- .../org/apache/spark/sql/parquet/HiveParquetSuite.scala | 6 +++--- 13 files changed, 39 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 36a4adf9321a1..472e97cad79b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -33,9 +33,9 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.dsl.expressions._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 82887dc9d4604..3586e6557aa14 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{StringType, NullType} /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index d0fec34fd2cf3..bc8efd787b1dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 687c5a2707587..34d58a09af85c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.test.TestSQLContext._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 0260eea467d1c..3d1f1801778c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -26,9 +26,9 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 4a09ed517d6e1..df22ec9b1bcd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.test._ import TestSQLContext._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 95582eec06975..5d5c2b0b5168e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.test._ import TestSQLContext._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 45086f88a27ff..a6f31e0e15302 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 0ed380c63527a..2c43c110e32b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 25031910c30de..0f53e9736fb20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index f3323952b75d5..3565b377ef602 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index b3ea6549fab8c..3e9805dfe8067 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{Row, SchemaRDD} /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(EQ(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) */ import org.scalatest.Assertions.{convertToEqualizer => EQ} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index be1fbca69708f..acefbe66dbd91 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -30,9 +30,9 @@ import org.apache.spark.util.Utils import org.apache.spark.sql.hive.test.TestHive._ /* - * Note: the DSL conversions collide with the FunSuite === operator! - * We can apply the Funsuite conversion explicitly: - * assert(X === true) --> assert(convertToEqualizer(X).===(true)) + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) * (This file already imports convertToEqualizer) */ From 3a14915442eed860f58130b1ceba659fbfd03269 Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Thu, 6 Nov 2014 10:02:41 -0600 Subject: [PATCH 78/79] One last comment clarification. --- .../scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index acefbe66dbd91..0069726426f96 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.hive.test.TestHive._ /* * Note: the DSL conversions collide with the scalatest === operator! * We can apply the scalatest conversion explicitly: - * assert(X === Y) --> assert(EQ(X).===(Y)) + * assert(X === Y) --> assert(convertToEqualizer(X).===(Y)) * (This file already imports convertToEqualizer) */ From d0a27abc164d256f25eb58710af9ce1e07b4b03f Mon Sep 17 00:00:00 2001 From: Marc Culler Date: Fri, 7 Nov 2014 07:36:39 -0600 Subject: [PATCH 79/79] One more test to fix after rebasing. --- .../org/apache/spark/sql/UserDefinedTypeSuite.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 1806a1dd82023..3c1505fd63e43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -24,6 +24,13 @@ import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.test.TestSQLContext._ +/* + * Note: the DSL conversions collide with the scalatest === operator! + * We can apply the scalatest conversion explicitly: + * assert(X === Y) --> assert(EQ(X).===(Y)) + */ +import org.scalatest.Assertions.{convertToEqualizer => EQ} + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { @@ -69,14 +76,14 @@ class UserDefinedTypeSuite extends QueryTest { test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() - assert(labelsArrays.size === 2) + assert(EQ(labelsArrays.size).===(2)) assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) val features: RDD[MyDenseVector] = pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } val featuresArrays: Array[MyDenseVector] = features.collect() - assert(featuresArrays.size === 2) + assert(EQ(featuresArrays.size).===(2)) assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) }